aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/state-map.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/state-map.h')
-rw-r--r--src/include/fst/state-map.h601
1 files changed, 601 insertions, 0 deletions
diff --git a/src/include/fst/state-map.h b/src/include/fst/state-map.h
new file mode 100644
index 0000000..ace4a3c
--- /dev/null
+++ b/src/include/fst/state-map.h
@@ -0,0 +1,601 @@
+// map.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Class to map over/transform states e.g., sort transitions
+// Consider using when operation does not change the number of states.
+
+#ifndef FST_LIB_STATE_MAP_H__
+#define FST_LIB_STATE_MAP_H__
+
+#include <algorithm>
+#include <unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+#include <string>
+#include <utility>
+using std::pair; using std::make_pair;
+
+#include <fst/cache.h>
+#include <fst/arc-map.h>
+#include <fst/mutable-fst.h>
+
+
+namespace fst {
+
+// StateMapper Interface - class determinies how states are mapped.
+// Useful for implementing operations that do not change the number of states.
+//
+// class StateMapper {
+// public:
+// typedef A FromArc;
+// typedef B ToArc;
+//
+// // Typical constructor
+// StateMapper(const Fst<A> &fst);
+// // Required copy constructor that allows updating Fst argument;
+// // pass only if relevant and changed.
+// StateMapper(const StateMapper &mapper, const Fst<A> *fst = 0);
+//
+// // Specifies initial state of result
+// B::StateId Start() const;
+// // Specifies state's final weight in result
+// B::Weight Final(B::StateId s) const;
+//
+// // These methods iterate through a state's arcs in result
+// // Specifies state to iterate over
+// void SetState(B::StateId s);
+// // End of arcs?
+// bool Done() const;
+// // Current arc
+
+// const B &Value() const;
+// // Advance to next arc (when !Done)
+// void Next();
+//
+// // Specifies input symbol table action the mapper requires (see above).
+// MapSymbolsAction InputSymbolsAction() const;
+// // Specifies output symbol table action the mapper requires (see above).
+// MapSymbolsAction OutputSymbolsAction() const;
+// // This specifies the known properties of an Fst mapped by this
+// // mapper. It takes as argument the input Fst's known properties.
+// uint64 Properties(uint64 props) const;
+// };
+//
+// We include a various state map versions below. One dimension of
+// variation is whether the mapping mutates its input, writes to a
+// new result Fst, or is an on-the-fly Fst. Another dimension is how
+// we pass the mapper. We allow passing the mapper by pointer
+// for cases that we need to change the state of the user's mapper.
+// We also include map versions that pass the mapper
+// by value or const reference when this suffices.
+
+// Maps an arc type A using a mapper function object C, passed
+// by pointer. This version modifies its Fst input.
+template<class A, class C>
+void StateMap(MutableFst<A> *fst, C* mapper) {
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
+ fst->SetInputSymbols(0);
+
+ if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
+ fst->SetOutputSymbols(0);
+
+ if (fst->Start() == kNoStateId)
+ return;
+
+ uint64 props = fst->Properties(kFstProperties, false);
+
+ fst->SetStart(mapper->Start());
+
+ for (StateId s = 0; s < fst->NumStates(); ++s) {
+ mapper->SetState(s);
+ fst->DeleteArcs(s);
+ for (; !mapper->Done(); mapper->Next())
+ fst->AddArc(s, mapper->Value());
+ fst->SetFinal(s, mapper->Final(s));
+ }
+
+ fst->SetProperties(mapper->Properties(props), kFstProperties);
+}
+
+// Maps an arc type A using a mapper function object C, passed
+// by value. This version modifies its Fst input.
+template<class A, class C>
+void StateMap(MutableFst<A> *fst, C mapper) {
+ StateMap(fst, &mapper);
+}
+
+
+// Maps an arc type A to an arc type B using mapper function
+// object C, passed by pointer. This version writes the mapped
+// input Fst to an output MutableFst.
+template<class A, class B, class C>
+void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) {
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ ofst->DeleteStates();
+
+ if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS)
+ ofst->SetInputSymbols(ifst.InputSymbols());
+ else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
+ ofst->SetInputSymbols(0);
+
+ if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS)
+ ofst->SetOutputSymbols(ifst.OutputSymbols());
+ else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
+ ofst->SetOutputSymbols(0);
+
+ uint64 iprops = ifst.Properties(kCopyProperties, false);
+
+ if (ifst.Start() == kNoStateId) {
+ if (iprops & kError) ofst->SetProperties(kError, kError);
+ return;
+ }
+
+ // Add all states.
+ if (ifst.Properties(kExpanded, false))
+ ofst->ReserveStates(CountStates(ifst));
+ for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next())
+ ofst->AddState();
+
+ ofst->SetStart(mapper->Start());
+
+ for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) {
+ StateId s = siter.Value();
+ mapper->SetState(s);
+ for (; !mapper->Done(); mapper->Next())
+ ofst->AddArc(s, mapper->Value());
+ ofst->SetFinal(s, mapper->Final(s));
+ }
+
+ uint64 oprops = ofst->Properties(kFstProperties, false);
+ ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties);
+}
+
+// Maps an arc type A to an arc type B using mapper function
+// object C, passed by value. This version writes the mapped input
+// Fst to an output MutableFst.
+template<class A, class B, class C>
+void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) {
+ StateMap(ifst, ofst, &mapper);
+}
+
+typedef CacheOptions StateMapFstOptions;
+
+template <class A, class B, class C> class StateMapFst;
+
+// Implementation of delayed StateMapFst.
+template <class A, class B, class C>
+class StateMapFstImpl : public CacheImpl<B> {
+ public:
+ using FstImpl<B>::SetType;
+ using FstImpl<B>::SetProperties;
+ using FstImpl<B>::SetInputSymbols;
+ using FstImpl<B>::SetOutputSymbols;
+
+ using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates;
+
+ using CacheImpl<B>::PushArc;
+ using CacheImpl<B>::HasArcs;
+ using CacheImpl<B>::HasFinal;
+ using CacheImpl<B>::HasStart;
+ using CacheImpl<B>::SetArcs;
+ using CacheImpl<B>::SetFinal;
+ using CacheImpl<B>::SetStart;
+
+ friend class StateIterator< StateMapFst<A, B, C> >;
+
+ typedef B Arc;
+ typedef typename B::Weight Weight;
+ typedef typename B::StateId StateId;
+
+ StateMapFstImpl(const Fst<A> &fst, const C &mapper,
+ const StateMapFstOptions& opts)
+ : CacheImpl<B>(opts),
+ fst_(fst.Copy()),
+ mapper_(new C(mapper, fst_)),
+ own_mapper_(true) {
+ Init();
+ }
+
+ StateMapFstImpl(const Fst<A> &fst, C *mapper,
+ const StateMapFstOptions& opts)
+ : CacheImpl<B>(opts),
+ fst_(fst.Copy()),
+ mapper_(mapper),
+ own_mapper_(false) {
+ Init();
+ }
+
+ StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl)
+ : CacheImpl<B>(impl),
+ fst_(impl.fst_->Copy(true)),
+ mapper_(new C(*impl.mapper_, fst_)),
+ own_mapper_(true) {
+ Init();
+ }
+
+ ~StateMapFstImpl() {
+ delete fst_;
+ if (own_mapper_) delete mapper_;
+ }
+
+ StateId Start() {
+ if (!HasStart())
+ SetStart(mapper_->Start());
+ return CacheImpl<B>::Start();
+ }
+
+ Weight Final(StateId s) {
+ if (!HasFinal(s))
+ SetFinal(s, mapper_->Final(s));
+ return CacheImpl<B>::Final(s);
+ }
+
+ size_t NumArcs(StateId s) {
+ if (!HasArcs(s))
+ Expand(s);
+ return CacheImpl<B>::NumArcs(s);
+ }
+
+ size_t NumInputEpsilons(StateId s) {
+ if (!HasArcs(s))
+ Expand(s);
+ return CacheImpl<B>::NumInputEpsilons(s);
+ }
+
+ size_t NumOutputEpsilons(StateId s) {
+ if (!HasArcs(s))
+ Expand(s);
+ return CacheImpl<B>::NumOutputEpsilons(s);
+ }
+
+ void InitStateIterator(StateIteratorData<A> *data) const {
+ fst_->InitStateIterator(data);
+ }
+
+ void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
+ if (!HasArcs(s))
+ Expand(s);
+ CacheImpl<B>::InitArcIterator(s, data);
+ }
+
+ uint64 Properties() const { return Properties(kFstProperties); }
+
+ // Set error if found; return FST impl properties.
+ uint64 Properties(uint64 mask) const {
+ if ((mask & kError) && (fst_->Properties(kError, false) ||
+ (mapper_->Properties(0) & kError)))
+ SetProperties(kError, kError);
+ return FstImpl<Arc>::Properties(mask);
+ }
+
+ void Expand(StateId s) {
+ // Add exiting arcs.
+ for (mapper_->SetState(s); !mapper_->Done(); mapper_->Next())
+ PushArc(s, mapper_->Value());
+ SetArcs(s);
+ }
+
+ private:
+ void Init() {
+ SetType("statemap");
+
+ if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS)
+ SetInputSymbols(fst_->InputSymbols());
+ else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
+ SetInputSymbols(0);
+
+ if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS)
+ SetOutputSymbols(fst_->OutputSymbols());
+ else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
+ SetOutputSymbols(0);
+
+ uint64 props = fst_->Properties(kCopyProperties, false);
+ SetProperties(mapper_->Properties(props));
+ }
+
+ const Fst<A> *fst_;
+ C* mapper_;
+ bool own_mapper_;
+
+ void operator=(const StateMapFstImpl<A, B, C> &); // disallow
+};
+
+
+// Maps an arc type A to an arc type B using Mapper function object
+// C. This version is a delayed Fst.
+template <class A, class B, class C>
+class StateMapFst : public ImplToFst< StateMapFstImpl<A, B, C> > {
+ public:
+ friend class ArcIterator< StateMapFst<A, B, C> >;
+
+ typedef B Arc;
+ typedef typename B::Weight Weight;
+ typedef typename B::StateId StateId;
+ typedef CacheState<B> State;
+ typedef StateMapFstImpl<A, B, C> Impl;
+
+ StateMapFst(const Fst<A> &fst, const C &mapper,
+ const StateMapFstOptions& opts)
+ : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {}
+
+ StateMapFst(const Fst<A> &fst, C* mapper, const StateMapFstOptions& opts)
+ : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {}
+
+ StateMapFst(const Fst<A> &fst, const C &mapper)
+ : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {}
+
+ StateMapFst(const Fst<A> &fst, C* mapper)
+ : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {}
+
+ // See Fst<>::Copy() for doc.
+ StateMapFst(const StateMapFst<A, B, C> &fst, bool safe = false)
+ : ImplToFst<Impl>(fst, safe) {}
+
+ // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc.
+ virtual StateMapFst<A, B, C> *Copy(bool safe = false) const {
+ return new StateMapFst<A, B, C>(*this, safe);
+ }
+
+ virtual void InitStateIterator(StateIteratorData<A> *data) const {
+ GetImpl()->InitStateIterator(data);
+ }
+
+ virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
+ GetImpl()->InitArcIterator(s, data);
+ }
+
+ private:
+ // Makes visible to friends.
+ Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
+
+ void operator=(const StateMapFst<A, B, C> &fst); // disallow
+};
+
+
+// Specialization for StateMapFst.
+template <class A, class B, class C>
+class ArcIterator< StateMapFst<A, B, C> >
+ : public CacheArcIterator< StateMapFst<A, B, C> > {
+ public:
+ typedef typename A::StateId StateId;
+
+ ArcIterator(const StateMapFst<A, B, C> &fst, StateId s)
+ : CacheArcIterator< StateMapFst<A, B, C> >(fst.GetImpl(), s) {
+ if (!fst.GetImpl()->HasArcs(s))
+ fst.GetImpl()->Expand(s);
+ }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(ArcIterator);
+};
+
+//
+// Utility Mappers
+//
+
+// Mapper that returns its input.
+template <class A>
+class IdentityStateMapper {
+ public:
+ typedef A FromArc;
+ typedef A ToArc;
+
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ explicit IdentityStateMapper(const Fst<A> &fst) : fst_(fst), aiter_(0) {}
+
+ // Allows updating Fst argument; pass only if changed.
+ IdentityStateMapper(const IdentityStateMapper<A> &mapper,
+ const Fst<A> *fst = 0)
+ : fst_(fst ? *fst : mapper.fst_), aiter_(0) {}
+
+ ~IdentityStateMapper() { delete aiter_; }
+
+ StateId Start() const { return fst_.Start(); }
+
+ Weight Final(StateId s) const { return fst_.Final(s); }
+
+ void SetState(StateId s) {
+ if (aiter_) delete aiter_;
+ aiter_ = new ArcIterator< Fst<A> >(fst_, s);
+ }
+
+ bool Done() const { return aiter_->Done(); }
+ const A &Value() const { return aiter_->Value(); }
+ void Next() { aiter_->Next(); }
+
+ MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
+ MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;}
+
+ uint64 Properties(uint64 props) const { return props; }
+
+ private:
+ const Fst<A> &fst_;
+ ArcIterator< Fst<A> > *aiter_;
+};
+
+template <class A>
+class ArcSumMapper {
+ public:
+ typedef A FromArc;
+ typedef A ToArc;
+
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ explicit ArcSumMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
+
+ // Allows updating Fst argument; pass only if changed.
+ ArcSumMapper(const ArcSumMapper<A> &mapper,
+ const Fst<A> *fst = 0)
+ : fst_(fst ? *fst : mapper.fst_), i_(0) {}
+
+ StateId Start() const { return fst_.Start(); }
+ Weight Final(StateId s) const { return fst_.Final(s); }
+
+ void SetState(StateId s) {
+ i_ = 0;
+ arcs_.clear();
+ arcs_.reserve(fst_.NumArcs(s));
+ for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next())
+ arcs_.push_back(aiter.Value());
+
+ // First sorts the exiting arcs by input label, output label
+ // and destination state and then sums weights of arcs with
+ // the same input label, output label, and destination state.
+ sort(arcs_.begin(), arcs_.end(), comp_);
+ size_t narcs = 0;
+ for (size_t i = 0; i < arcs_.size(); ++i) {
+ if (narcs > 0 && equal_(arcs_[i], arcs_[narcs - 1])) {
+ arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight,
+ arcs_[i].weight);
+ } else {
+ arcs_[narcs++] = arcs_[i];
+ }
+ }
+ arcs_.resize(narcs);
+ }
+
+ bool Done() const { return i_ >= arcs_.size(); }
+ const A &Value() const { return arcs_[i_]; }
+ void Next() { ++i_; }
+
+ MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
+ MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
+
+ uint64 Properties(uint64 props) const {
+ return props & kArcSortProperties &
+ kDeleteArcsProperties & kWeightInvariantProperties;
+ }
+
+ private:
+ struct Compare {
+ bool operator()(const A& x, const A& y) {
+ if (x.ilabel < y.ilabel) return true;
+ if (x.ilabel > y.ilabel) return false;
+ if (x.olabel < y.olabel) return true;
+ if (x.olabel > y.olabel) return false;
+ if (x.nextstate < y.nextstate) return true;
+ if (x.nextstate > y.nextstate) return false;
+ return false;
+ }
+ };
+
+ struct Equal {
+ bool operator()(const A& x, const A& y) {
+ return (x.ilabel == y.ilabel &&
+ x.olabel == y.olabel &&
+ x.nextstate == y.nextstate);
+ }
+ };
+
+ const Fst<A> &fst_;
+ Compare comp_;
+ Equal equal_;
+ vector<A> arcs_;
+ ssize_t i_; // current arc position
+
+ void operator=(const ArcSumMapper<A> &); // disallow
+};
+
+template <class A>
+class ArcUniqueMapper {
+ public:
+ typedef A FromArc;
+ typedef A ToArc;
+
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
+
+ // Allows updating Fst argument; pass only if changed.
+ ArcUniqueMapper(const ArcSumMapper<A> &mapper,
+ const Fst<A> *fst = 0)
+ : fst_(fst ? *fst : mapper.fst_), i_(0) {}
+
+ StateId Start() const { return fst_.Start(); }
+ Weight Final(StateId s) const { return fst_.Final(s); }
+
+ void SetState(StateId s) {
+ i_ = 0;
+ arcs_.clear();
+ arcs_.reserve(fst_.NumArcs(s));
+ for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next())
+ arcs_.push_back(aiter.Value());
+
+ // First sorts the exiting arcs by input label, output label
+ // and destination state and then uniques identical arcs
+ sort(arcs_.begin(), arcs_.end(), comp_);
+ typename vector<A>::iterator unique_end =
+ unique(arcs_.begin(), arcs_.end(), equal_);
+ arcs_.resize(unique_end - arcs_.begin());
+ }
+
+ bool Done() const { return i_ >= arcs_.size(); }
+ const A &Value() const { return arcs_[i_]; }
+ void Next() { ++i_; }
+
+ MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
+ MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
+
+ uint64 Properties(uint64 props) const {
+ return props & kArcSortProperties & kDeleteArcsProperties;
+ }
+
+ private:
+ struct Compare {
+ bool operator()(const A& x, const A& y) {
+ if (x.ilabel < y.ilabel) return true;
+ if (x.ilabel > y.ilabel) return false;
+ if (x.olabel < y.olabel) return true;
+ if (x.olabel > y.olabel) return false;
+ if (x.nextstate < y.nextstate) return true;
+ if (x.nextstate > y.nextstate) return false;
+ return false;
+ }
+ };
+
+ struct Equal {
+ bool operator()(const A& x, const A& y) {
+ return (x.ilabel == y.ilabel &&
+ x.olabel == y.olabel &&
+ x.nextstate == y.nextstate &&
+ x.weight == y.weight);
+ }
+ };
+
+ const Fst<A> &fst_;
+ Compare comp_;
+ Equal equal_;
+ vector<A> arcs_;
+ ssize_t i_; // current arc position
+
+ void operator=(const ArcUniqueMapper<A> &); // disallow
+};
+
+
+} // namespace fst
+
+#endif // FST_LIB_STATE_MAP_H__