aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/connect.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/connect.h')
-rw-r--r--src/include/fst/connect.h319
1 files changed, 319 insertions, 0 deletions
diff --git a/src/include/fst/connect.h b/src/include/fst/connect.h
new file mode 100644
index 0000000..427808c
--- /dev/null
+++ b/src/include/fst/connect.h
@@ -0,0 +1,319 @@
+// connect.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Classes and functions to remove unsuccessful paths from an Fst.
+
+#ifndef FST_LIB_CONNECT_H__
+#define FST_LIB_CONNECT_H__
+
+#include <vector>
+using std::vector;
+
+#include <fst/dfs-visit.h>
+#include <fst/union-find.h>
+#include <fst/mutable-fst.h>
+
+
+namespace fst {
+
+// Finds and returns connected components. Use with Visit().
+template <class A>
+class CcVisitor {
+ public:
+ typedef A Arc;
+ typedef typename Arc::Weight Weight;
+ typedef typename A::StateId StateId;
+
+ // cc[i]: connected component number for state i.
+ CcVisitor(vector<StateId> *cc)
+ : comps_(new UnionFind<StateId>(0, kNoStateId)),
+ cc_(cc),
+ nstates_(0) { }
+
+ // comps: connected components equiv classes.
+ CcVisitor(UnionFind<StateId> *comps)
+ : comps_(comps),
+ cc_(0),
+ nstates_(0) { }
+
+ ~CcVisitor() {
+ if (cc_) // own comps_?
+ delete comps_;
+ }
+
+ void InitVisit(const Fst<A> &fst) { }
+
+ bool InitState(StateId s, StateId root) {
+ ++nstates_;
+ if (comps_->FindSet(s) == kNoStateId)
+ comps_->MakeSet(s);
+ return true;
+ }
+
+ bool WhiteArc(StateId s, const A &arc) {
+ comps_->MakeSet(arc.nextstate);
+ comps_->Union(s, arc.nextstate);
+ return true;
+ }
+
+ bool GreyArc(StateId s, const A &arc) {
+ comps_->Union(s, arc.nextstate);
+ return true;
+ }
+
+ bool BlackArc(StateId s, const A &arc) {
+ comps_->Union(s, arc.nextstate);
+ return true;
+ }
+
+ void FinishState(StateId s) { }
+
+ void FinishVisit() {
+ if (cc_)
+ GetCcVector(cc_);
+ }
+
+ // cc[i]: connected component number for state i.
+ // Returns number of components.
+ int GetCcVector(vector<StateId> *cc) {
+ cc->clear();
+ cc->resize(nstates_, kNoStateId);
+ StateId ncomp = 0;
+ for (StateId i = 0; i < nstates_; ++i) {
+ StateId rep = comps_->FindSet(i);
+ StateId &comp = (*cc)[rep];
+ if (comp == kNoStateId) {
+ comp = ncomp;
+ ++ncomp;
+ }
+ (*cc)[i] = comp;
+ }
+ return ncomp;
+ }
+
+ private:
+ UnionFind<StateId> *comps_; // Components
+ vector<StateId> *cc_; // State's cc number
+ StateId nstates_; // State count
+};
+
+
+// Finds and returns strongly-connected components, accessible and
+// coaccessible states and related properties. Uses Tarjan's single
+// DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer
+// Algorithms", 189pp). Use with DfsVisit();
+template <class A>
+class SccVisitor {
+ public:
+ typedef A Arc;
+ typedef typename A::Weight Weight;
+ typedef typename A::StateId StateId;
+
+ // scc[i]: strongly-connected component number for state i.
+ // SCC numbers will be in topological order for acyclic input.
+ // access[i]: accessibility of state i.
+ // coaccess[i]: coaccessibility of state i.
+ // Any of above can be NULL.
+ // props: related property bits (cyclicity, initial cyclicity,
+ // accessibility, coaccessibility) set/cleared (o.w. unchanged).
+ SccVisitor(vector<StateId> *scc, vector<bool> *access,
+ vector<bool> *coaccess, uint64 *props)
+ : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {}
+ SccVisitor(uint64 *props)
+ : scc_(0), access_(0), coaccess_(0), props_(props) {}
+
+ void InitVisit(const Fst<A> &fst);
+
+ bool InitState(StateId s, StateId root);
+
+ bool TreeArc(StateId s, const A &arc) { return true; }
+
+ bool BackArc(StateId s, const A &arc) {
+ StateId t = arc.nextstate;
+ if ((*dfnumber_)[t] < (*lowlink_)[s])
+ (*lowlink_)[s] = (*dfnumber_)[t];
+ if ((*coaccess_)[t])
+ (*coaccess_)[s] = true;
+ *props_ |= kCyclic;
+ *props_ &= ~kAcyclic;
+ if (arc.nextstate == start_) {
+ *props_ |= kInitialCyclic;
+ *props_ &= ~kInitialAcyclic;
+ }
+ return true;
+ }
+
+ bool ForwardOrCrossArc(StateId s, const A &arc) {
+ StateId t = arc.nextstate;
+ if ((*dfnumber_)[t] < (*dfnumber_)[s] /* cross edge */ &&
+ (*onstack_)[t] && (*dfnumber_)[t] < (*lowlink_)[s])
+ (*lowlink_)[s] = (*dfnumber_)[t];
+ if ((*coaccess_)[t])
+ (*coaccess_)[s] = true;
+ return true;
+ }
+
+ void FinishState(StateId s, StateId p, const A *);
+
+ void FinishVisit() {
+ // Numbers SCC's in topological order when acyclic.
+ if (scc_)
+ for (StateId i = 0; i < scc_->size(); ++i)
+ (*scc_)[i] = nscc_ - 1 - (*scc_)[i];
+ if (coaccess_internal_)
+ delete coaccess_;
+ delete dfnumber_;
+ delete lowlink_;
+ delete onstack_;
+ delete scc_stack_;
+ }
+
+ private:
+ vector<StateId> *scc_; // State's scc number
+ vector<bool> *access_; // State's accessibility
+ vector<bool> *coaccess_; // State's coaccessibility
+ uint64 *props_;
+ const Fst<A> *fst_;
+ StateId start_;
+ StateId nstates_; // State count
+ StateId nscc_; // SCC count
+ bool coaccess_internal_;
+ vector<StateId> *dfnumber_; // state discovery times
+ vector<StateId> *lowlink_; // lowlink[s] == dfnumber[s] => SCC root
+ vector<bool> *onstack_; // is a state on the SCC stack
+ vector<StateId> *scc_stack_; // SCC stack (w/ random access)
+};
+
+template <class A> inline
+void SccVisitor<A>::InitVisit(const Fst<A> &fst) {
+ if (scc_)
+ scc_->clear();
+ if (access_)
+ access_->clear();
+ if (coaccess_) {
+ coaccess_->clear();
+ coaccess_internal_ = false;
+ } else {
+ coaccess_ = new vector<bool>;
+ coaccess_internal_ = true;
+ }
+ *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible;
+ *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible);
+ fst_ = &fst;
+ start_ = fst.Start();
+ nstates_ = 0;
+ nscc_ = 0;
+ dfnumber_ = new vector<StateId>;
+ lowlink_ = new vector<StateId>;
+ onstack_ = new vector<bool>;
+ scc_stack_ = new vector<StateId>;
+}
+
+template <class A> inline
+bool SccVisitor<A>::InitState(StateId s, StateId root) {
+ scc_stack_->push_back(s);
+ while (dfnumber_->size() <= s) {
+ if (scc_)
+ scc_->push_back(-1);
+ if (access_)
+ access_->push_back(false);
+ coaccess_->push_back(false);
+ dfnumber_->push_back(-1);
+ lowlink_->push_back(-1);
+ onstack_->push_back(false);
+ }
+ (*dfnumber_)[s] = nstates_;
+ (*lowlink_)[s] = nstates_;
+ (*onstack_)[s] = true;
+ if (root == start_) {
+ if (access_)
+ (*access_)[s] = true;
+ } else {
+ if (access_)
+ (*access_)[s] = false;
+ *props_ |= kNotAccessible;
+ *props_ &= ~kAccessible;
+ }
+ ++nstates_;
+ return true;
+}
+
+template <class A> inline
+void SccVisitor<A>::FinishState(StateId s, StateId p, const A *) {
+ if (fst_->Final(s) != Weight::Zero())
+ (*coaccess_)[s] = true;
+ if ((*dfnumber_)[s] == (*lowlink_)[s]) { // root of new SCC
+ bool scc_coaccess = false;
+ size_t i = scc_stack_->size();
+ StateId t;
+ do {
+ t = (*scc_stack_)[--i];
+ if ((*coaccess_)[t])
+ scc_coaccess = true;
+ } while (s != t);
+ do {
+ t = scc_stack_->back();
+ if (scc_)
+ (*scc_)[t] = nscc_;
+ if (scc_coaccess)
+ (*coaccess_)[t] = true;
+ (*onstack_)[t] = false;
+ scc_stack_->pop_back();
+ } while (s != t);
+ if (!scc_coaccess) {
+ *props_ |= kNotCoAccessible;
+ *props_ &= ~kCoAccessible;
+ }
+ ++nscc_;
+ }
+ if (p != kNoStateId) {
+ if ((*coaccess_)[s])
+ (*coaccess_)[p] = true;
+ if ((*lowlink_)[s] < (*lowlink_)[p])
+ (*lowlink_)[p] = (*lowlink_)[s];
+ }
+}
+
+
+// Trims an FST, removing states and arcs that are not on successful
+// paths. This version modifies its input.
+//
+// Complexity:
+// - Time: O(V + E)
+// - Space: O(V + E)
+// where V = # of states and E = # of arcs.
+template<class Arc>
+void Connect(MutableFst<Arc> *fst) {
+ typedef typename Arc::StateId StateId;
+
+ vector<bool> access;
+ vector<bool> coaccess;
+ uint64 props = 0;
+ SccVisitor<Arc> scc_visitor(0, &access, &coaccess, &props);
+ DfsVisit(*fst, &scc_visitor);
+ vector<StateId> dstates;
+ for (StateId s = 0; s < access.size(); ++s)
+ if (!access[s] || !coaccess[s])
+ dstates.push_back(s);
+ fst->DeleteStates(dstates);
+ fst->SetProperties(kAccessible | kCoAccessible, kAccessible | kCoAccessible);
+}
+
+} // namespace fst
+
+#endif // FST_LIB_CONNECT_H__