diff options
Diffstat (limited to 'src/include/fst/extensions/pdt/replace.h')
-rw-r--r-- | src/include/fst/extensions/pdt/replace.h | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/src/include/fst/extensions/pdt/replace.h b/src/include/fst/extensions/pdt/replace.h index a85d0fe..9081400 100644 --- a/src/include/fst/extensions/pdt/replace.h +++ b/src/include/fst/extensions/pdt/replace.h @@ -21,6 +21,10 @@ #ifndef FST_EXTENSIONS_PDT_REPLACE_H__ #define FST_EXTENSIONS_PDT_REPLACE_H__ +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + #include <fst/replace.h> namespace fst { @@ -62,11 +66,14 @@ void Replace(const vector<pair<typename Arc::Label, label2id[ifst_array[i].first] = i; Label max_label = kNoLabel; + size_t max_non_term_count = 0; - deque<size_t> non_term_queue; // Queue of non-terminals to replace - unordered_set<Label> non_term_set; // Set of non-terminals to replace + // Queue of non-terminals to replace + deque<size_t> non_term_queue; + // Map of non-terminals to replace to count + unordered_map<Label, size_t> non_term_map; non_term_queue.push_back(root); - non_term_set.insert(root); + non_term_map[root] = 1;; // PDT state corr. to ith replace FST start state. vector<StateId> fst_start(ifst_array.size(), kNoLabel); @@ -107,10 +114,11 @@ void Replace(const vector<pair<typename Arc::Label, size_t nfst_id = it->second; if (ifst_array[nfst_id].second->Start() == -1) continue; - if (non_term_set.count(arc.olabel) == 0) { + size_t count = non_term_map[arc.olabel]++; + if (count == 0) non_term_queue.push_back(arc.olabel); - non_term_set.insert(arc.olabel); - } + if (count > max_non_term_count) + max_non_term_count = count; } arc.nextstate += soff; ofst->AddArc(os, arc); @@ -134,7 +142,8 @@ void Replace(const vector<pair<typename Arc::Label, // # of parenthesis pairs per fst. vector<size_t> nparens(ifst_array.size(), 0); // Initial open parenthesis label - Label first_paren = max_label + 1; + Label first_open_paren = max_label + 1; + Label first_close_paren = max_label + max_non_term_count + 1; for (StateIterator< Fst<Arc> > siter(*ofst); !siter.Done(); siter.Next()) { @@ -158,8 +167,8 @@ void Replace(const vector<pair<typename Arc::Label, close_paren = (*parens)[paren_id].second; } else { size_t paren_id = nparens[nfst_id]++; - open_paren = first_paren + 2 * paren_id; - close_paren = open_paren + 1; + open_paren = first_open_paren + paren_id; + close_paren = first_close_paren + paren_id; paren_map[paren_key] = paren_id; if (paren_id >= parens->size()) parens->push_back(make_pair(open_paren, close_paren)); |