aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/script/fst-class.h
blob: a820c1c9b137613ee8aa06fb8e6c848483c6cf4c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: jpr@google.com (Jake Ratkiewicz)

#ifndef FST_SCRIPT_FST_CLASS_H_
#define FST_SCRIPT_FST_CLASS_H_

#include <string>

#include <fst/fst.h>
#include <fst/mutable-fst.h>
#include <fst/vector-fst.h>
#include <iostream>
#include <fstream>
#include <sstream>

// Classes to support "boxing" all existing types of FST arcs in a single
// FstClass which hides the arc types. This allows clients to load
// and work with FSTs without knowing the arc type.

// These classes are only recommended for use in high-level scripting
// applications. Most users should use the lower-level templated versions
// corresponding to these classes.

namespace fst {
namespace script {

//
// Abstract base class defining the set of functionalities implemented
// in all impls, and passed through by all bases Below FstClassBase
// the class hierarchy bifurcates; FstClassImplBase serves as the base
// class for all implementations (of which FstClassImpl is currently
// the only one) and FstClass serves as the base class for all
// interfaces.
//
class FstClassBase {
 public:
  virtual const string &ArcType() const = 0;
  virtual const string &FstType() const = 0;
  virtual const string &WeightType() const = 0;
  virtual const SymbolTable *InputSymbols() const = 0;
  virtual const SymbolTable *OutputSymbols() const = 0;
  virtual void Write(const string& fname) const = 0;
  virtual void Write(ostream &ostr, const FstWriteOptions &opts) const = 0;
  virtual uint64 Properties(uint64 mask, bool test) const = 0;
  virtual ~FstClassBase() { }
};

class FstClassImplBase : public FstClassBase {
 public:
  virtual FstClassImplBase *Copy() = 0;
  virtual void SetInputSymbols(SymbolTable *is) = 0;
  virtual void SetOutputSymbols(SymbolTable *is) = 0;
  virtual ~FstClassImplBase() { }
};


//
// CONTAINER CLASS
// Wraps an Fst<Arc>, hiding its arc type. Whether this Fst<Arc>
// pointer refers to a special kind of FST (e.g. a MutableFst) is
// known by the type of interface class that owns the pointer to this
// container.
//

template<class Arc>
class FstClassImpl : public FstClassImplBase {
 public:
  explicit FstClassImpl(Fst<Arc> *impl,
                        bool should_own = false) :
      impl_(should_own ? impl : impl->Copy()) { }

  virtual const string &ArcType() const {
    return Arc::Type();
  }

  virtual const string &FstType() const {
    return impl_->Type();
  }

  virtual const string &WeightType() const {
    return Arc::Weight::Type();
  }

  virtual const SymbolTable *InputSymbols() const {
    return impl_->InputSymbols();
  }

  virtual const SymbolTable *OutputSymbols() const {
    return impl_->OutputSymbols();
  }

  // Warning: calling this method casts the FST to a mutable FST.
  virtual void SetInputSymbols(SymbolTable *is) {
    static_cast<MutableFst<Arc> *>(impl_)->SetInputSymbols(is);
  }

  // Warning: calling this method casts the FST to a mutable FST.
  virtual void SetOutputSymbols(SymbolTable *os) {
    static_cast<MutableFst<Arc> *>(impl_)->SetOutputSymbols(os);
  }

  virtual void Write(const string &fname) const {
    impl_->Write(fname);
  }

  virtual void Write(ostream &ostr, const FstWriteOptions &opts) const {
    impl_->Write(ostr, opts);
  }

  virtual uint64 Properties(uint64 mask, bool test) const {
    return impl_->Properties(mask, test);
  }

  virtual ~FstClassImpl() { delete impl_; }

  Fst<Arc> *GetImpl() const { return impl_; }

  Fst<Arc> *GetImpl() { return impl_; }

  virtual FstClassImpl *Copy() {
    return new FstClassImpl<Arc>(impl_);
  }

 private:
  Fst<Arc> *impl_;
};

//
// BASE CLASS DEFINITIONS
//

class MutableFstClass;

class FstClass : public FstClassBase {
 public:
  template<class Arc>
  static FstClass *Read(istream &stream,
                        const FstReadOptions &opts) {
    if (!opts.header) {
      FSTERROR() << "FstClass::Read: options header not specified";
      return 0;
    }
    const FstHeader &hdr = *opts.header;

    if (hdr.Properties() & kMutable) {
      return ReadTypedFst<MutableFstClass, MutableFst<Arc> >(stream, opts);
    } else {
      return ReadTypedFst<FstClass, Fst<Arc> >(stream, opts);
    }
  }

  FstClass() : impl_(NULL) {
  }

  template<class Arc>
  explicit FstClass(Fst<Arc> *fst) : impl_(new FstClassImpl<Arc>(fst)) {
  }

  explicit FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { }

  FstClass &operator=(const FstClass &other) {
    delete impl_;
    impl_ = other.impl_->Copy();
    return *this;
  }

  static FstClass *Read(const string &fname);

  static FstClass *Read(istream &istr, const string &source);

  virtual const string &ArcType() const {
    return impl_->ArcType();
  }

  virtual const string& FstType() const {
    return impl_->FstType();
  }

  virtual const SymbolTable *InputSymbols() const {
    return impl_->InputSymbols();
  }

  virtual const SymbolTable *OutputSymbols() const {
    return impl_->OutputSymbols();
  }

  virtual const string& WeightType() const {
    return impl_->WeightType();
  }

  virtual void Write(const string &fname) const {
    impl_->Write(fname);
  }

  virtual void Write(ostream &ostr, const FstWriteOptions &opts) const {
    impl_->Write(ostr, opts);
  }

  virtual uint64 Properties(uint64 mask, bool test) const {
    return impl_->Properties(mask, test);
  }

  template<class Arc>
  const Fst<Arc> *GetFst() const {
    if (Arc::Type() != ArcType()) {
      return NULL;
    } else {
      FstClassImpl<Arc> *typed_impl = static_cast<FstClassImpl<Arc> *>(impl_);
      return typed_impl->GetImpl();
    }
  }

  virtual ~FstClass() { delete impl_; }

  // These methods are required by IO registration
  template<class Arc>
  static FstClassImplBase *Convert(const FstClass &other) {
    LOG(ERROR) << "Doesn't make sense to convert any class to type FstClass.";
    return 0;
  }

  template<class Arc>
  static FstClassImplBase *Create() {
    LOG(ERROR) << "Doesn't make sense to create an FstClass with a "
               << "particular arc type.";
    return 0;
  }


 protected:
  explicit FstClass(FstClassImplBase *impl) : impl_(impl) { }

  // Generic template method for reading an arc-templated FST of type
  // UnderlyingT, and returning it wrapped as FstClassT, with appropriate
  // error checking. Called from arc-templated Read() static methods.
  template<class FstClassT, class UnderlyingT>
  static FstClassT* ReadTypedFst(istream &stream,
                                     const FstReadOptions &opts) {
    UnderlyingT *u = UnderlyingT::Read(stream, opts);
    if (!u) {
      return 0;
    } else {
      FstClassT *r = new FstClassT(u);
      delete u;
      return r;
    }
  }

  FstClassImplBase *GetImpl() const { return impl_; }

  FstClassImplBase *GetImpl() { return impl_; }

//  friend ostream &operator<<(ostream&, const FstClass&);

 private:
  FstClassImplBase *impl_;
};

//
// Specific types of FstClass with special properties
//

class MutableFstClass : public FstClass {
 public:
  template<class Arc>
  explicit MutableFstClass(MutableFst<Arc> *fst) :
      FstClass(fst) { }

  template<class Arc>
  MutableFst<Arc> *GetMutableFst() {
    Fst<Arc> *fst = const_cast<Fst<Arc> *>(this->GetFst<Arc>());
    MutableFst<Arc> *mfst = static_cast<MutableFst<Arc> *>(fst);

    return mfst;
  }

  template<class Arc>
  static MutableFstClass *Read(istream &stream,
                               const FstReadOptions &opts) {
    MutableFst<Arc> *mfst = MutableFst<Arc>::Read(stream, opts);
    if (!mfst) {
      return 0;
    } else {
      MutableFstClass *retval = new MutableFstClass(mfst);
      delete mfst;
      return retval;
    }
  }

  virtual void Write(const string &fname) const {
    GetImpl()->Write(fname);
  }

  virtual void Write(ostream &ostr, const FstWriteOptions &opts) const {
    GetImpl()->Write(ostr, opts);
  }

  static MutableFstClass *Read(const string &fname, bool convert = false);

  virtual void SetInputSymbols(SymbolTable *is) {
    GetImpl()->SetInputSymbols(is);
  }

  virtual void SetOutputSymbols(SymbolTable *os) {
    GetImpl()->SetOutputSymbols(os);
  }

  // These methods are required by IO registration
  template<class Arc>
  static FstClassImplBase *Convert(const FstClass &other) {
    LOG(ERROR) << "Doesn't make sense to convert any class to type "
               << "MutableFstClass.";
    return 0;
  }

  template<class Arc>
  static FstClassImplBase *Create() {
    LOG(ERROR) << "Doesn't make sense to create a MutableFstClass with a "
               << "particular arc type.";
    return 0;
  }

 protected:
  explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) { }
};


class VectorFstClass : public MutableFstClass {
 public:
  explicit VectorFstClass(const FstClass &other);
  explicit VectorFstClass(const string &arc_type);

  template<class Arc>
  explicit VectorFstClass(VectorFst<Arc> *fst) :
      MutableFstClass(fst) { }

  template<class Arc>
  static VectorFstClass *Read(istream &stream,
                              const FstReadOptions &opts) {
    VectorFst<Arc> *vfst = VectorFst<Arc>::Read(stream, opts);
    if (!vfst) {
      return 0;
    } else {
      VectorFstClass *retval = new VectorFstClass(vfst);
      delete vfst;
      return retval;
    }
  }

  static VectorFstClass *Read(const string &fname);

  // Converter / creator for known arc types
  template<class Arc>
  static FstClassImplBase *Convert(const FstClass &other) {
    return new FstClassImpl<Arc>(new VectorFst<Arc>(
        *other.GetFst<Arc>()), true);
  }

  template<class Arc>
  static FstClassImplBase *Create() {
    return new FstClassImpl<Arc>(new VectorFst<Arc>(), true);
  }
};

}  // namespace script
}  // namespace fst
#endif  // FST_SCRIPT_FST_CLASS_H_