aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/cache.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/cache.h')
-rw-r--r--src/include/fst/cache.h366
1 files changed, 246 insertions, 120 deletions
diff --git a/src/include/fst/cache.h b/src/include/fst/cache.h
index 0177396..7c96fe1 100644
--- a/src/include/fst/cache.h
+++ b/src/include/fst/cache.h
@@ -89,14 +89,15 @@ struct DefaultCacheStateAllocator {
// CacheState below). This class is used to cache FST elements with
// the flags used to indicate what has been cached. Use HasStart()
// HasFinal(), and HasArcs() to determine if cached and SetStart(),
-// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note you
-// must set the final weight even if the state is non-final to mark it as
-// cached. If the 'gc' option is 'false', cached items have the extent
-// of the FST - minimizing computation. If the 'gc' option is 'true',
-// garbage collection of states (not in use in an arc iterator) is
-// performed, in a rough approximation of LRU order, when 'gc_limit'
-// bytes is reached - controlling memory use. When 'gc_limit' is 0,
-// special optimizations apply - minimizing memory use.
+// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note
+// you must set the final weight even if the state is non-final to
+// mark it as cached. If the 'gc' option is 'false', cached items have
+// the extent of the FST - minimizing computation. If the 'gc' option
+// is 'true', garbage collection of states (not in use in an arc
+// iterator and not 'protected') is performed, in a rough
+// approximation of LRU order, when 'gc_limit' bytes is reached -
+// controlling memory use. When 'gc_limit' is 0, special optimizations
+// apply - minimizing memory use.
template <class S, class C = DefaultCacheStateAllocator<S> >
class CacheBaseImpl : public VectorFstBaseImpl<S> {
@@ -111,8 +112,10 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
using FstImpl<Arc>::Properties;
using FstImpl<Arc>::SetProperties;
using VectorFstBaseImpl<State>::NumStates;
+ using VectorFstBaseImpl<State>::Start;
using VectorFstBaseImpl<State>::AddState;
using VectorFstBaseImpl<State>::SetState;
+ using VectorFstBaseImpl<State>::ReserveStates;
explicit CacheBaseImpl(C *allocator = 0)
: cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
@@ -120,27 +123,57 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
cache_gc_(FLAGS_fst_default_cache_gc), cache_size_(0),
cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
FLAGS_fst_default_cache_gc_limit == 0 ?
- FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) {
- allocator_ = allocator ? allocator : new C();
- }
+ FLAGS_fst_default_cache_gc_limit : kMinCacheLimit),
+ protect_(false) {
+ allocator_ = allocator ? allocator : new C();
+ }
explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0)
: cache_start_(false), nknown_states_(0),
min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
- opts.gc_limit : kMinCacheLimit) {
- allocator_ = allocator ? allocator : new C();
- }
+ opts.gc_limit : kMinCacheLimit),
+ protect_(false) {
+ allocator_ = allocator ? allocator : new C();
+ }
- // Preserve gc parameters, but initially cache nothing.
- CacheBaseImpl(const CacheBaseImpl &impl)
- : cache_start_(false), nknown_states_(0),
- min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
- cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
- cache_limit_(impl.cache_limit_) {
- allocator_ = new C();
+ // Preserve gc parameters. If preserve_cache true, also preserves
+ // cache data.
+ CacheBaseImpl(const CacheBaseImpl<S, C> &impl, bool preserve_cache = false)
+ : VectorFstBaseImpl<S>(), cache_start_(false), nknown_states_(0),
+ min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
+ cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
+ cache_limit_(impl.cache_limit_),
+ protect_(impl.protect_) {
+ allocator_ = new C();
+ if (preserve_cache) {
+ cache_start_ = impl.cache_start_;
+ nknown_states_ = impl.nknown_states_;
+ expanded_states_ = impl.expanded_states_;
+ min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
+ if (impl.cache_first_state_id_ != kNoStateId) {
+ cache_first_state_id_ = impl.cache_first_state_id_;
+ cache_first_state_ = allocator_->Allocate(cache_first_state_id_);
+ *cache_first_state_ = *impl.cache_first_state_;
}
+ cache_states_ = impl.cache_states_;
+ cache_size_ = impl.cache_size_;
+ ReserveStates(impl.NumStates());
+ for (StateId s = 0; s < impl.NumStates(); ++s) {
+ const S *state =
+ static_cast<const VectorFstBaseImpl<S> &>(impl).GetState(s);
+ if (state) {
+ S *copied_state = allocator_->Allocate(s);
+ *copied_state = *state;
+ AddState(copied_state);
+ } else {
+ AddState(0);
+ }
+ }
+ VectorFstBaseImpl<S>::SetStart(impl.Start());
+ }
+ }
~CacheBaseImpl() {
allocator_->Free(cache_first_state_, cache_first_state_id_);
@@ -174,49 +207,7 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
}
// Gets a state from its ID; add it if necessary.
- S *ExtendState(StateId s) {
- if (s == cache_first_state_id_) {
- return cache_first_state_; // Return 1st cached state
- } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
- cache_first_state_id_ = s; // Remember 1st cached state
- cache_first_state_ = allocator_->Allocate(s);
- return cache_first_state_;
- } else if (cache_first_state_id_ != kNoStateId &&
- cache_first_state_->ref_count == 0) {
- // With Default allocator, the Free and Allocate will reuse the same S*.
- allocator_->Free(cache_first_state_, cache_first_state_id_);
- cache_first_state_id_ = s;
- cache_first_state_ = allocator_->Allocate(s);
- return cache_first_state_; // Return 1st cached state
- } else {
- while (NumStates() <= s) // Add state to main cache
- AddState(0);
- if (!VectorFstBaseImpl<S>::GetState(s)) {
- SetState(s, allocator_->Allocate(s));
- if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state
- while (NumStates() <= cache_first_state_id_)
- AddState(0);
- SetState(cache_first_state_id_, cache_first_state_);
- if (cache_gc_) {
- cache_states_.push_back(cache_first_state_id_);
- cache_size_ += sizeof(S) +
- cache_first_state_->arcs.capacity() * sizeof(Arc);
- }
- cache_limit_ = kMinCacheLimit;
- cache_first_state_id_ = kNoStateId;
- cache_first_state_ = 0;
- }
- if (cache_gc_) {
- cache_states_.push_back(s);
- cache_size_ += sizeof(S);
- if (cache_size_ > cache_limit_)
- GC(s, false);
- }
- }
- S *state = VectorFstBaseImpl<S>::GetState(s);
- return state;
- }
- }
+ S *ExtendState(StateId s);
void SetStart(StateId s) {
VectorFstBaseImpl<S>::SetStart(s);
@@ -246,7 +237,8 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back());
SetProperties(AddArcProperties(Properties(), s, arc, parc));
state->flags |= kCacheModified;
- if (cache_gc_ && s != cache_first_state_id_) {
+ if (cache_gc_ && s != cache_first_state_id_ &&
+ !(state->flags & kCacheProtect)) {
cache_size_ += sizeof(Arc);
if (cache_size_ > cache_limit_)
GC(s, false);
@@ -278,7 +270,8 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
}
ExpandedState(s);
state->flags |= kCacheArcs | kCacheRecent | kCacheModified;
- if (cache_gc_ && s != cache_first_state_id_) {
+ if (cache_gc_ && s != cache_first_state_id_ &&
+ !(state->flags & kCacheProtect)) {
cache_size_ += arcs.capacity() * sizeof(Arc);
if (cache_size_ > cache_limit_)
GC(s, false);
@@ -300,18 +293,73 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
if (arcs[j].olabel == 0)
--state->noepsilons;
}
+
state->arcs.resize(arcs.size() - n);
SetProperties(DeleteArcsProperties(Properties()));
state->flags |= kCacheModified;
+ if (cache_gc_ && s != cache_first_state_id_ &&
+ !(state->flags & kCacheProtect)) {
+ cache_size_ -= n * sizeof(Arc);
+ }
}
void DeleteArcs(StateId s) {
S *state = ExtendState(s);
+ size_t n = state->arcs.size();
state->niepsilons = 0;
state->noepsilons = 0;
state->arcs.clear();
SetProperties(DeleteArcsProperties(Properties()));
state->flags |= kCacheModified;
+ if (cache_gc_ && s != cache_first_state_id_ &&
+ !(state->flags & kCacheProtect)) {
+ cache_size_ -= n * sizeof(Arc);
+ }
+ }
+
+ void DeleteStates(const vector<StateId> &dstates) {
+ size_t old_num_states = NumStates();
+ vector<StateId> newid(old_num_states, 0);
+ for (size_t i = 0; i < dstates.size(); ++i)
+ newid[dstates[i]] = kNoStateId;
+ StateId nstates = 0;
+ for (StateId s = 0; s < old_num_states; ++s) {
+ if (newid[s] != kNoStateId) {
+ newid[s] = nstates;
+ ++nstates;
+ }
+ }
+ // just for states_.resize(), does unnecessary walk.
+ VectorFstBaseImpl<S>::DeleteStates(dstates);
+ SetProperties(DeleteStatesProperties(Properties()));
+ // Update list of cached states.
+ typename list<StateId>::iterator siter = cache_states_.begin();
+ while (siter != cache_states_.end()) {
+ if (newid[*siter] != kNoStateId) {
+ *siter = newid[*siter];
+ ++siter;
+ } else {
+ cache_states_.erase(siter++);
+ }
+ }
+ }
+
+ void DeleteStates() {
+ cache_states_.clear();
+ allocator_->Free(cache_first_state_, cache_first_state_id_);
+ for (int s = 0; s < NumStates(); ++s) {
+ allocator_->Free(VectorFstBaseImpl<S>::GetState(s), s);
+ SetState(s, 0);
+ }
+ nknown_states_ = 0;
+ min_unexpanded_state_id_ = 0;
+ cache_first_state_id_ = kNoStateId;
+ cache_first_state_ = 0;
+ cache_size_ = 0;
+ cache_start_ = false;
+ VectorFstBaseImpl<State>::DeleteStates();
+ SetProperties(DeleteAllStatesProperties(Properties(),
+ kExpanded | kMutable));
}
// Is the start state cached?
@@ -390,48 +438,17 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
return min_unexpanded_state_id_;
}
- // Removes from cache_states_ and uncaches (not referenced-counted)
- // states that have not been accessed since the last GC until
- // cache_limit_/3 bytes are uncached. If that fails to free enough,
- // recurs uncaching recently visited states as well. If still
- // unable to free enough memory, then widens cache_limit_.
- void GC(StateId current, bool free_recent) {
- if (!cache_gc_)
- return;
- VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
- << "), free recently cached = " << free_recent
- << ", cache size = " << cache_size_
- << ", cache limit = " << cache_limit_ << "\n";
- typename list<StateId>::iterator siter = cache_states_.begin();
+ // Removes from cache_states_ and uncaches (not referenced-counted
+ // or protected) states that have not been accessed since the last
+ // GC until at most cache_fraction * cache_limit_ bytes are cached.
+ // If that fails to free enough, recurs uncaching recently visited
+ // states as well. If still unable to free enough memory, then
+ // widens cache_limit_ to fulfill condition.
+ void GC(StateId current, bool free_recent, float cache_fraction = 0.666);
- size_t cache_target = (2 * cache_limit_)/3 + 1;
- while (siter != cache_states_.end()) {
- StateId s = *siter;
- S* state = VectorFstBaseImpl<S>::GetState(s);
- if (cache_size_ > cache_target && state->ref_count == 0 &&
- (free_recent || !(state->flags & kCacheRecent)) && s != current) {
- cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
- allocator_->Free(state, s);
- SetState(s, 0);
- cache_states_.erase(siter++);
- } else {
- state->flags &= ~kCacheRecent;
- ++siter;
- }
- }
- if (!free_recent && cache_size_ > cache_target) {
- GC(current, true);
- } else {
- while (cache_size_ > cache_target) {
- cache_limit_ *= 2;
- cache_target *= 2;
- }
- }
- VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
- << "), free recently cached = " << free_recent
- << ", cache size = " << cache_size_
- << ", cache limit = " << cache_limit_ << "\n";
- }
+ // Setc/clears GC protection: if true, new states are protected
+ // from garbage collection.
+ void GCProtect(bool on) { protect_ = on; }
void ExpandedState(StateId s) {
if (s < min_unexpanded_state_id_)
@@ -441,26 +458,30 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
expanded_states_[s] = true;
}
+ C *GetAllocator() const {
+ return allocator_;
+ }
+
// Caching on/off switch, limit and size accessors.
bool GetCacheGc() const { return cache_gc_; }
size_t GetCacheLimit() const { return cache_limit_; }
size_t GetCacheSize() const { return cache_size_; }
private:
- static const size_t kMinCacheLimit = 8096; // Minimum (non-zero) cache limit
- static const uint32 kCacheFinal = 0x0001; // Final weight has been cached
- static const uint32 kCacheArcs = 0x0002; // Arcs have been cached
- static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC
+ static const size_t kMinCacheLimit = 8096; // Minimum (non-zero) cache limit
+
+ static const uint32 kCacheFinal = 0x0001; // Final weight has been cached
+ static const uint32 kCacheArcs = 0x0002; // Arcs have been cached
+ static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC
+ static const uint32 kCacheProtect = 0x0008; // Mark state as GC protected
public:
- static const uint32 kCacheModified = 0x0008; // Mark state as modified
+ static const uint32 kCacheModified = 0x0010; // Mark state as modified
static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent
- | kCacheModified;
-
- protected:
- C *allocator_; // used to allocate new states
+ | kCacheProtect | kCacheModified;
private:
+ C *allocator_; // used to allocate new states
mutable bool cache_start_; // Is the start state cached?
StateId nknown_states_; // # of known states
vector<bool> expanded_states_; // states that have been expanded
@@ -471,10 +492,113 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
bool cache_gc_; // enable GC
size_t cache_size_; // # of bytes cached
size_t cache_limit_; // # of bytes allowed before GC
+ bool protect_; // Protect new states from GC
- void operator=(const CacheBaseImpl<S> &impl); // disallow
+ void operator=(const CacheBaseImpl<S, C> &impl); // disallow
};
+// Gets a state from its ID; add it if necessary.
+template <class S, class C>
+S *CacheBaseImpl<S, C>::ExtendState(typename S::Arc::StateId s) {
+ // If 'protect_' true and a new state, protects from garbage collection.
+ if (s == cache_first_state_id_) {
+ return cache_first_state_; // Return 1st cached state
+ } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
+ cache_first_state_id_ = s; // Remember 1st cached state
+ cache_first_state_ = allocator_->Allocate(s);
+ if (protect_) cache_first_state_->flags |= kCacheProtect;
+ return cache_first_state_;
+ } else if (cache_first_state_id_ != kNoStateId &&
+ cache_first_state_->ref_count == 0 &&
+ !(cache_first_state_->flags & kCacheProtect)) {
+ // With Default allocator, the Free and Allocate will reuse the same S*.
+ allocator_->Free(cache_first_state_, cache_first_state_id_);
+ cache_first_state_id_ = s;
+ cache_first_state_ = allocator_->Allocate(s);
+ if (protect_) cache_first_state_->flags |= kCacheProtect;
+ return cache_first_state_; // Return 1st cached state
+ } else {
+ while (NumStates() <= s) // Add state to main cache
+ AddState(0);
+ S *state = VectorFstBaseImpl<S>::GetState(s);
+ if (!state) {
+ state = allocator_->Allocate(s);
+ if (protect_) state->flags |= kCacheProtect;
+ SetState(s, state);
+ if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state
+ while (NumStates() <= cache_first_state_id_)
+ AddState(0);
+ SetState(cache_first_state_id_, cache_first_state_);
+ if (cache_gc_ && !(cache_first_state_->flags & kCacheProtect)) {
+ cache_states_.push_back(cache_first_state_id_);
+ cache_size_ += sizeof(S) +
+ cache_first_state_->arcs.capacity() * sizeof(Arc);
+ }
+ cache_limit_ = kMinCacheLimit;
+ cache_first_state_id_ = kNoStateId;
+ cache_first_state_ = 0;
+ }
+ if (cache_gc_ && !protect_) {
+ cache_states_.push_back(s);
+ cache_size_ += sizeof(S);
+ if (cache_size_ > cache_limit_)
+ GC(s, false);
+ }
+ }
+ return state;
+ }
+}
+
+// Removes from cache_states_ and uncaches (not referenced-counted or
+// protected) states that have not been accessed since the last GC
+// until at most cache_fraction * cache_limit_ bytes are cached. If
+// that fails to free enough, recurs uncaching recently visited states
+// as well. If still unable to free enough memory, then widens cache_limit_
+// to fulfill condition.
+template <class S, class C>
+void CacheBaseImpl<S, C>::GC(typename S::Arc::StateId current,
+ bool free_recent, float cache_fraction) {
+ if (!cache_gc_)
+ return;
+ VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
+ << "), free recently cached = " << free_recent
+ << ", cache size = " << cache_size_
+ << ", cache frac = " << cache_fraction
+ << ", cache limit = " << cache_limit_ << "\n";
+ typename list<StateId>::iterator siter = cache_states_.begin();
+
+ size_t cache_target = cache_fraction * cache_limit_;
+ while (siter != cache_states_.end()) {
+ StateId s = *siter;
+ S* state = VectorFstBaseImpl<S>::GetState(s);
+ if (cache_size_ > cache_target && state->ref_count == 0 &&
+ (free_recent || !(state->flags & kCacheRecent)) && s != current) {
+ cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
+ allocator_->Free(state, s);
+ SetState(s, 0);
+ cache_states_.erase(siter++);
+ } else {
+ state->flags &= ~kCacheRecent;
+ ++siter;
+ }
+ }
+ if (!free_recent && cache_size_ > cache_target) { // recurses on recent
+ GC(current, true);
+ } else if (cache_target > 0) { // widens cache limit
+ while (cache_size_ > cache_target) {
+ cache_limit_ *= 2;
+ cache_target *= 2;
+ }
+ } else if (cache_size_ > 0) {
+ FSTERROR() << "CacheImpl:GC: Unable to free all cached states";
+ }
+ VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
+ << "), free recently cached = " << free_recent
+ << ", cache size = " << cache_size_
+ << ", cache frac = " << cache_fraction
+ << ", cache limit = " << cache_limit_ << "\n";
+}
+
template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal;
template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs;
template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent;
@@ -516,7 +640,8 @@ class CacheImpl : public CacheBaseImpl< CacheState<A> > {
explicit CacheImpl(const CacheOptions &opts)
: CacheBaseImpl< CacheState<A> >(opts) {}
- CacheImpl(const CacheImpl<State> &impl) : CacheBaseImpl<State>(impl) {}
+ CacheImpl(const CacheImpl<A> &impl, bool preserve_cache = false)
+ : CacheBaseImpl<State>(impl, preserve_cache) {}
private:
void operator=(const CacheImpl<State> &impl); // disallow
@@ -536,12 +661,13 @@ class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
typedef CacheBaseImpl<State> Impl;
CacheStateIterator(const F &fst, Impl *impl)
- : fst_(fst), impl_(impl), s_(0) {}
+ : fst_(fst), impl_(impl), s_(0) {
+ fst_.Start(); // force start state
+ }
bool Done() const {
if (s_ < impl_->NumKnownStates())
return false;
- fst_.Start(); // force start state
if (s_ < impl_->NumKnownStates())
return false;
for (StateId u = impl_->MinUnexpandedState();