diff options
Diffstat (limited to 'src/test/weight_test.cc')
-rw-r--r-- | src/test/weight_test.cc | 258 |
1 files changed, 258 insertions, 0 deletions
diff --git a/src/test/weight_test.cc b/src/test/weight_test.cc new file mode 100644 index 0000000..54ba85d --- /dev/null +++ b/src/test/weight_test.cc @@ -0,0 +1,258 @@ +// weight_test.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Regression test for Fst weights. + +#include <cstdlib> +#include <ctime> + +#include <fst/expectation-weight.h> +#include <fst/float-weight.h> +#include <fst/random-weight.h> +#include "./weight-tester.h" + +DEFINE_int32(seed, -1, "random seed"); +DEFINE_int32(repeat, 100000, "number of test repetitions"); + +using fst::TropicalWeight; +using fst::TropicalWeightGenerator; +using fst::TropicalWeightTpl; +using fst::TropicalWeightGenerator_; + +using fst::LogWeight; +using fst::LogWeightGenerator; +using fst::LogWeightTpl; +using fst::LogWeightGenerator_; + +using fst::MinMaxWeight; +using fst::MinMaxWeightGenerator; +using fst::MinMaxWeightTpl; +using fst::MinMaxWeightGenerator_; + +using fst::StringWeight; +using fst::StringWeightGenerator; + +using fst::GallicWeight; +using fst::GallicWeightGenerator; + +using fst::LexicographicWeight; +using fst::LexicographicWeightGenerator; + +using fst::ProductWeight; +using fst::ProductWeightGenerator; + +using fst::PowerWeight; +using fst::PowerWeightGenerator; + +using fst::SignedLogWeightTpl; +using fst::SignedLogWeightGenerator_; + +using fst::ExpectationWeight; + +using fst::SparsePowerWeight; +using fst::SparsePowerWeightGenerator; + +using fst::STRING_LEFT; +using fst::STRING_RIGHT; + +using fst::WeightTester; + +template <class T> +void TestTemplatedWeights(int repeat, int seed) { + TropicalWeightGenerator_<T> tropical_generator(seed); + WeightTester<TropicalWeightTpl<T>, TropicalWeightGenerator_<T> > + tropical_tester(tropical_generator); + tropical_tester.Test(repeat); + + LogWeightGenerator_<T> log_generator(seed); + WeightTester<LogWeightTpl<T>, LogWeightGenerator_<T> > + log_tester(log_generator); + log_tester.Test(repeat); + + MinMaxWeightGenerator_<T> minmax_generator(seed); + WeightTester<MinMaxWeightTpl<T>, MinMaxWeightGenerator_<T> > + minmax_tester(minmax_generator); + minmax_tester.Test(repeat); + + SignedLogWeightGenerator_<T> signedlog_generator(seed); + WeightTester<SignedLogWeightTpl<T>, SignedLogWeightGenerator_<T> > + signedlog_tester(signedlog_generator); + signedlog_tester.Test(repeat); +} + +int main(int argc, char **argv) { + std::set_new_handler(FailedNewHandler); + SetFlags(argv[0], &argc, &argv, true); + + int seed = FLAGS_seed >= 0 ? FLAGS_seed : time(0); + LOG(INFO) << "Seed = " << seed; + + TestTemplatedWeights<float>(FLAGS_repeat, seed); + TestTemplatedWeights<double>(FLAGS_repeat, seed); + FLAGS_fst_weight_parentheses = "()"; + TestTemplatedWeights<float>(FLAGS_repeat, seed); + TestTemplatedWeights<double>(FLAGS_repeat, seed); + FLAGS_fst_weight_parentheses = ""; + + // Make sure type names for templated weights are consistent + CHECK(TropicalWeight::Type() == "tropical"); + CHECK(TropicalWeightTpl<double>::Type() != TropicalWeightTpl<float>::Type()); + CHECK(LogWeight::Type() == "log"); + CHECK(LogWeightTpl<double>::Type() != LogWeightTpl<float>::Type()); + TropicalWeightTpl<double> w(15.0); + TropicalWeight tw(15.0); + + StringWeightGenerator<int> left_string_generator(seed); + WeightTester<StringWeight<int>, StringWeightGenerator<int> > + left_string_tester(left_string_generator); + left_string_tester.Test(FLAGS_repeat); + + StringWeightGenerator<int, STRING_RIGHT> right_string_generator(seed); + WeightTester<StringWeight<int, STRING_RIGHT>, + StringWeightGenerator<int, STRING_RIGHT> > + right_string_tester(right_string_generator); + right_string_tester.Test(FLAGS_repeat); + + typedef GallicWeight<int, TropicalWeight> TropicalGallicWeight; + typedef GallicWeightGenerator<int, TropicalWeightGenerator> + TropicalGallicWeightGenerator; + + TropicalGallicWeightGenerator tropical_gallic_generator(seed); + WeightTester<TropicalGallicWeight, TropicalGallicWeightGenerator> + tropical_gallic_tester(tropical_gallic_generator); + tropical_gallic_tester.Test(FLAGS_repeat); + + typedef ProductWeight<TropicalWeight, TropicalWeight> TropicalProductWeight; + typedef ProductWeightGenerator<TropicalWeightGenerator, + TropicalWeightGenerator> TropicalProductWeightGenerator; + + TropicalProductWeightGenerator tropical_product_generator(seed); + WeightTester<TropicalProductWeight, TropicalProductWeightGenerator> + tropical_product_weight_tester(tropical_product_generator); + tropical_product_weight_tester.Test(FLAGS_repeat); + + typedef PowerWeight<TropicalWeight, 3> TropicalCubeWeight; + typedef PowerWeightGenerator<TropicalWeightGenerator, 3> + TropicalCubeWeightGenerator; + + TropicalCubeWeightGenerator tropical_cube_generator(seed); + WeightTester<TropicalCubeWeight, TropicalCubeWeightGenerator> + tropical_cube_weight_tester(tropical_cube_generator); + tropical_cube_weight_tester.Test(FLAGS_repeat); + + typedef ProductWeight<TropicalWeight, TropicalProductWeight> + SecondNestedProductWeight; + typedef ProductWeightGenerator<TropicalWeightGenerator, + TropicalProductWeightGenerator> SecondNestedProductWeightGenerator; + + SecondNestedProductWeightGenerator second_nested_product_generator(seed); + WeightTester<SecondNestedProductWeight, SecondNestedProductWeightGenerator> + second_nested_product_weight_tester(second_nested_product_generator); + second_nested_product_weight_tester.Test(FLAGS_repeat); + + // This only works with fst_weight_parentheses = "()" + typedef ProductWeight<TropicalProductWeight, TropicalWeight> + FirstNestedProductWeight; + typedef ProductWeightGenerator<TropicalProductWeightGenerator, + TropicalWeightGenerator> FirstNestedProductWeightGenerator; + + FirstNestedProductWeightGenerator first_nested_product_generator(seed); + WeightTester<FirstNestedProductWeight, FirstNestedProductWeightGenerator> + first_nested_product_weight_tester(first_nested_product_generator); + + typedef PowerWeight<FirstNestedProductWeight, 3> NestedProductCubeWeight; + typedef PowerWeightGenerator<FirstNestedProductWeightGenerator, 3> + NestedProductCubeWeightGenerator; + + NestedProductCubeWeightGenerator nested_product_cube_generator(seed); + WeightTester<NestedProductCubeWeight, NestedProductCubeWeightGenerator> + nested_product_cube_weight_tester(nested_product_cube_generator); + + typedef SparsePowerWeight<NestedProductCubeWeight, + size_t > SparseNestedProductCubeWeight; + typedef SparsePowerWeightGenerator<NestedProductCubeWeightGenerator, + size_t, 3> SparseNestedProductCubeWeightGenerator; + + SparseNestedProductCubeWeightGenerator + sparse_nested_product_cube_generator(seed); + WeightTester<SparseNestedProductCubeWeight, + SparseNestedProductCubeWeightGenerator> + sparse_nested_product_cube_weight_tester( + sparse_nested_product_cube_generator); + + typedef SparsePowerWeight<LogWeight, size_t > LogSparsePowerWeight; + typedef SparsePowerWeightGenerator<LogWeightGenerator, + size_t, 3> LogSparsePowerWeightGenerator; + + LogSparsePowerWeightGenerator + log_sparse_power_weight_generator(seed); + WeightTester<LogSparsePowerWeight, + LogSparsePowerWeightGenerator> + log_sparse_power_weight_tester( + log_sparse_power_weight_generator); + + typedef ExpectationWeight<LogWeight, LogWeight> + LogLogExpectWeight; + typedef ProductWeightGenerator<LogWeightGenerator, LogWeightGenerator, + LogLogExpectWeight> LogLogExpectWeightGenerator; + + LogLogExpectWeightGenerator log_log_expect_weight_generator(seed); + WeightTester<LogLogExpectWeight, LogLogExpectWeightGenerator> + log_log_expect_weight_tester(log_log_expect_weight_generator); + + typedef ExpectationWeight<LogWeight, LogSparsePowerWeight> + LogLogSparseExpectWeight; + typedef ProductWeightGenerator< + LogWeightGenerator, + LogSparsePowerWeightGenerator, + LogLogSparseExpectWeight> LogLogSparseExpectWeightGenerator; + + LogLogSparseExpectWeightGenerator log_logsparse_expect_weight_generator(seed); + WeightTester<LogLogSparseExpectWeight, LogLogSparseExpectWeightGenerator> + log_logsparse_expect_weight_tester(log_logsparse_expect_weight_generator); + + // Test all product weight I/O with parentheses + FLAGS_fst_weight_parentheses = "()"; + first_nested_product_weight_tester.Test(FLAGS_repeat); + nested_product_cube_weight_tester.Test(FLAGS_repeat); + log_sparse_power_weight_tester.Test(1); + sparse_nested_product_cube_weight_tester.Test(1); + tropical_product_weight_tester.Test(5); + second_nested_product_weight_tester.Test(5); + tropical_gallic_tester.Test(5); + tropical_cube_weight_tester.Test(5); + FLAGS_fst_weight_parentheses = ""; + log_sparse_power_weight_tester.Test(1); + log_log_expect_weight_tester.Test(1, false); // disables division + log_logsparse_expect_weight_tester.Test(1, false); + + typedef LexicographicWeight<TropicalWeight, TropicalWeight> + TropicalLexicographicWeight; + typedef LexicographicWeightGenerator<TropicalWeightGenerator, + TropicalWeightGenerator> TropicalLexicographicWeightGenerator; + + TropicalLexicographicWeightGenerator tropical_lexicographic_generator(seed); + WeightTester<TropicalLexicographicWeight, + TropicalLexicographicWeightGenerator> + tropical_lexicographic_tester(tropical_lexicographic_generator); + tropical_lexicographic_tester.Test(FLAGS_repeat); + + cout << "PASS" << endl; + + return 0; +} |