aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/cache.h
blob: 01773960869b149b6480231063145ad509e1ee6a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
// cache.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// An Fst implementation that caches FST elements of a delayed
// computation.

#ifndef FST_LIB_CACHE_H__
#define FST_LIB_CACHE_H__

#include <vector>
using std::vector;
#include <list>

#include <fst/vector-fst.h>


DECLARE_bool(fst_default_cache_gc);
DECLARE_int64(fst_default_cache_gc_limit);

namespace fst {

struct CacheOptions {
  bool gc;          // enable GC
  size_t gc_limit;  // # of bytes allowed before GC

  CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
  CacheOptions()
      : gc(FLAGS_fst_default_cache_gc),
        gc_limit(FLAGS_fst_default_cache_gc_limit) {}
};

// A CacheStateAllocator allocates and frees CacheStates
// template <class S>
// struct CacheStateAllocator {
//   S *Allocate(StateId s);
//   void Free(S *state, StateId s);
// };
//

// A simple allocator class, can be overridden as needed,
// maintains a single entry cache.
template <class S>
struct DefaultCacheStateAllocator {
  typedef typename S::Arc::StateId StateId;

  DefaultCacheStateAllocator() : mru_(NULL) { }

  ~DefaultCacheStateAllocator() {
    delete mru_;
  }

  S *Allocate(StateId s) {
    if (mru_) {
      S *state = mru_;
      mru_ = NULL;
      state->Reset();
      return state;
    }
    return new S();
  }

  void Free(S *state, StateId s) {
    if (mru_) {
      delete mru_;
    }
    mru_ = state;
  }

 private:
  S *mru_;
};

// VectorState but additionally has a flags data member (see
// CacheState below). This class is used to cache FST elements with
// the flags used to indicate what has been cached. Use HasStart()
// HasFinal(), and HasArcs() to determine if cached and SetStart(),
// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note you
// must set the final weight even if the state is non-final to mark it as
// cached. If the 'gc' option is 'false', cached items have the extent
// of the FST - minimizing computation. If the 'gc' option is 'true',
// garbage collection of states (not in use in an arc iterator) is
// performed, in a rough approximation of LRU order, when 'gc_limit'
// bytes is reached - controlling memory use. When 'gc_limit' is 0,
// special optimizations apply - minimizing memory use.

template <class S, class C = DefaultCacheStateAllocator<S> >
class CacheBaseImpl : public VectorFstBaseImpl<S> {
 public:
  typedef S State;
  typedef C Allocator;
  typedef typename State::Arc Arc;
  typedef typename Arc::Weight Weight;
  typedef typename Arc::StateId StateId;

  using FstImpl<Arc>::Type;
  using FstImpl<Arc>::Properties;
  using FstImpl<Arc>::SetProperties;
  using VectorFstBaseImpl<State>::NumStates;
  using VectorFstBaseImpl<State>::AddState;
  using VectorFstBaseImpl<State>::SetState;

  explicit CacheBaseImpl(C *allocator = 0)
      : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
        cache_first_state_id_(kNoStateId), cache_first_state_(0),
        cache_gc_(FLAGS_fst_default_cache_gc),  cache_size_(0),
        cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
                     FLAGS_fst_default_cache_gc_limit == 0 ?
                     FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) {
          allocator_ = allocator ? allocator : new C();
        }

  explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0)
      : cache_start_(false), nknown_states_(0),
        min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
        cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
        cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
                     opts.gc_limit : kMinCacheLimit) {
          allocator_ = allocator ? allocator : new C();
        }

  // Preserve gc parameters, but initially cache nothing.
  CacheBaseImpl(const CacheBaseImpl &impl)
    : cache_start_(false), nknown_states_(0),
      min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
      cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
      cache_limit_(impl.cache_limit_) {
        allocator_ = new C();
      }

  ~CacheBaseImpl() {
    allocator_->Free(cache_first_state_, cache_first_state_id_);
    delete allocator_;
  }

  // Gets a state from its ID; state must exist.
  const S *GetState(StateId s) const {
    if (s == cache_first_state_id_)
      return cache_first_state_;
    else
      return VectorFstBaseImpl<S>::GetState(s);
  }

  // Gets a state from its ID; state must exist.
  S *GetState(StateId s) {
    if (s == cache_first_state_id_)
      return cache_first_state_;
    else
      return VectorFstBaseImpl<S>::GetState(s);
  }

  // Gets a state from its ID; return 0 if it doesn't exist.
  const S *CheckState(StateId s) const {
    if (s == cache_first_state_id_)
      return cache_first_state_;
    else if (s < NumStates())
      return VectorFstBaseImpl<S>::GetState(s);
    else
      return 0;
  }

  // Gets a state from its ID; add it if necessary.
  S *ExtendState(StateId s) {
    if (s == cache_first_state_id_) {
      return cache_first_state_;                   // Return 1st cached state
    } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
      cache_first_state_id_ = s;                   // Remember 1st cached state
      cache_first_state_ = allocator_->Allocate(s);
      return cache_first_state_;
    } else if (cache_first_state_id_ != kNoStateId &&
               cache_first_state_->ref_count == 0) {
      // With Default allocator, the Free and Allocate will reuse the same S*.
      allocator_->Free(cache_first_state_, cache_first_state_id_);
      cache_first_state_id_ = s;
      cache_first_state_ = allocator_->Allocate(s);
      return cache_first_state_;                   // Return 1st cached state
    } else {
      while (NumStates() <= s)                     // Add state to main cache
        AddState(0);
      if (!VectorFstBaseImpl<S>::GetState(s)) {
        SetState(s, allocator_->Allocate(s));
        if (cache_first_state_id_ != kNoStateId) {  // Forget 1st cached state
          while (NumStates() <= cache_first_state_id_)
            AddState(0);
          SetState(cache_first_state_id_, cache_first_state_);
          if (cache_gc_) {
            cache_states_.push_back(cache_first_state_id_);
            cache_size_ += sizeof(S) +
                           cache_first_state_->arcs.capacity() * sizeof(Arc);
          }
          cache_limit_ = kMinCacheLimit;
          cache_first_state_id_ = kNoStateId;
          cache_first_state_ = 0;
        }
        if (cache_gc_) {
          cache_states_.push_back(s);
          cache_size_ += sizeof(S);
          if (cache_size_ > cache_limit_)
            GC(s, false);
        }
      }
      S *state = VectorFstBaseImpl<S>::GetState(s);
      return state;
    }
  }

  void SetStart(StateId s) {
    VectorFstBaseImpl<S>::SetStart(s);
    cache_start_ = true;
    if (s >= nknown_states_)
      nknown_states_ = s + 1;
  }

  void SetFinal(StateId s, Weight w) {
    S *state = ExtendState(s);
    state->final = w;
    state->flags |= kCacheFinal | kCacheRecent | kCacheModified;
  }

  // AddArc adds a single arc to state s and does incremental cache
  // book-keeping.  For efficiency, prefer PushArc and SetArcs below
  // when possible.
  void AddArc(StateId s, const Arc &arc) {
    S *state = ExtendState(s);
    state->arcs.push_back(arc);
    if (arc.ilabel == 0) {
      ++state->niepsilons;
    }
    if (arc.olabel == 0) {
      ++state->noepsilons;
    }
    const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back());
    SetProperties(AddArcProperties(Properties(), s, arc, parc));
    state->flags |= kCacheModified;
    if (cache_gc_ && s != cache_first_state_id_) {
      cache_size_ += sizeof(Arc);
      if (cache_size_ > cache_limit_)
        GC(s, false);
    }
  }

  // Adds a single arc to state s but delays cache book-keeping.
  // SetArcs must be called when all PushArc calls at a state are
  // complete.  Do not mix with calls to AddArc.
  void PushArc(StateId s, const Arc &arc) {
    S *state = ExtendState(s);
    state->arcs.push_back(arc);
  }

  // Marks arcs of state s as cached and does cache book-keeping after all
  // calls to PushArc have been completed.  Do not mix with calls to AddArc.
  void SetArcs(StateId s) {
    S *state = ExtendState(s);
    vector<Arc> &arcs = state->arcs;
    state->niepsilons = state->noepsilons = 0;
    for (size_t a = 0; a < arcs.size(); ++a) {
      const Arc &arc = arcs[a];
      if (arc.nextstate >= nknown_states_)
        nknown_states_ = arc.nextstate + 1;
      if (arc.ilabel == 0)
        ++state->niepsilons;
      if (arc.olabel == 0)
        ++state->noepsilons;
    }
    ExpandedState(s);
    state->flags |= kCacheArcs | kCacheRecent | kCacheModified;
    if (cache_gc_ && s != cache_first_state_id_) {
      cache_size_ += arcs.capacity() * sizeof(Arc);
      if (cache_size_ > cache_limit_)
        GC(s, false);
    }
  };

  void ReserveArcs(StateId s, size_t n) {
    S *state = ExtendState(s);
    state->arcs.reserve(n);
  }

  void DeleteArcs(StateId s, size_t n) {
    S *state = ExtendState(s);
    const vector<Arc> &arcs = state->arcs;
    for (size_t i = 0; i < n; ++i) {
      size_t j = arcs.size() - i - 1;
      if (arcs[j].ilabel == 0)
        --state->niepsilons;
      if (arcs[j].olabel == 0)
        --state->noepsilons;
    }
    state->arcs.resize(arcs.size() - n);
    SetProperties(DeleteArcsProperties(Properties()));
    state->flags |= kCacheModified;
  }

  void DeleteArcs(StateId s) {
    S *state = ExtendState(s);
    state->niepsilons = 0;
    state->noepsilons = 0;
    state->arcs.clear();
    SetProperties(DeleteArcsProperties(Properties()));
    state->flags |= kCacheModified;
  }

  // Is the start state cached?
  bool HasStart() const {
    if (!cache_start_ && Properties(kError))
      cache_start_ = true;
    return cache_start_;
  }

  // Is the final weight of state s cached?
  bool HasFinal(StateId s) const {
    const S *state = CheckState(s);
    if (state && state->flags & kCacheFinal) {
      state->flags |= kCacheRecent;
      return true;
    } else {
      return false;
    }
  }

  // Are arcs of state s cached?
  bool HasArcs(StateId s) const {
    const S *state = CheckState(s);
    if (state && state->flags & kCacheArcs) {
      state->flags |= kCacheRecent;
      return true;
    } else {
      return false;
    }
  }

  Weight Final(StateId s) const {
    const S *state = GetState(s);
    return state->final;
  }

  size_t NumArcs(StateId s) const {
    const S *state = GetState(s);
    return state->arcs.size();
  }

  size_t NumInputEpsilons(StateId s) const {
    const S *state = GetState(s);
    return state->niepsilons;
  }

  size_t NumOutputEpsilons(StateId s) const {
    const S *state = GetState(s);
    return state->noepsilons;
  }

  // Provides information needed for generic arc iterator.
  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
    const S *state = GetState(s);
    data->base = 0;
    data->narcs = state->arcs.size();
    data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
    data->ref_count = &(state->ref_count);
    ++(*data->ref_count);
  }

  // Number of known states.
  StateId NumKnownStates() const { return nknown_states_; }

  // Update number of known states taking in account the existence of state s.
  void UpdateNumKnownStates(StateId s) {
    if (s >= nknown_states_)
      nknown_states_ = s + 1;
  }

  // Find the mininum never-expanded state Id
  StateId MinUnexpandedState() const {
    while (min_unexpanded_state_id_ < expanded_states_.size() &&
          expanded_states_[min_unexpanded_state_id_])
      ++min_unexpanded_state_id_;
    return min_unexpanded_state_id_;
  }

  // Removes from cache_states_ and uncaches (not referenced-counted)
  // states that have not been accessed since the last GC until
  // cache_limit_/3 bytes are uncached.  If that fails to free enough,
  // recurs uncaching recently visited states as well. If still
  // unable to free enough memory, then widens cache_limit_.
  void GC(StateId current, bool free_recent) {
    if (!cache_gc_)
      return;
    VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
            << "), free recently cached = " << free_recent
            << ", cache size = " << cache_size_
            << ", cache limit = " << cache_limit_ << "\n";
    typename list<StateId>::iterator siter = cache_states_.begin();

    size_t cache_target = (2 * cache_limit_)/3 + 1;
    while (siter != cache_states_.end()) {
      StateId s = *siter;
      S* state = VectorFstBaseImpl<S>::GetState(s);
      if (cache_size_ > cache_target && state->ref_count == 0 &&
          (free_recent || !(state->flags & kCacheRecent)) && s != current) {
        cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
        allocator_->Free(state, s);
        SetState(s, 0);
        cache_states_.erase(siter++);
      } else {
        state->flags &= ~kCacheRecent;
        ++siter;
      }
    }
    if (!free_recent && cache_size_ > cache_target) {
      GC(current, true);
    } else {
      while (cache_size_ > cache_target) {
        cache_limit_ *= 2;
        cache_target *= 2;
      }
    }
    VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
            << "), free recently cached = " << free_recent
            << ", cache size = " << cache_size_
            << ", cache limit = " << cache_limit_ << "\n";
  }

  void ExpandedState(StateId s) {
    if (s < min_unexpanded_state_id_)
      return;
    while (expanded_states_.size() <= s)
      expanded_states_.push_back(false);
    expanded_states_[s] = true;
  }

  // Caching on/off switch, limit and size accessors.
  bool GetCacheGc() const { return cache_gc_; }
  size_t GetCacheLimit() const { return cache_limit_; }
  size_t GetCacheSize() const { return cache_size_; }

 private:
  static const size_t kMinCacheLimit = 8096;  // Minimum (non-zero) cache limit
  static const uint32 kCacheFinal =  0x0001;  // Final weight has been cached
  static const uint32 kCacheArcs =   0x0002;  // Arcs have been cached
  static const uint32 kCacheRecent = 0x0004;  // Mark as visited since GC

 public:
  static const uint32 kCacheModified = 0x0008;  // Mark state as modified
  static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent
                                    | kCacheModified;

 protected:
  C *allocator_;                             // used to allocate new states

 private:
  mutable bool cache_start_;                 // Is the start state cached?
  StateId nknown_states_;                    // # of known states
  vector<bool> expanded_states_;             // states that have been expanded
  mutable StateId min_unexpanded_state_id_;  // minimum never-expanded state Id
  StateId cache_first_state_id_;             // First cached state id
  S *cache_first_state_;                     // First cached state
  list<StateId> cache_states_;               // list of currently cached states
  bool cache_gc_;                            // enable GC
  size_t cache_size_;                        // # of bytes cached
  size_t cache_limit_;                       // # of bytes allowed before GC

  void operator=(const CacheBaseImpl<S> &impl);    // disallow
};

template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal;
template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs;
template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent;
template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheModified;
template <class S, class C> const size_t CacheBaseImpl<S, C>::kMinCacheLimit;

// Arcs implemented by an STL vector per state. Similar to VectorState
// but adds flags and ref count to keep track of what has been cached.
template <class A>
struct CacheState {
  typedef A Arc;
  typedef typename A::Weight Weight;
  typedef typename A::StateId StateId;

  CacheState() :  final(Weight::Zero()), flags(0), ref_count(0) {}

  void Reset() {
    flags = 0;
    ref_count = 0;
    arcs.resize(0);
  }

  Weight final;              // Final weight
  vector<A> arcs;            // Arcs represenation
  size_t niepsilons;         // # of input epsilons
  size_t noepsilons;         // # of output epsilons
  mutable uint32 flags;
  mutable int ref_count;
};

// A CacheBaseImpl with a commonly used CacheState.
template <class A>
class CacheImpl : public CacheBaseImpl< CacheState<A> > {
 public:
  typedef CacheState<A> State;

  CacheImpl() {}

  explicit CacheImpl(const CacheOptions &opts)
      : CacheBaseImpl< CacheState<A> >(opts) {}

  CacheImpl(const CacheImpl<State> &impl) : CacheBaseImpl<State>(impl) {}

 private:
  void operator=(const CacheImpl<State> &impl);    // disallow
};


// Use this to make a state iterator for a CacheBaseImpl-derived Fst,
// which must have type 'State' defined.  Note this iterator only
// returns those states reachable from the initial state, so consider
// implementing a class-specific one.
template <class F>
class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
 public:
  typedef typename F::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename F::State State;
  typedef CacheBaseImpl<State> Impl;

  CacheStateIterator(const F &fst, Impl *impl)
      : fst_(fst), impl_(impl), s_(0) {}

  bool Done() const {
    if (s_ < impl_->NumKnownStates())
      return false;
    fst_.Start();  // force start state
    if (s_ < impl_->NumKnownStates())
      return false;
    for (StateId u = impl_->MinUnexpandedState();
         u < impl_->NumKnownStates();
         u = impl_->MinUnexpandedState()) {
      // force state expansion
      ArcIterator<F> aiter(fst_, u);
      aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
      for (; !aiter.Done(); aiter.Next())
        impl_->UpdateNumKnownStates(aiter.Value().nextstate);
      impl_->ExpandedState(u);
      if (s_ < impl_->NumKnownStates())
        return false;
    }
    return true;
  }

  StateId Value() const { return s_; }

  void Next() { ++s_; }

  void Reset() { s_ = 0; }

 private:
  // This allows base class virtual access to non-virtual derived-
  // class members of the same name. It makes the derived class more
  // efficient to use but unsafe to further derive.
  virtual bool Done_() const { return Done(); }
  virtual StateId Value_() const { return Value(); }
  virtual void Next_() { Next(); }
  virtual void Reset_() { Reset(); }

  const F &fst_;
  Impl *impl_;
  StateId s_;
};


// Use this to make an arc iterator for a CacheBaseImpl-derived Fst,
// which must have types 'Arc' and 'State' defined.
template <class F,
          class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
class CacheArcIterator {
 public:
  typedef typename F::Arc Arc;
  typedef typename F::State State;
  typedef typename Arc::StateId StateId;
  typedef CacheBaseImpl<State, C> Impl;

  CacheArcIterator(Impl *impl, StateId s) : i_(0) {
    state_ = impl->ExtendState(s);
    ++state_->ref_count;
  }

  ~CacheArcIterator() { --state_->ref_count;  }

  bool Done() const { return i_ >= state_->arcs.size(); }

  const Arc& Value() const { return state_->arcs[i_]; }

  void Next() { ++i_; }

  size_t Position() const { return i_; }

  void Reset() { i_ = 0; }

  void Seek(size_t a) { i_ = a; }

  uint32 Flags() const {
    return kArcValueFlags;
  }

  void SetFlags(uint32 flags, uint32 mask) {}

 private:
  const State *state_;
  size_t i_;

  DISALLOW_COPY_AND_ASSIGN(CacheArcIterator);
};

// Use this to make a mutable arc iterator for a CacheBaseImpl-derived Fst,
// which must have types 'Arc' and 'State' defined.
template <class F,
          class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
class CacheMutableArcIterator
    : public MutableArcIteratorBase<typename F::Arc> {
 public:
  typedef typename F::State State;
  typedef typename F::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;
  typedef CacheBaseImpl<State, C> Impl;

  // You will need to call MutateCheck() in the constructor.
  CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
    state_ = impl_->ExtendState(s_);
    ++state_->ref_count;
  };

  ~CacheMutableArcIterator() {
    --state_->ref_count;
  }

  bool Done() const { return i_ >= state_->arcs.size(); }

  const Arc& Value() const { return state_->arcs[i_]; }

  void Next() { ++i_; }

  size_t Position() const { return i_; }

  void Reset() { i_ = 0; }

  void Seek(size_t a) { i_ = a; }

  void SetValue(const Arc& arc) {
    state_->flags |= CacheBaseImpl<State, C>::kCacheModified;
    uint64 properties = impl_->Properties();
    Arc& oarc = state_->arcs[i_];
    if (oarc.ilabel != oarc.olabel)
      properties &= ~kNotAcceptor;
    if (oarc.ilabel == 0) {
      --state_->niepsilons;
      properties &= ~kIEpsilons;
      if (oarc.olabel == 0)
        properties &= ~kEpsilons;
    }
    if (oarc.olabel == 0) {
      --state_->noepsilons;
      properties &= ~kOEpsilons;
    }
    if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One())
      properties &= ~kWeighted;
    oarc = arc;
    if (arc.ilabel != arc.olabel) {
      properties |= kNotAcceptor;
      properties &= ~kAcceptor;
    }
    if (arc.ilabel == 0) {
      ++state_->niepsilons;
      properties |= kIEpsilons;
      properties &= ~kNoIEpsilons;
      if (arc.olabel == 0) {
        properties |= kEpsilons;
        properties &= ~kNoEpsilons;
      }
    }
    if (arc.olabel == 0) {
      ++state_->noepsilons;
      properties |= kOEpsilons;
      properties &= ~kNoOEpsilons;
    }
    if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) {
      properties |= kWeighted;
      properties &= ~kUnweighted;
    }
    properties &= kSetArcProperties | kAcceptor | kNotAcceptor |
        kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons |
        kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted;
    impl_->SetProperties(properties);
  }

  uint32 Flags() const {
    return kArcValueFlags;
  }

  void SetFlags(uint32 f, uint32 m) {}

 private:
  virtual bool Done_() const { return Done(); }
  virtual const Arc& Value_() const { return Value(); }
  virtual void Next_() { Next(); }
  virtual size_t Position_() const { return Position(); }
  virtual void Reset_() { Reset(); }
  virtual void Seek_(size_t a) { Seek(a); }
  virtual void SetValue_(const Arc &a) { SetValue(a); }
  uint32 Flags_() const { return Flags(); }
  void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }

  size_t i_;
  StateId s_;
  Impl *impl_;
  State *state_;

  DISALLOW_COPY_AND_ASSIGN(CacheMutableArcIterator);
};

}  // namespace fst

#endif  // FST_LIB_CACHE_H__