diff options
Diffstat (limited to 'src/include/fst/extensions/pdt/shortest-path.h')
-rw-r--r-- | src/include/fst/extensions/pdt/shortest-path.h | 790 |
1 files changed, 790 insertions, 0 deletions
diff --git a/src/include/fst/extensions/pdt/shortest-path.h b/src/include/fst/extensions/pdt/shortest-path.h new file mode 100644 index 0000000..e90471b --- /dev/null +++ b/src/include/fst/extensions/pdt/shortest-path.h @@ -0,0 +1,790 @@ +// shortest-path.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Functions to find shortest paths in a PDT. + +#ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ +#define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ + +#include <fst/shortest-path.h> +#include <fst/extensions/pdt/paren.h> +#include <fst/extensions/pdt/pdt.h> + +#include <unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; +#include <stack> +#include <vector> +using std::vector; + +namespace fst { + +template <class Arc, class Queue> +struct PdtShortestPathOptions { + bool keep_parentheses; + bool path_gc; + + PdtShortestPathOptions(bool kp = false, bool gc = true) + : keep_parentheses(kp), path_gc(gc) {} +}; + + +// Class to store PDT shortest path results. Stores shortest path +// tree info 'Distance()', Parent(), and ArcParent() information keyed +// on two types: +// (1) By SearchState: This is a usual node in a shortest path tree but: +// (a) is w.r.t a PDT search state - a pair of a PDT state and +// a 'start' state, which is either the PDT start state or +// the destination state of an open parenthesis. +// (b) the Distance() is from this 'start' state to the search state. +// (c) Parent().state is kNoLabel for the 'start' state. +// +// (2) By ParenSpec: This connects shortest path trees depending on the +// the parenthesis taken. Given the parenthesis spec: +// (a) the Distance() is from the Parent() 'start' state to the +// parenthesis destination state. +// (b) the ArcParent() is the parenthesis arc. +template <class Arc> +class PdtShortestPathData { + public: + static const uint8 kFinal; + + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + struct SearchState { + SearchState() : state(kNoStateId), start(kNoStateId) {} + + SearchState(StateId s, StateId t) : state(s), start(t) {} + + bool operator==(const SearchState &s) const { + if (&s == this) + return true; + return s.state == this->state && s.start == this->start; + } + + StateId state; // PDT state + StateId start; // PDT paren 'source' state + }; + + + // Specifies paren id, source and dest 'start' states of a paren. + // These are the 'start' states of the respective sub-graphs. + struct ParenSpec { + ParenSpec() + : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {} + + ParenSpec(Label id, StateId s, StateId d) + : paren_id(id), src_start(s), dest_start(d) {} + + Label paren_id; // Id of parenthesis + StateId src_start; // sub-graph 'start' state for paren source. + StateId dest_start; // sub-graph 'start' state for paren dest. + + bool operator==(const ParenSpec &x) const { + if (&x == this) + return true; + return x.paren_id == this->paren_id && + x.src_start == this->src_start && + x.dest_start == this->dest_start; + } + }; + + struct SearchData { + SearchData() : distance(Weight::Zero()), + parent(kNoStateId, kNoStateId), + paren_id(kNoLabel), + flags(0) {} + + Weight distance; // Distance to this state from PDT 'start' state + SearchState parent; // Parent state in shortest path tree + int16 paren_id; // If parent arc has paren, paren ID, o.w. kNoLabel + uint8 flags; // First byte reserved for PdtShortestPathData use + }; + + PdtShortestPathData(bool gc) + : state_(kNoStateId, kNoStateId), + paren_(kNoLabel, kNoStateId, kNoStateId), + gc_(gc), + nstates_(0), + ngc_(0), + finished_(false) {} + + ~PdtShortestPathData() { + VLOG(1) << "opm size: " << paren_map_.size(); + VLOG(1) << "# of search states: " << nstates_; + if (gc_) + VLOG(1) << "# of GC'd search states: " << ngc_; + } + + void Clear() { + search_map_.clear(); + search_multimap_.clear(); + paren_map_.clear(); + state_ = SearchState(kNoStateId, kNoStateId); + nstates_ = 0; + ngc_ = 0; + } + + Weight Distance(SearchState s) const { + SearchData *data = GetSearchData(s); + return data->distance; + } + + Weight Distance(const ParenSpec &paren) const { + SearchData *data = GetSearchData(paren); + return data->distance; + } + + SearchState Parent(SearchState s) const { + SearchData *data = GetSearchData(s); + return data->parent; + } + + SearchState Parent(const ParenSpec &paren) const { + SearchData *data = GetSearchData(paren); + return data->parent; + } + + Label ParenId(SearchState s) const { + SearchData *data = GetSearchData(s); + return data->paren_id; + } + + uint8 Flags(SearchState s) const { + SearchData *data = GetSearchData(s); + return data->flags; + } + + void SetDistance(SearchState s, Weight w) { + SearchData *data = GetSearchData(s); + data->distance = w; + } + + void SetDistance(const ParenSpec &paren, Weight w) { + SearchData *data = GetSearchData(paren); + data->distance = w; + } + + void SetParent(SearchState s, SearchState p) { + SearchData *data = GetSearchData(s); + data->parent = p; + } + + void SetParent(const ParenSpec &paren, SearchState p) { + SearchData *data = GetSearchData(paren); + data->parent = p; + } + + void SetParenId(SearchState s, Label p) { + if (p >= 32768) + FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16"; + SearchData *data = GetSearchData(s); + data->paren_id = p; + } + + void SetFlags(SearchState s, uint8 f, uint8 mask) { + SearchData *data = GetSearchData(s); + data->flags &= ~mask; + data->flags |= f & mask; + } + + void GC(StateId s); + + void Finish() { finished_ = true; } + + private: + static const Arc kNoArc; + static const size_t kPrime0; + static const size_t kPrime1; + static const uint8 kInited; + static const uint8 kMarked; + + // Hash for search state + struct SearchStateHash { + size_t operator()(const SearchState &s) const { + return s.state + s.start * kPrime0; + } + }; + + // Hash for paren map + struct ParenHash { + size_t operator()(const ParenSpec &paren) const { + return paren.paren_id + paren.src_start * kPrime0 + + paren.dest_start * kPrime1; + } + }; + + typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap; + + typedef unordered_multimap<StateId, StateId> SearchMultimap; + + // Hash map from paren spec to open paren data + typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap; + + SearchData *GetSearchData(SearchState s) const { + if (s == state_) + return state_data_; + if (finished_) { + typename SearchMap::iterator it = search_map_.find(s); + if (it == search_map_.end()) + return &null_search_data_; + state_ = s; + return state_data_ = &(it->second); + } else { + state_ = s; + state_data_ = &search_map_[s]; + if (!(state_data_->flags & kInited)) { + ++nstates_; + if (gc_) + search_multimap_.insert(make_pair(s.start, s.state)); + state_data_->flags = kInited; + } + return state_data_; + } + } + + SearchData *GetSearchData(ParenSpec paren) const { + if (paren == paren_) + return paren_data_; + if (finished_) { + typename ParenMap::iterator it = paren_map_.find(paren); + if (it == paren_map_.end()) + return &null_search_data_; + paren_ = paren; + return state_data_ = &(it->second); + } else { + paren_ = paren; + return paren_data_ = &paren_map_[paren]; + } + } + + mutable SearchMap search_map_; // Maps from search state to data + mutable SearchMultimap search_multimap_; // Maps from 'start' to subgraph + mutable ParenMap paren_map_; // Maps paren spec to search data + mutable SearchState state_; // Last state accessed + mutable SearchData *state_data_; // Last state data accessed + mutable ParenSpec paren_; // Last paren spec accessed + mutable SearchData *paren_data_; // Last paren data accessed + bool gc_; // Allow GC? + mutable size_t nstates_; // Total number of search states + size_t ngc_; // Number of GC'd search states + mutable SearchData null_search_data_; // Null search data + bool finished_; // Read-only access when true + + DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData); +}; + +// Deletes inaccessible search data from a given 'start' (open paren dest) +// state. Assumes 'final' (close paren source or PDT final) states have +// been flagged 'kFinal'. +template<class Arc> +void PdtShortestPathData<Arc>::GC(StateId start) { + if (!gc_) + return; + vector<StateId> final; + for (typename SearchMultimap::iterator mmit = search_multimap_.find(start); + mmit != search_multimap_.end() && mmit->first == start; + ++mmit) { + SearchState s(mmit->second, start); + const SearchData &data = search_map_[s]; + if (data.flags & kFinal) + final.push_back(s.state); + } + + // Mark phase + for (size_t i = 0; i < final.size(); ++i) { + SearchState s(final[i], start); + while (s.state != kNoLabel) { + SearchData *sdata = &search_map_[s]; + if (sdata->flags & kMarked) + break; + sdata->flags |= kMarked; + SearchState p = sdata->parent; + if (p.start != start && p.start != kNoLabel) { // entering sub-subgraph + ParenSpec paren(sdata->paren_id, s.start, p.start); + SearchData *pdata = &paren_map_[paren]; + s = pdata->parent; + } else { + s = p; + } + } + } + + // Sweep phase + typename SearchMultimap::iterator mmit = search_multimap_.find(start); + while (mmit != search_multimap_.end() && mmit->first == start) { + SearchState s(mmit->second, start); + typename SearchMap::iterator mit = search_map_.find(s); + const SearchData &data = mit->second; + if (!(data.flags & kMarked)) { + search_map_.erase(mit); + ++ngc_; + } + search_multimap_.erase(mmit++); + } +} + +template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc + = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); + +template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853; + +template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867; + +template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01; + +template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal = 0x02; + +template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04; + + +// This computes the single source shortest (balanced) path (SSSP) +// through a weighted PDT that has a bounded stack (i.e. is expandable +// as an FST). It is a generalization of the classic SSSP graph +// algorithm that removes a state s from a queue (defined by a +// user-provided queue type) and relaxes the destination states of +// transitions leaving s. In this PDT version, states that have +// entering open parentheses are treated as source states for a +// sub-graph SSSP problem with the shortest path up to the open +// parenthesis being first saved. When a close parenthesis is then +// encountered any balancing open parenthesis is examined for this +// saved information and multiplied back. In this way, each sub-graph +// is entered only once rather than repeatedly. If every state in the +// input PDT has the property that there is a unique 'start' state for +// it with entering open parentheses, then this algorithm is quite +// straight-forward. In general, this will not be the case, so the +// algorithm (implicitly) creates a new graph where each state is a +// pair of an original state and a possible parenthesis 'start' state +// for that state. +template<class Arc, class Queue> +class PdtShortestPath { + public: + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + typedef PdtShortestPathData<Arc> SpData; + typedef typename SpData::SearchState SearchState; + typedef typename SpData::ParenSpec ParenSpec; + + typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator; + typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator; + + PdtShortestPath(const Fst<Arc> &ifst, + const vector<pair<Label, Label> > &parens, + const PdtShortestPathOptions<Arc, Queue> &opts) + : kFinal(SpData::kFinal), + ifst_(ifst.Copy()), + parens_(parens), + keep_parens_(opts.keep_parentheses), + start_(ifst.Start()), + sp_data_(opts.path_gc), + error_(false) { + + if ((Weight::Properties() & (kPath | kRightSemiring)) + != (kPath | kRightSemiring)) { + FSTERROR() << "SingleShortestPath: Weight needs to have the path" + << " property and be right distributive: " << Weight::Type(); + error_ = true; + } + + for (Label i = 0; i < parens.size(); ++i) { + const pair<Label, Label> &p = parens[i]; + paren_id_map_[p.first] = i; + paren_id_map_[p.second] = i; + } + }; + + ~PdtShortestPath() { + VLOG(1) << "# of input states: " << CountStates(*ifst_); + VLOG(1) << "# of enqueued: " << nenqueued_; + VLOG(1) << "cpmm size: " << close_paren_multimap_.size(); + delete ifst_; + } + + void ShortestPath(MutableFst<Arc> *ofst) { + Init(ofst); + GetDistance(start_); + GetPath(); + sp_data_.Finish(); + if (error_) ofst->SetProperties(kError, kError); + } + + const PdtShortestPathData<Arc> &GetShortestPathData() const { + return sp_data_; + } + + PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; } + + private: + static const Arc kNoArc; + static const uint8 kEnqueued; + static const uint8 kExpanded; + const uint8 kFinal; + + public: + // Hash multimap from close paren label to an paren arc. + typedef unordered_multimap<ParenState<Arc>, Arc, + typename ParenState<Arc>::Hash> CloseParenMultimap; + + const CloseParenMultimap &GetCloseParenMultimap() const { + return close_paren_multimap_; + } + + private: + void Init(MutableFst<Arc> *ofst); + void GetDistance(StateId start); + void ProcFinal(SearchState s); + void ProcArcs(SearchState s); + void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w); + void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w); + void ProcNonParen(SearchState s, const Arc &arc, Weight w); + void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id); + void Enqueue(SearchState d); + void GetPath(); + Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open); + + Fst<Arc> *ifst_; + MutableFst<Arc> *ofst_; + const vector<pair<Label, Label> > &parens_; + bool keep_parens_; + Queue *state_queue_; // current state queue + StateId start_; + Weight f_distance_; + SearchState f_parent_; + SpData sp_data_; + unordered_map<Label, Label> paren_id_map_; + CloseParenMultimap close_paren_multimap_; + PdtBalanceData<Arc> balance_data_; + ssize_t nenqueued_; + bool error_; + + DISALLOW_COPY_AND_ASSIGN(PdtShortestPath); +}; + +template<class Arc, class Queue> +void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) { + ofst_ = ofst; + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst_->InputSymbols()); + ofst->SetOutputSymbols(ifst_->OutputSymbols()); + + if (ifst_->Start() == kNoStateId) + return; + + f_distance_ = Weight::Zero(); + f_parent_ = SearchState(kNoStateId, kNoStateId); + + sp_data_.Clear(); + close_paren_multimap_.clear(); + balance_data_.Clear(); + nenqueued_ = 0; + + // Find open parens per destination state and close parens per source state. + for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + for (ArcIterator<Fst<Arc> > aiter(*ifst_, s); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map_.find(arc.ilabel); + if (pit != paren_id_map_.end()) { // Is a paren? + Label paren_id = pit->second; + if (arc.ilabel == parens_[paren_id].first) { // Open paren + balance_data_.OpenInsert(paren_id, arc.nextstate); + } else { // Close paren + ParenState<Arc> paren_state(paren_id, s); + close_paren_multimap_.insert(make_pair(paren_state, arc)); + } + } + } + } +} + +// Computes the shortest distance stored in a recursive way. Each +// sub-graph (i.e. different paren 'start' state) begins with weight One(). +template<class Arc, class Queue> +void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) { + if (start == kNoStateId) + return; + + Queue state_queue; + state_queue_ = &state_queue; + SearchState q(start, start); + Enqueue(q); + sp_data_.SetDistance(q, Weight::One()); + + while (!state_queue_->Empty()) { + StateId state = state_queue_->Head(); + state_queue_->Dequeue(); + SearchState s(state, start); + sp_data_.SetFlags(s, 0, kEnqueued); + ProcFinal(s); + ProcArcs(s); + sp_data_.SetFlags(s, kExpanded, kExpanded); + } + balance_data_.FinishInsert(start); + sp_data_.GC(start); +} + +// Updates best complete path. +template<class Arc, class Queue> +void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) { + if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) { + Weight w = Times(sp_data_.Distance(s), + ifst_->Final(s.state)); + if (f_distance_ != Plus(f_distance_, w)) { + if (f_parent_.state != kNoStateId) + sp_data_.SetFlags(f_parent_, 0, kFinal); + sp_data_.SetFlags(s, kFinal, kFinal); + + f_distance_ = Plus(f_distance_, w); + f_parent_ = s; + } + } +} + +// Processes all arcs leaving the state s. +template<class Arc, class Queue> +void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) { + for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + Weight w = Times(sp_data_.Distance(s), arc.weight); + + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map_.find(arc.ilabel); + if (pit != paren_id_map_.end()) { // Is a paren? + Label paren_id = pit->second; + if (arc.ilabel == parens_[paren_id].first) + ProcOpenParen(paren_id, s, arc, w); + else + ProcCloseParen(paren_id, s, arc, w); + } else { + ProcNonParen(s, arc, w); + } + } +} + +// Saves the shortest path info for reaching this parenthesis +// and starts a new SSSP in the sub-graph pointed to by the parenthesis +// if previously unvisited. Otherwise it finds any previously encountered +// closing parentheses and relaxes them using the recursively stored +// shortest distance to them. +template<class Arc, class Queue> inline +void PdtShortestPath<Arc, Queue>::ProcOpenParen( + Label paren_id, SearchState s, Arc arc, Weight w) { + + SearchState d(arc.nextstate, arc.nextstate); + ParenSpec paren(paren_id, s.start, d.start); + Weight pdist = sp_data_.Distance(paren); + if (pdist != Plus(pdist, w)) { + sp_data_.SetDistance(paren, w); + sp_data_.SetParent(paren, s); + Weight dist = sp_data_.Distance(d); + if (dist == Weight::Zero()) { + Queue *state_queue = state_queue_; + GetDistance(d.start); + state_queue_ = state_queue; + } + for (CloseSourceIterator set_iter = + balance_data_.Find(paren_id, arc.nextstate); + !set_iter.Done(); set_iter.Next()) { + SearchState cpstate(set_iter.Element(), d.start); + ParenState<Arc> paren_state(paren_id, cpstate.state); + for (typename CloseParenMultimap::const_iterator cpit = + close_paren_multimap_.find(paren_state); + cpit != close_paren_multimap_.end() && paren_state == cpit->first; + ++cpit) { + const Arc &cparc = cpit->second; + Weight cpw = Times(w, Times(sp_data_.Distance(cpstate), + cparc.weight)); + Relax(cpstate, s, cparc, cpw, paren_id); + } + } + } +} + +// Saves the correspondence between each closing parenthesis and its +// balancing open parenthesis info. Relaxes any close parenthesis +// destination state that has a balancing previously encountered open +// parenthesis. +template<class Arc, class Queue> inline +void PdtShortestPath<Arc, Queue>::ProcCloseParen( + Label paren_id, SearchState s, const Arc &arc, Weight w) { + ParenState<Arc> paren_state(paren_id, s.start); + if (!(sp_data_.Flags(s) & kExpanded)) { + balance_data_.CloseInsert(paren_id, s.start, s.state); + sp_data_.SetFlags(s, kFinal, kFinal); + } +} + +// For non-parentheses, classical relaxation. +template<class Arc, class Queue> inline +void PdtShortestPath<Arc, Queue>::ProcNonParen( + SearchState s, const Arc &arc, Weight w) { + Relax(s, s, arc, w, kNoLabel); +} + +// Classical relaxation on the search graph for 'arc' from state 's'. +// State 't' is in the same sub-graph as the nextstate should be (i.e. +// has the same paren 'start'. +template<class Arc, class Queue> inline +void PdtShortestPath<Arc, Queue>::Relax( + SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) { + SearchState d(arc.nextstate, t.start); + Weight dist = sp_data_.Distance(d); + if (dist != Plus(dist, w)) { + sp_data_.SetParent(d, s); + sp_data_.SetParenId(d, paren_id); + sp_data_.SetDistance(d, Plus(dist, w)); + Enqueue(d); + } +} + +template<class Arc, class Queue> inline +void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) { + if (!(sp_data_.Flags(s) & kEnqueued)) { + state_queue_->Enqueue(s.state); + sp_data_.SetFlags(s, kEnqueued, kEnqueued); + ++nenqueued_; + } else { + state_queue_->Update(s.state); + } +} + +// Follows parent pointers to find the shortest path. Uses a stack +// since the shortest distance is stored recursively. +template<class Arc, class Queue> +void PdtShortestPath<Arc, Queue>::GetPath() { + SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId); + StateId s_p = kNoStateId, d_p = kNoStateId; + Arc arc(kNoArc); + Label paren_id = kNoLabel; + stack<ParenSpec> paren_stack; + while (s.state != kNoStateId) { + d_p = s_p; + s_p = ofst_->AddState(); + if (d.state == kNoStateId) { + ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state)); + } else { + if (paren_id != kNoLabel) { // paren? + if (arc.ilabel == parens_[paren_id].first) { // open paren + paren_stack.pop(); + } else { // close paren + ParenSpec paren(paren_id, d.start, s.start); + paren_stack.push(paren); + } + if (!keep_parens_) + arc.ilabel = arc.olabel = 0; + } + arc.nextstate = d_p; + ofst_->AddArc(s_p, arc); + } + d = s; + s = sp_data_.Parent(d); + paren_id = sp_data_.ParenId(d); + if (s.state != kNoStateId) { + arc = GetPathArc(s, d, paren_id, false); + } else if (!paren_stack.empty()) { + ParenSpec paren = paren_stack.top(); + s = sp_data_.Parent(paren); + paren_id = paren.paren_id; + arc = GetPathArc(s, d, paren_id, true); + } + } + ofst_->SetStart(s_p); + ofst_->SetProperties( + ShortestPathProperties(ofst_->Properties(kFstProperties, false)), + kFstProperties); +} + + +// Finds transition with least weight between two states with label matching +// paren_id and open/close paren type or a non-paren if kNoLabel. +template<class Arc, class Queue> +Arc PdtShortestPath<Arc, Queue>::GetPathArc( + SearchState s, SearchState d, Label paren_id, bool open_paren) { + Arc path_arc = kNoArc; + for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.nextstate != d.state) + continue; + Label arc_paren_id = kNoLabel; + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map_.find(arc.ilabel); + if (pit != paren_id_map_.end()) { + arc_paren_id = pit->second; + bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first; + if (arc_open_paren != open_paren) + continue; + } + if (arc_paren_id != paren_id) + continue; + if (arc.weight == Plus(arc.weight, path_arc.weight)) + path_arc = arc; + } + if (path_arc.nextstate == kNoStateId) { + FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc"; + error_ = true; + } + return path_arc; +} + +template<class Arc, class Queue> +const Arc PdtShortestPath<Arc, Queue>::kNoArc + = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); + +template<class Arc, class Queue> +const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10; + +template<class Arc, class Queue> +const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20; + +template<class Arc, class Queue> +void ShortestPath(const Fst<Arc> &ifst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + MutableFst<Arc> *ofst, + const PdtShortestPathOptions<Arc, Queue> &opts) { + PdtShortestPath<Arc, Queue> psp(ifst, parens, opts); + psp.ShortestPath(ofst); +} + +template<class Arc> +void ShortestPath(const Fst<Arc> &ifst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + MutableFst<Arc> *ofst) { + typedef FifoQueue<typename Arc::StateId> Queue; + PdtShortestPathOptions<Arc, Queue> opts; + PdtShortestPath<Arc, Queue> psp(ifst, parens, opts); + psp.ShortestPath(ofst); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ |