aboutsummaryrefslogtreecommitdiff
path: root/src/test/weight_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/test/weight_test.cc')
-rw-r--r--src/test/weight_test.cc258
1 files changed, 258 insertions, 0 deletions
diff --git a/src/test/weight_test.cc b/src/test/weight_test.cc
new file mode 100644
index 0000000..54ba85d
--- /dev/null
+++ b/src/test/weight_test.cc
@@ -0,0 +1,258 @@
+// weight_test.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Regression test for Fst weights.
+
+#include <cstdlib>
+#include <ctime>
+
+#include <fst/expectation-weight.h>
+#include <fst/float-weight.h>
+#include <fst/random-weight.h>
+#include "./weight-tester.h"
+
+DEFINE_int32(seed, -1, "random seed");
+DEFINE_int32(repeat, 100000, "number of test repetitions");
+
+using fst::TropicalWeight;
+using fst::TropicalWeightGenerator;
+using fst::TropicalWeightTpl;
+using fst::TropicalWeightGenerator_;
+
+using fst::LogWeight;
+using fst::LogWeightGenerator;
+using fst::LogWeightTpl;
+using fst::LogWeightGenerator_;
+
+using fst::MinMaxWeight;
+using fst::MinMaxWeightGenerator;
+using fst::MinMaxWeightTpl;
+using fst::MinMaxWeightGenerator_;
+
+using fst::StringWeight;
+using fst::StringWeightGenerator;
+
+using fst::GallicWeight;
+using fst::GallicWeightGenerator;
+
+using fst::LexicographicWeight;
+using fst::LexicographicWeightGenerator;
+
+using fst::ProductWeight;
+using fst::ProductWeightGenerator;
+
+using fst::PowerWeight;
+using fst::PowerWeightGenerator;
+
+using fst::SignedLogWeightTpl;
+using fst::SignedLogWeightGenerator_;
+
+using fst::ExpectationWeight;
+
+using fst::SparsePowerWeight;
+using fst::SparsePowerWeightGenerator;
+
+using fst::STRING_LEFT;
+using fst::STRING_RIGHT;
+
+using fst::WeightTester;
+
+template <class T>
+void TestTemplatedWeights(int repeat, int seed) {
+ TropicalWeightGenerator_<T> tropical_generator(seed);
+ WeightTester<TropicalWeightTpl<T>, TropicalWeightGenerator_<T> >
+ tropical_tester(tropical_generator);
+ tropical_tester.Test(repeat);
+
+ LogWeightGenerator_<T> log_generator(seed);
+ WeightTester<LogWeightTpl<T>, LogWeightGenerator_<T> >
+ log_tester(log_generator);
+ log_tester.Test(repeat);
+
+ MinMaxWeightGenerator_<T> minmax_generator(seed);
+ WeightTester<MinMaxWeightTpl<T>, MinMaxWeightGenerator_<T> >
+ minmax_tester(minmax_generator);
+ minmax_tester.Test(repeat);
+
+ SignedLogWeightGenerator_<T> signedlog_generator(seed);
+ WeightTester<SignedLogWeightTpl<T>, SignedLogWeightGenerator_<T> >
+ signedlog_tester(signedlog_generator);
+ signedlog_tester.Test(repeat);
+}
+
+int main(int argc, char **argv) {
+ std::set_new_handler(FailedNewHandler);
+ SetFlags(argv[0], &argc, &argv, true);
+
+ int seed = FLAGS_seed >= 0 ? FLAGS_seed : time(0);
+ LOG(INFO) << "Seed = " << seed;
+
+ TestTemplatedWeights<float>(FLAGS_repeat, seed);
+ TestTemplatedWeights<double>(FLAGS_repeat, seed);
+ FLAGS_fst_weight_parentheses = "()";
+ TestTemplatedWeights<float>(FLAGS_repeat, seed);
+ TestTemplatedWeights<double>(FLAGS_repeat, seed);
+ FLAGS_fst_weight_parentheses = "";
+
+ // Make sure type names for templated weights are consistent
+ CHECK(TropicalWeight::Type() == "tropical");
+ CHECK(TropicalWeightTpl<double>::Type() != TropicalWeightTpl<float>::Type());
+ CHECK(LogWeight::Type() == "log");
+ CHECK(LogWeightTpl<double>::Type() != LogWeightTpl<float>::Type());
+ TropicalWeightTpl<double> w(15.0);
+ TropicalWeight tw(15.0);
+
+ StringWeightGenerator<int> left_string_generator(seed);
+ WeightTester<StringWeight<int>, StringWeightGenerator<int> >
+ left_string_tester(left_string_generator);
+ left_string_tester.Test(FLAGS_repeat);
+
+ StringWeightGenerator<int, STRING_RIGHT> right_string_generator(seed);
+ WeightTester<StringWeight<int, STRING_RIGHT>,
+ StringWeightGenerator<int, STRING_RIGHT> >
+ right_string_tester(right_string_generator);
+ right_string_tester.Test(FLAGS_repeat);
+
+ typedef GallicWeight<int, TropicalWeight> TropicalGallicWeight;
+ typedef GallicWeightGenerator<int, TropicalWeightGenerator>
+ TropicalGallicWeightGenerator;
+
+ TropicalGallicWeightGenerator tropical_gallic_generator(seed);
+ WeightTester<TropicalGallicWeight, TropicalGallicWeightGenerator>
+ tropical_gallic_tester(tropical_gallic_generator);
+ tropical_gallic_tester.Test(FLAGS_repeat);
+
+ typedef ProductWeight<TropicalWeight, TropicalWeight> TropicalProductWeight;
+ typedef ProductWeightGenerator<TropicalWeightGenerator,
+ TropicalWeightGenerator> TropicalProductWeightGenerator;
+
+ TropicalProductWeightGenerator tropical_product_generator(seed);
+ WeightTester<TropicalProductWeight, TropicalProductWeightGenerator>
+ tropical_product_weight_tester(tropical_product_generator);
+ tropical_product_weight_tester.Test(FLAGS_repeat);
+
+ typedef PowerWeight<TropicalWeight, 3> TropicalCubeWeight;
+ typedef PowerWeightGenerator<TropicalWeightGenerator, 3>
+ TropicalCubeWeightGenerator;
+
+ TropicalCubeWeightGenerator tropical_cube_generator(seed);
+ WeightTester<TropicalCubeWeight, TropicalCubeWeightGenerator>
+ tropical_cube_weight_tester(tropical_cube_generator);
+ tropical_cube_weight_tester.Test(FLAGS_repeat);
+
+ typedef ProductWeight<TropicalWeight, TropicalProductWeight>
+ SecondNestedProductWeight;
+ typedef ProductWeightGenerator<TropicalWeightGenerator,
+ TropicalProductWeightGenerator> SecondNestedProductWeightGenerator;
+
+ SecondNestedProductWeightGenerator second_nested_product_generator(seed);
+ WeightTester<SecondNestedProductWeight, SecondNestedProductWeightGenerator>
+ second_nested_product_weight_tester(second_nested_product_generator);
+ second_nested_product_weight_tester.Test(FLAGS_repeat);
+
+ // This only works with fst_weight_parentheses = "()"
+ typedef ProductWeight<TropicalProductWeight, TropicalWeight>
+ FirstNestedProductWeight;
+ typedef ProductWeightGenerator<TropicalProductWeightGenerator,
+ TropicalWeightGenerator> FirstNestedProductWeightGenerator;
+
+ FirstNestedProductWeightGenerator first_nested_product_generator(seed);
+ WeightTester<FirstNestedProductWeight, FirstNestedProductWeightGenerator>
+ first_nested_product_weight_tester(first_nested_product_generator);
+
+ typedef PowerWeight<FirstNestedProductWeight, 3> NestedProductCubeWeight;
+ typedef PowerWeightGenerator<FirstNestedProductWeightGenerator, 3>
+ NestedProductCubeWeightGenerator;
+
+ NestedProductCubeWeightGenerator nested_product_cube_generator(seed);
+ WeightTester<NestedProductCubeWeight, NestedProductCubeWeightGenerator>
+ nested_product_cube_weight_tester(nested_product_cube_generator);
+
+ typedef SparsePowerWeight<NestedProductCubeWeight,
+ size_t > SparseNestedProductCubeWeight;
+ typedef SparsePowerWeightGenerator<NestedProductCubeWeightGenerator,
+ size_t, 3> SparseNestedProductCubeWeightGenerator;
+
+ SparseNestedProductCubeWeightGenerator
+ sparse_nested_product_cube_generator(seed);
+ WeightTester<SparseNestedProductCubeWeight,
+ SparseNestedProductCubeWeightGenerator>
+ sparse_nested_product_cube_weight_tester(
+ sparse_nested_product_cube_generator);
+
+ typedef SparsePowerWeight<LogWeight, size_t > LogSparsePowerWeight;
+ typedef SparsePowerWeightGenerator<LogWeightGenerator,
+ size_t, 3> LogSparsePowerWeightGenerator;
+
+ LogSparsePowerWeightGenerator
+ log_sparse_power_weight_generator(seed);
+ WeightTester<LogSparsePowerWeight,
+ LogSparsePowerWeightGenerator>
+ log_sparse_power_weight_tester(
+ log_sparse_power_weight_generator);
+
+ typedef ExpectationWeight<LogWeight, LogWeight>
+ LogLogExpectWeight;
+ typedef ProductWeightGenerator<LogWeightGenerator, LogWeightGenerator,
+ LogLogExpectWeight> LogLogExpectWeightGenerator;
+
+ LogLogExpectWeightGenerator log_log_expect_weight_generator(seed);
+ WeightTester<LogLogExpectWeight, LogLogExpectWeightGenerator>
+ log_log_expect_weight_tester(log_log_expect_weight_generator);
+
+ typedef ExpectationWeight<LogWeight, LogSparsePowerWeight>
+ LogLogSparseExpectWeight;
+ typedef ProductWeightGenerator<
+ LogWeightGenerator,
+ LogSparsePowerWeightGenerator,
+ LogLogSparseExpectWeight> LogLogSparseExpectWeightGenerator;
+
+ LogLogSparseExpectWeightGenerator log_logsparse_expect_weight_generator(seed);
+ WeightTester<LogLogSparseExpectWeight, LogLogSparseExpectWeightGenerator>
+ log_logsparse_expect_weight_tester(log_logsparse_expect_weight_generator);
+
+ // Test all product weight I/O with parentheses
+ FLAGS_fst_weight_parentheses = "()";
+ first_nested_product_weight_tester.Test(FLAGS_repeat);
+ nested_product_cube_weight_tester.Test(FLAGS_repeat);
+ log_sparse_power_weight_tester.Test(1);
+ sparse_nested_product_cube_weight_tester.Test(1);
+ tropical_product_weight_tester.Test(5);
+ second_nested_product_weight_tester.Test(5);
+ tropical_gallic_tester.Test(5);
+ tropical_cube_weight_tester.Test(5);
+ FLAGS_fst_weight_parentheses = "";
+ log_sparse_power_weight_tester.Test(1);
+ log_log_expect_weight_tester.Test(1, false); // disables division
+ log_logsparse_expect_weight_tester.Test(1, false);
+
+ typedef LexicographicWeight<TropicalWeight, TropicalWeight>
+ TropicalLexicographicWeight;
+ typedef LexicographicWeightGenerator<TropicalWeightGenerator,
+ TropicalWeightGenerator> TropicalLexicographicWeightGenerator;
+
+ TropicalLexicographicWeightGenerator tropical_lexicographic_generator(seed);
+ WeightTester<TropicalLexicographicWeight,
+ TropicalLexicographicWeightGenerator>
+ tropical_lexicographic_tester(tropical_lexicographic_generator);
+ tropical_lexicographic_tester.Test(FLAGS_repeat);
+
+ cout << "PASS" << endl;
+
+ return 0;
+}