diff options
author | Xin Li <delphij@google.com> | 2018-08-06 16:50:37 -0700 |
---|---|---|
committer | Xin Li <delphij@google.com> | 2018-08-06 16:50:37 -0700 |
commit | efbea3cb0f4119e28a1e9e4ba2f42da6b603d156 (patch) | |
tree | f36e9f2e71e84c4103f4f9ef577ee7b41980345a | |
parent | 0b3ea84b83c3fa8108f9879d7f141445644fcde8 (diff) | |
parent | 02a6fb6aad9bd613b986db84f44598bb6f4b5da9 (diff) | |
download | libtextclassifier-master-cuttlefish-testing-release.tar.gz |
Merge Android Pie into masterandroid-o-mr1-iot-release-smart-display-r3android-o-mr1-iot-release-1.0.5android-o-mr1-iot-release-1.0.4android-o-mr1-iot-release-1.0.3oreo-mr1-1.2-iot-releasemaster-cuttlefish-testing-release
Bug: 112104996
Change-Id: I1dd339c1bf8396642516c6c4ed82e195601fb316
204 files changed, 15535 insertions, 11758 deletions
@@ -36,27 +36,26 @@ MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS := \ MY_LIBTEXTCLASSIFIER_CFLAGS := \ $(MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS) \ - -fvisibility=hidden + -fvisibility=hidden \ + -DLIBTEXTCLASSIFIER_UNILIB_ICU \ + -DZLIB_CONST # Only enable debug logging in userdebug/eng builds. ifneq (,$(filter userdebug eng, $(TARGET_BUILD_VARIANT))) MY_LIBTEXTCLASSIFIER_CFLAGS += -DTC_DEBUG_LOGGING=1 endif -# ------------------------ -# libtextclassifier_protos -# ------------------------ +# ----------------- +# flatbuffers +# ----------------- +# Empty static library so that other projects can include just the basic +# FlatBuffers headers as a module. include $(CLEAR_VARS) - -LOCAL_MODULE := libtextclassifier_protos - -LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS) - -LOCAL_SRC_FILES := $(call all-proto-files-under, .) -LOCAL_SHARED_LIBRARIES := libprotobuf-cpp-lite - -LOCAL_CFLAGS := $(MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS) +LOCAL_MODULE := flatbuffers +LOCAL_EXPORT_C_INCLUDES := $(LOCAL_PATH)/include +LOCAL_EXPORT_CPPFLAGS := -std=c++11 -fexceptions -Wall \ + -DFLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE include $(BUILD_STATIC_LIBRARY) @@ -67,23 +66,31 @@ include $(BUILD_STATIC_LIBRARY) include $(CLEAR_VARS) LOCAL_MODULE := libtextclassifier -proto_sources_dir := $(generated_sources_dir) - LOCAL_CPP_EXTENSION := .cc LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS) LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS) -LOCAL_SRC_FILES := $(filter-out tests/% %_test.cc,$(call all-subdir-cpp-files)) -LOCAL_C_INCLUDES += $(proto_sources_dir)/proto/external/libtextclassifier +LOCAL_SRC_FILES := $(filter-out tests/% %_test.cc test-util.%,$(call all-subdir-cpp-files)) + +LOCAL_C_INCLUDES := $(TOP)/external/zlib +LOCAL_C_INCLUDES += $(TOP)/external/tensorflow +LOCAL_C_INCLUDES += $(TOP)/external/flatbuffers/include -LOCAL_STATIC_LIBRARIES += libtextclassifier_protos -LOCAL_SHARED_LIBRARIES += libprotobuf-cpp-lite LOCAL_SHARED_LIBRARIES += liblog -LOCAL_SHARED_LIBRARIES += libicuuc libicui18n -LOCAL_REQUIRED_MODULES := textclassifier.smartselection.en.model +LOCAL_SHARED_LIBRARIES += libicuuc +LOCAL_SHARED_LIBRARIES += libicui18n +LOCAL_SHARED_LIBRARIES += libtflite +LOCAL_SHARED_LIBRARIES += libz + +LOCAL_STATIC_LIBRARIES += flatbuffers + +LOCAL_REQUIRED_MODULES := textclassifier.en.model +LOCAL_REQUIRED_MODULES += textclassifier.universal.model LOCAL_ADDITIONAL_DEPENDENCIES += $(LOCAL_PATH)/jni.lds LOCAL_LDFLAGS += -Wl,-version-script=$(LOCAL_PATH)/jni.lds +LOCAL_CPPFLAGS_32 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\"" +LOCAL_CPPFLAGS_64 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\"" include $(BUILD_SHARED_LIBRARY) @@ -101,162 +108,45 @@ LOCAL_CPP_EXTENSION := .cc LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS) LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS) -LOCAL_TEST_DATA := $(call find-test-data-in-subdirs, $(LOCAL_PATH), *, tests/testdata) +LOCAL_TEST_DATA := $(call find-test-data-in-subdirs, $(LOCAL_PATH), *, test_data) -LOCAL_CPPFLAGS_32 += -DTEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/tests/testdata/\"" -LOCAL_CPPFLAGS_64 += -DTEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/tests/testdata/\"" +LOCAL_CPPFLAGS_32 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\"" +LOCAL_CPPFLAGS_64 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\"" LOCAL_SRC_FILES := $(call all-subdir-cpp-files) -LOCAL_C_INCLUDES += $(proto_sources_dir)/proto/external/libtextclassifier -LOCAL_STATIC_LIBRARIES += libtextclassifier_protos libgmock -LOCAL_SHARED_LIBRARIES += libprotobuf-cpp-lite -LOCAL_SHARED_LIBRARIES += liblog -LOCAL_SHARED_LIBRARIES += libicuuc libicui18n +LOCAL_C_INCLUDES := $(TOP)/external/zlib +LOCAL_C_INCLUDES += $(TOP)/external/tensorflow +LOCAL_C_INCLUDES += $(TOP)/external/flatbuffers/include -include $(BUILD_NATIVE_TEST) +LOCAL_STATIC_LIBRARIES += libgmock +LOCAL_SHARED_LIBRARIES += liblog +LOCAL_SHARED_LIBRARIES += libicuuc +LOCAL_SHARED_LIBRARIES += libicui18n +LOCAL_SHARED_LIBRARIES += libtflite +LOCAL_SHARED_LIBRARIES += libz -# ------------ -# LangId model -# ------------ +LOCAL_STATIC_LIBRARIES += flatbuffers -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.langid.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.langid.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) +include $(BUILD_NATIVE_TEST) # ---------------------- # Smart Selection models # ---------------------- include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.ar.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.ar.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.de.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.de.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.en.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.en.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.es.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.es.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.fr.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.fr.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.it.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.it.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.ja.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.ja.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.ko.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.ko.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.nl.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.nl.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.pl.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.pl.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.pt.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.pt.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.ru.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.ru.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.th.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.th.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.tr.model -LOCAL_MODULE_CLASS := ETC -LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.tr.model -LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier -include $(BUILD_PREBUILT) - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.zh-Hant.model +LOCAL_MODULE := textclassifier.en.model LOCAL_MODULE_CLASS := ETC LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.zh-Hant.model +LOCAL_SRC_FILES := ./models/textclassifier.en.model LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier include $(BUILD_PREBUILT) include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.zh.model +LOCAL_MODULE := textclassifier.universal.model LOCAL_MODULE_CLASS := ETC LOCAL_MODULE_OWNER := google -LOCAL_SRC_FILES := ./models/textclassifier.smartselection.zh.model +LOCAL_SRC_FILES := ./models/textclassifier.universal.model LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier include $(BUILD_PREBUILT) @@ -265,10 +155,7 @@ include $(BUILD_PREBUILT) # ----------------------- include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.smartselection.bundle1 -LOCAL_REQUIRED_MODULES := textclassifier.smartselection.en.model -LOCAL_REQUIRED_MODULES += textclassifier.smartselection.es.model -LOCAL_REQUIRED_MODULES += textclassifier.smartselection.de.model -LOCAL_REQUIRED_MODULES += textclassifier.smartselection.fr.model +LOCAL_MODULE := textclassifier.bundle1 +LOCAL_REQUIRED_MODULES := textclassifier.en.model LOCAL_CFLAGS := $(MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS) include $(BUILD_STATIC_LIBRARY) diff --git a/cached-features.cc b/cached-features.cc new file mode 100644 index 0000000..2a46780 --- /dev/null +++ b/cached-features.cc @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cached-features.h" + +#include "tensor-view.h" +#include "util/base/logging.h" + +namespace libtextclassifier2 { + +namespace { + +int CalculateOutputFeaturesSize(const FeatureProcessorOptions* options, + int feature_vector_size) { + const bool bounds_sensitive_enabled = + options->bounds_sensitive_features() && + options->bounds_sensitive_features()->enabled(); + + int num_extracted_tokens = 0; + if (bounds_sensitive_enabled) { + const FeatureProcessorOptions_::BoundsSensitiveFeatures* config = + options->bounds_sensitive_features(); + num_extracted_tokens += config->num_tokens_before(); + num_extracted_tokens += config->num_tokens_inside_left(); + num_extracted_tokens += config->num_tokens_inside_right(); + num_extracted_tokens += config->num_tokens_after(); + if (config->include_inside_bag()) { + ++num_extracted_tokens; + } + } else { + num_extracted_tokens = 2 * options->context_size() + 1; + } + + int output_features_size = num_extracted_tokens * feature_vector_size; + + if (bounds_sensitive_enabled && + options->bounds_sensitive_features()->include_inside_length()) { + ++output_features_size; + } + + return output_features_size; +} + +} // namespace + +std::unique_ptr<CachedFeatures> CachedFeatures::Create( + const TokenSpan& extraction_span, + std::unique_ptr<std::vector<float>> features, + std::unique_ptr<std::vector<float>> padding_features, + const FeatureProcessorOptions* options, int feature_vector_size) { + const int min_feature_version = + options->bounds_sensitive_features() && + options->bounds_sensitive_features()->enabled() + ? 2 + : 1; + if (options->feature_version() < min_feature_version) { + TC_LOG(ERROR) << "Unsupported feature version."; + return nullptr; + } + + std::unique_ptr<CachedFeatures> cached_features(new CachedFeatures()); + cached_features->extraction_span_ = extraction_span; + cached_features->features_ = std::move(features); + cached_features->padding_features_ = std::move(padding_features); + cached_features->options_ = options; + + cached_features->output_features_size_ = + CalculateOutputFeaturesSize(options, feature_vector_size); + + return cached_features; +} + +void CachedFeatures::AppendClickContextFeaturesForClick( + int click_pos, std::vector<float>* output_features) const { + click_pos -= extraction_span_.first; + + AppendFeaturesInternal( + /*intended_span=*/ExpandTokenSpan(SingleTokenSpan(click_pos), + options_->context_size(), + options_->context_size()), + /*read_mask_span=*/{0, TokenSpanSize(extraction_span_)}, output_features); +} + +void CachedFeatures::AppendBoundsSensitiveFeaturesForSpan( + TokenSpan selected_span, std::vector<float>* output_features) const { + const FeatureProcessorOptions_::BoundsSensitiveFeatures* config = + options_->bounds_sensitive_features(); + + selected_span.first -= extraction_span_.first; + selected_span.second -= extraction_span_.first; + + // Append the features for tokens around the left bound. Masks out tokens + // after the right bound, so that if num_tokens_inside_left goes past it, + // padding tokens will be used. + AppendFeaturesInternal( + /*intended_span=*/{selected_span.first - config->num_tokens_before(), + selected_span.first + + config->num_tokens_inside_left()}, + /*read_mask_span=*/{0, selected_span.second}, output_features); + + // Append the features for tokens around the right bound. Masks out tokens + // before the left bound, so that if num_tokens_inside_right goes past it, + // padding tokens will be used. + AppendFeaturesInternal( + /*intended_span=*/{selected_span.second - + config->num_tokens_inside_right(), + selected_span.second + config->num_tokens_after()}, + /*read_mask_span=*/{selected_span.first, TokenSpanSize(extraction_span_)}, + output_features); + + if (config->include_inside_bag()) { + AppendBagFeatures(selected_span, output_features); + } + + if (config->include_inside_length()) { + output_features->push_back( + static_cast<float>(TokenSpanSize(selected_span))); + } +} + +void CachedFeatures::AppendFeaturesInternal( + const TokenSpan& intended_span, const TokenSpan& read_mask_span, + std::vector<float>* output_features) const { + const TokenSpan copy_span = + IntersectTokenSpans(intended_span, read_mask_span); + for (int i = intended_span.first; i < copy_span.first; ++i) { + AppendPaddingFeatures(output_features); + } + output_features->insert( + output_features->end(), + features_->begin() + copy_span.first * NumFeaturesPerToken(), + features_->begin() + copy_span.second * NumFeaturesPerToken()); + for (int i = copy_span.second; i < intended_span.second; ++i) { + AppendPaddingFeatures(output_features); + } +} + +void CachedFeatures::AppendPaddingFeatures( + std::vector<float>* output_features) const { + output_features->insert(output_features->end(), padding_features_->begin(), + padding_features_->end()); +} + +void CachedFeatures::AppendBagFeatures( + const TokenSpan& bag_span, std::vector<float>* output_features) const { + const int offset = output_features->size(); + output_features->resize(output_features->size() + NumFeaturesPerToken()); + for (int i = bag_span.first; i < bag_span.second; ++i) { + for (int j = 0; j < NumFeaturesPerToken(); ++j) { + (*output_features)[offset + j] += + (*features_)[i * NumFeaturesPerToken() + j] / TokenSpanSize(bag_span); + } + } +} + +int CachedFeatures::NumFeaturesPerToken() const { + return padding_features_->size(); +} + +} // namespace libtextclassifier2 diff --git a/cached-features.h b/cached-features.h new file mode 100644 index 0000000..0224d86 --- /dev/null +++ b/cached-features.h @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_CACHED_FEATURES_H_ +#define LIBTEXTCLASSIFIER_CACHED_FEATURES_H_ + +#include <memory> +#include <vector> + +#include "model-executor.h" +#include "model_generated.h" +#include "types.h" + +namespace libtextclassifier2 { + +// Holds state for extracting features across multiple calls and reusing them. +// Assumes that features for each Token are independent. +class CachedFeatures { + public: + static std::unique_ptr<CachedFeatures> Create( + const TokenSpan& extraction_span, + std::unique_ptr<std::vector<float>> features, + std::unique_ptr<std::vector<float>> padding_features, + const FeatureProcessorOptions* options, int feature_vector_size); + + // Appends the click context features for the given click position to + // 'output_features'. + void AppendClickContextFeaturesForClick( + int click_pos, std::vector<float>* output_features) const; + + // Appends the bounds-sensitive features for the given token span to + // 'output_features'. + void AppendBoundsSensitiveFeaturesForSpan( + TokenSpan selected_span, std::vector<float>* output_features) const; + + // Returns number of features that 'AppendFeaturesForSpan' appends. + int OutputFeaturesSize() const { return output_features_size_; } + + private: + CachedFeatures() {} + + // Appends token features to the output. The intended_span specifies which + // tokens' features should be used in principle. The read_mask_span restricts + // which tokens are actually read. For tokens outside of the read_mask_span, + // padding tokens are used instead. + void AppendFeaturesInternal(const TokenSpan& intended_span, + const TokenSpan& read_mask_span, + std::vector<float>* output_features) const; + + // Appends features of one padding token to the output. + void AppendPaddingFeatures(std::vector<float>* output_features) const; + + // Appends the features of tokens from the given span to the output. The + // features are averaged so that the appended features have the size + // corresponding to one token. + void AppendBagFeatures(const TokenSpan& bag_span, + std::vector<float>* output_features) const; + + int NumFeaturesPerToken() const; + + TokenSpan extraction_span_; + const FeatureProcessorOptions* options_; + int output_features_size_; + std::unique_ptr<std::vector<float>> features_; + std::unique_ptr<std::vector<float>> padding_features_; +}; + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_CACHED_FEATURES_H_ diff --git a/cached-features_test.cc b/cached-features_test.cc new file mode 100644 index 0000000..f064a63 --- /dev/null +++ b/cached-features_test.cc @@ -0,0 +1,157 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cached-features.h" + +#include "model-executor.h" +#include "tensor-view.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::ElementsAreArray; +using testing::FloatEq; +using testing::Matcher; + +namespace libtextclassifier2 { +namespace { + +Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) { + std::vector<Matcher<float>> matchers; + for (const float value : values) { + matchers.push_back(FloatEq(value)); + } + return ElementsAreArray(matchers); +} + +std::unique_ptr<std::vector<float>> MakeFeatures(int num_tokens) { + std::unique_ptr<std::vector<float>> features(new std::vector<float>()); + for (int i = 1; i <= num_tokens; ++i) { + features->push_back(i * 11.0f); + features->push_back(-i * 11.0f); + features->push_back(i * 0.1f); + } + return features; +} + +std::vector<float> GetCachedClickContextFeatures( + const CachedFeatures& cached_features, int click_pos) { + std::vector<float> output_features; + cached_features.AppendClickContextFeaturesForClick(click_pos, + &output_features); + return output_features; +} + +std::vector<float> GetCachedBoundsSensitiveFeatures( + const CachedFeatures& cached_features, TokenSpan selected_span) { + std::vector<float> output_features; + cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span, + &output_features); + return output_features; +} + +TEST(CachedFeaturesTest, ClickContext) { + FeatureProcessorOptionsT options; + options.context_size = 2; + options.feature_version = 1; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(CreateFeatureProcessorOptions(builder, &options)); + flatbuffers::DetachedBuffer options_fb = builder.Release(); + + std::unique_ptr<std::vector<float>> features = MakeFeatures(9); + std::unique_ptr<std::vector<float>> padding_features( + new std::vector<float>{112233.0, -112233.0, 321.0}); + + const std::unique_ptr<CachedFeatures> cached_features = + CachedFeatures::Create( + {3, 10}, std::move(features), std::move(padding_features), + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + /*feature_vector_size=*/3); + ASSERT_TRUE(cached_features); + + EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5), + ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0, + 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5})); + + EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6), + ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0, + 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6})); + + EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7), + ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, + 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7})); +} + +TEST(CachedFeaturesTest, BoundsSensitive) { + std::unique_ptr<FeatureProcessorOptions_::BoundsSensitiveFeaturesT> config( + new FeatureProcessorOptions_::BoundsSensitiveFeaturesT()); + config->enabled = true; + config->num_tokens_before = 2; + config->num_tokens_inside_left = 2; + config->num_tokens_inside_right = 2; + config->num_tokens_after = 2; + config->include_inside_bag = true; + config->include_inside_length = true; + FeatureProcessorOptionsT options; + options.bounds_sensitive_features = std::move(config); + options.feature_version = 2; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(CreateFeatureProcessorOptions(builder, &options)); + flatbuffers::DetachedBuffer options_fb = builder.Release(); + + std::unique_ptr<std::vector<float>> features = MakeFeatures(9); + std::unique_ptr<std::vector<float>> padding_features( + new std::vector<float>{112233.0, -112233.0, 321.0}); + + const std::unique_ptr<CachedFeatures> cached_features = + CachedFeatures::Create( + {3, 9}, std::move(features), std::move(padding_features), + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + /*feature_vector_size=*/3); + ASSERT_TRUE(cached_features); + + EXPECT_THAT( + GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}), + ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, + -33.0, 0.3, 44.0, -44.0, 0.4, 44.0, -44.0, + 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, + 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, 3.0})); + + EXPECT_THAT( + GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}), + ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, + -33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0, + 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5, + 66.0, -66.0, 0.6, 38.5, -38.5, 0.35, 2.0})); + + EXPECT_THAT( + GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}), + ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, + -44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0, + 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, + 112233.0, -112233.0, 321.0, 49.5, -49.5, 0.45, 2.0})); + + EXPECT_THAT( + GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}), + ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, + 44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0, + 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, + 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, + 44.0, -44.0, 0.4, 1.0})); +} + +} // namespace +} // namespace libtextclassifier2 diff --git a/common/algorithm.h b/common/algorithm.h deleted file mode 100644 index 365eec9..0000000 --- a/common/algorithm.h +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Generic utils similar to those from the C++ header <algorithm>. - -#ifndef LIBTEXTCLASSIFIER_COMMON_ALGORITHM_H_ -#define LIBTEXTCLASSIFIER_COMMON_ALGORITHM_H_ - -#include <algorithm> -#include <vector> - -namespace libtextclassifier { -namespace nlp_core { - -// Returns index of max element from the vector |elements|. Returns 0 if -// |elements| is empty. T should be a type that can be compared by operator<. -template<typename T> -inline int GetArgMax(const std::vector<T> &elements) { - return std::distance( - elements.begin(), - std::max_element(elements.begin(), elements.end())); -} - -// Returns index of min element from the vector |elements|. Returns 0 if -// |elements| is empty. T should be a type that can be compared by operator<. -template<typename T> -inline int GetArgMin(const std::vector<T> &elements) { - return std::distance( - elements.begin(), - std::min_element(elements.begin(), elements.end())); -} - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_ALGORITHM_H_ diff --git a/common/embedding-feature-extractor.cc b/common/embedding-feature-extractor.cc deleted file mode 100644 index 254af45..0000000 --- a/common/embedding-feature-extractor.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/embedding-feature-extractor.h" - -#include <stddef.h> - -#include <vector> - -#include "common/feature-extractor.h" -#include "common/feature-types.h" -#include "common/task-context.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" -#include "util/strings/numbers.h" -#include "util/strings/split.h" - -namespace libtextclassifier { -namespace nlp_core { - -bool GenericEmbeddingFeatureExtractor::Init(TaskContext *context) { - // Don't use version to determine how to get feature FML. - const std::string features = context->Get(GetParamName("features"), ""); - TC_LOG(INFO) << "Features: " << features; - - const std::string embedding_names = - context->Get(GetParamName("embedding_names"), ""); - TC_LOG(INFO) << "Embedding names: " << embedding_names; - - const std::string embedding_dims = - context->Get(GetParamName("embedding_dims"), ""); - TC_LOG(INFO) << "Embedding dims: " << embedding_dims; - - embedding_fml_ = strings::Split(features, ';'); - embedding_names_ = strings::Split(embedding_names, ';'); - for (const std::string &dim : strings::Split(embedding_dims, ';')) { - int32 parsed_dim = 0; - if (!ParseInt32(dim.c_str(), &parsed_dim)) { - TC_LOG(ERROR) << "Unable to parse dim " << dim; - return false; - } - embedding_dims_.push_back(parsed_dim); - } - if ((embedding_fml_.size() != embedding_names_.size()) || - (embedding_fml_.size() != embedding_dims_.size())) { - TC_LOG(ERROR) << "Mismatch: #fml specs = " << embedding_fml_.size() - << "; #names = " << embedding_names_.size() - << "; #dims = " << embedding_dims_.size(); - return false; - } - return true; -} - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/embedding-feature-extractor.h b/common/embedding-feature-extractor.h deleted file mode 100644 index 0efd0d2..0000000 --- a/common/embedding-feature-extractor.h +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ -#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ - -#include <memory> -#include <string> -#include <vector> - -#include "common/feature-extractor.h" -#include "common/task-context.h" -#include "common/workspace.h" -#include "util/base/logging.h" -#include "util/base/macros.h" - -namespace libtextclassifier { -namespace nlp_core { - -// An EmbeddingFeatureExtractor manages the extraction of features for -// embedding-based models. It wraps a sequence of underlying classes of feature -// extractors, along with associated predicate maps. Each class of feature -// extractors is associated with a name, e.g., "words", "labels", "tags". -// -// The class is split between a generic abstract version, -// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the -// signature of the ExtractFeatures method) and a typed version. -// -// The predicate maps must be initialized before use: they can be loaded using -// Read() or updated via UpdateMapsForExample. -class GenericEmbeddingFeatureExtractor { - public: - GenericEmbeddingFeatureExtractor() {} - virtual ~GenericEmbeddingFeatureExtractor() {} - - // Get the prefix std::string to put in front of all arguments, so they don't - // conflict with other embedding models. - virtual const std::string ArgPrefix() const = 0; - - // Initializes predicate maps and embedding space names that are common for - // all embedding-based feature extractors. - virtual bool Init(TaskContext *context); - - // Requests workspace for the underlying feature extractors. This is - // implemented in the typed class. - virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0; - - // Returns number of embedding spaces. - int NumEmbeddings() const { return embedding_dims_.size(); } - - // Number of predicates for the embedding at a given index (vocabulary size). - // Returns -1 if index is out of bounds. - int EmbeddingSize(int index) const { - const GenericFeatureExtractor *extractor = generic_feature_extractor(index); - return (extractor == nullptr) ? -1 : extractor->GetDomainSize(); - } - - // Returns the dimensionality of the embedding space. - int EmbeddingDims(int index) const { return embedding_dims_[index]; } - - // Accessor for embedding dims (dimensions of the embedding spaces). - const std::vector<int> &embedding_dims() const { return embedding_dims_; } - - const std::vector<std::string> &embedding_fml() const { - return embedding_fml_; - } - - // Get parameter name by concatenating the prefix and the original name. - std::string GetParamName(const std::string ¶m_name) const { - std::string full_name = ArgPrefix(); - full_name.push_back('_'); - full_name.append(param_name); - return full_name; - } - - protected: - // Provides the generic class with access to the templated extractors. This is - // used to get the type information out of the feature extractor without - // knowing the specific calling arguments of the extractor itself. - // Returns nullptr for an out-of-bounds idx. - virtual const GenericFeatureExtractor *generic_feature_extractor( - int idx) const = 0; - - private: - // Embedding space names for parameter sharing. - std::vector<std::string> embedding_names_; - - // FML strings for each feature extractor. - std::vector<std::string> embedding_fml_; - - // Size of each of the embedding spaces (maximum predicate id). - std::vector<int> embedding_sizes_; - - // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.) - std::vector<int> embedding_dims_; - - TC_DISALLOW_COPY_AND_ASSIGN(GenericEmbeddingFeatureExtractor); -}; - -// Templated, object-specific implementation of the -// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ, -// ARGS...> class that has the appropriate FeatureTraits() to ensure that -// locator type features work. -// -// Note: for backwards compatibility purposes, this always reads the FML spec -// from "<prefix>_features". -template <class EXTRACTOR, class OBJ, class... ARGS> -class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor { - public: - // Initializes all predicate maps, feature extractors, etc. - bool Init(TaskContext *context) override { - if (!GenericEmbeddingFeatureExtractor::Init(context)) { - return false; - } - feature_extractors_.resize(embedding_fml().size()); - for (int i = 0; i < embedding_fml().size(); ++i) { - feature_extractors_[i].reset(new EXTRACTOR()); - if (!feature_extractors_[i]->Parse(embedding_fml()[i])) { - return false; - } - if (!feature_extractors_[i]->Setup(context)) { - return false; - } - } - for (auto &feature_extractor : feature_extractors_) { - if (!feature_extractor->Init(context)) { - return false; - } - } - return true; - } - - // Requests workspaces from the registry. Must be called after Init(), and - // before Preprocess(). - void RequestWorkspaces(WorkspaceRegistry *registry) override { - for (auto &feature_extractor : feature_extractors_) { - feature_extractor->RequestWorkspaces(registry); - } - } - - // Must be called on the object one state for each sentence, before any - // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures). - void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const { - for (auto &feature_extractor : feature_extractors_) { - feature_extractor->Preprocess(workspaces, obj); - } - } - - // Extracts features using the extractors. Note that features must already - // be initialized to the correct number of feature extractors. No predicate - // mapping is applied. - void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj, - ARGS... args, - std::vector<FeatureVector> *features) const { - TC_DCHECK(features != nullptr); - TC_DCHECK_EQ(features->size(), feature_extractors_.size()); - for (int i = 0; i < feature_extractors_.size(); ++i) { - (*features)[i].clear(); - feature_extractors_[i]->ExtractFeatures(workspaces, obj, args..., - &(*features)[i]); - } - } - - protected: - // Provides generic access to the feature extractors. - const GenericFeatureExtractor *generic_feature_extractor( - int idx) const override { - if ((idx < 0) || (idx >= feature_extractors_.size())) { - TC_LOG(ERROR) << "Out of bounds index " << idx; - TC_DCHECK(false); // Crash in debug mode. - return nullptr; - } - return feature_extractors_[idx].get(); - } - - private: - // Templated feature extractor class. - std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_; -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ diff --git a/common/embedding-feature-extractor_test.cc b/common/embedding-feature-extractor_test.cc deleted file mode 100644 index c5ed627..0000000 --- a/common/embedding-feature-extractor_test.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/embedding-feature-extractor.h" - -#include "lang_id/language-identifier-features.h" -#include "lang_id/light-sentence-features.h" -#include "lang_id/light-sentence.h" -#include "lang_id/relevant-script-feature.h" -#include "gtest/gtest.h" - -namespace libtextclassifier { -namespace nlp_core { - -class EmbeddingFeatureExtractorTest : public ::testing::Test { - public: - void SetUp() override { - // Make sure all relevant features are registered: - lang_id::ContinuousBagOfNgramsFunction::RegisterClass(); - lang_id::RelevantScriptFeature::RegisterClass(); - } -}; - -// Specialization of EmbeddingFeatureExtractor that extracts from LightSentence. -class TestEmbeddingFeatureExtractor - : public EmbeddingFeatureExtractor<lang_id::LightSentenceExtractor, - lang_id::LightSentence> { - public: - const std::string ArgPrefix() const override { return "test"; } -}; - -TEST_F(EmbeddingFeatureExtractorTest, NoEmbeddingSpaces) { - TaskContext context; - context.SetParameter("test_features", ""); - context.SetParameter("test_embedding_names", ""); - context.SetParameter("test_embedding_dims", ""); - TestEmbeddingFeatureExtractor tefe; - ASSERT_TRUE(tefe.Init(&context)); - EXPECT_EQ(tefe.NumEmbeddings(), 0); -} - -TEST_F(EmbeddingFeatureExtractorTest, GoodSpec) { - TaskContext context; - const std::string spec = - "continuous-bag-of-ngrams(id_dim=5000,size=3);" - "continuous-bag-of-ngrams(id_dim=7000,size=4)"; - context.SetParameter("test_features", spec); - context.SetParameter("test_embedding_names", "trigram;quadgram"); - context.SetParameter("test_embedding_dims", "16;24"); - TestEmbeddingFeatureExtractor tefe; - ASSERT_TRUE(tefe.Init(&context)); - EXPECT_EQ(tefe.NumEmbeddings(), 2); - EXPECT_EQ(tefe.EmbeddingSize(0), 5000); - EXPECT_EQ(tefe.EmbeddingDims(0), 16); - EXPECT_EQ(tefe.EmbeddingSize(1), 7000); - EXPECT_EQ(tefe.EmbeddingDims(1), 24); -} - -TEST_F(EmbeddingFeatureExtractorTest, MissmatchFmlVsNames) { - TaskContext context; - const std::string spec = - "continuous-bag-of-ngrams(id_dim=5000,size=3);" - "continuous-bag-of-ngrams(id_dim=7000,size=4)"; - context.SetParameter("test_features", spec); - context.SetParameter("test_embedding_names", "trigram"); - context.SetParameter("test_embedding_dims", "16;16"); - TestEmbeddingFeatureExtractor tefe; - ASSERT_FALSE(tefe.Init(&context)); -} - -TEST_F(EmbeddingFeatureExtractorTest, MissmatchFmlVsDims) { - TaskContext context; - const std::string spec = - "continuous-bag-of-ngrams(id_dim=5000,size=3);" - "continuous-bag-of-ngrams(id_dim=7000,size=4)"; - context.SetParameter("test_features", spec); - context.SetParameter("test_embedding_names", "trigram;quadgram"); - context.SetParameter("test_embedding_dims", "16;16;32"); - TestEmbeddingFeatureExtractor tefe; - ASSERT_FALSE(tefe.Init(&context)); -} - -TEST_F(EmbeddingFeatureExtractorTest, BrokenSpec) { - TaskContext context; - const std::string spec = - "continuous-bag-of-ngrams(id_dim=5000;" - "continuous-bag-of-ngrams(id_dim=7000,size=4)"; - context.SetParameter("test_features", spec); - context.SetParameter("test_embedding_names", "trigram;quadgram"); - context.SetParameter("test_embedding_dims", "16;16"); - TestEmbeddingFeatureExtractor tefe; - ASSERT_FALSE(tefe.Init(&context)); -} - -TEST_F(EmbeddingFeatureExtractorTest, MissingFeature) { - TaskContext context; - const std::string spec = - "continuous-bag-of-ngrams(id_dim=5000,size=3);" - "no-such-feature"; - context.SetParameter("test_features", spec); - context.SetParameter("test_embedding_names", "trigram;foo"); - context.SetParameter("test_embedding_dims", "16;16"); - TestEmbeddingFeatureExtractor tefe; - ASSERT_FALSE(tefe.Init(&context)); -} - -TEST_F(EmbeddingFeatureExtractorTest, MultipleFeatures) { - TaskContext context; - const std::string spec = - "continuous-bag-of-ngrams(id_dim=1000,size=3);" - "continuous-bag-of-relevant-scripts"; - context.SetParameter("test_features", spec); - context.SetParameter("test_embedding_names", "trigram;script"); - context.SetParameter("test_embedding_dims", "8;16"); - TestEmbeddingFeatureExtractor tefe; - ASSERT_TRUE(tefe.Init(&context)); - EXPECT_EQ(tefe.NumEmbeddings(), 2); - EXPECT_EQ(tefe.EmbeddingSize(0), 1000); - EXPECT_EQ(tefe.EmbeddingDims(0), 8); - - // continuous-bag-of-relevant-scripts has its own hard-wired vocabulary size. - // We don't want this test to depend on that value; we just check it's bigger - // than 0. - EXPECT_GT(tefe.EmbeddingSize(1), 0); - EXPECT_EQ(tefe.EmbeddingDims(1), 16); -} - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/embedding-network-package.proto b/common/embedding-network-package.proto deleted file mode 100644 index 54d47e6..0000000 --- a/common/embedding-network-package.proto +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (C) 2017 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file defines TaskSpec as an extension to EmbeddingNetworkProto. The -// definition is done here rather than directly in the imported protos to keep -// the different messages as independent as possible. - -syntax = "proto2"; -option optimize_for = LITE_RUNTIME; - -import "external/libtextclassifier/common/task-spec.proto"; -import "external/libtextclassifier/common/embedding-network.proto"; - -package libtextclassifier.nlp_core; - -extend EmbeddingNetworkProto { - optional TaskSpec task_spec_in_embedding_network_proto = 129692954; -} diff --git a/common/embedding-network-params-from-proto.h b/common/embedding-network-params-from-proto.h deleted file mode 100644 index 2f2c429..0000000 --- a/common/embedding-network-params-from-proto.h +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_ -#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_ - -#include <algorithm> -#include <memory> -#include <string> -#include <utility> -#include <vector> - -#include "common/embedding-network-package.pb.h" -#include "common/embedding-network-params.h" -#include "common/embedding-network.pb.h" -#include "common/float16.h" -#include "common/little-endian-data.h" -#include "common/task-context.h" -#include "common/task-spec.pb.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" - -namespace libtextclassifier { -namespace nlp_core { - -// A wrapper class that owns and exposes an EmbeddingNetworkProto message via -// the EmbeddingNetworkParams interface. -// -// The EmbeddingNetworkParams interface encapsulates the weight matrices of the -// embeddings, hidden and softmax layers as transposed versions of their -// counterparts in the original EmbeddingNetworkProto. The matrices in the proto -// passed to this class' constructor must likewise already have been transposed. -// See embedding-network-params.h for details. -class EmbeddingNetworkParamsFromProto : public EmbeddingNetworkParams { - public: - // Constructor that takes ownership of the provided proto. See class-comment - // for the requirements that certain weight matrices must satisfy. - explicit EmbeddingNetworkParamsFromProto( - std::unique_ptr<EmbeddingNetworkProto> proto) - : proto_(std::move(proto)) { - valid_ = true; - - // Initialize these vectors to have the required number of elements - // regardless of quantization status. This is to support the unlikely case - // where only some embeddings are quantized, along with the fact that - // EmbeddingNetworkParams interface accesses them by index. - embeddings_quant_scales_.resize(proto_->embeddings_size()); - embeddings_quant_weights_.resize(proto_->embeddings_size()); - for (int i = 0; i < proto_->embeddings_size(); ++i) { - MatrixParams *embedding = proto_->mutable_embeddings()->Mutable(i); - if (!embedding->is_quantized()) { - continue; - } - - bool success = FillVectorFromDataBytesInLittleEndian( - embedding->bytes_for_quantized_values(), - embedding->rows() * embedding->cols(), - &(embeddings_quant_weights_[i])); - if (!success) { - TC_LOG(ERROR) << "Problem decoding quant_weights for embeddings #" << i; - valid_ = false; - } - - // The repeated field bytes_for_quantized_values uses a lot of memory. - // Since it's no longer necessary (and we own the proto), we clear it. - embedding->clear_bytes_for_quantized_values(); - - success = FillVectorFromDataBytesInLittleEndian( - embedding->bytes_for_col_scales(), - embedding->rows(), - &(embeddings_quant_scales_[i])); - if (!success) { - TC_LOG(ERROR) << "Problem decoding col_scales for embeddings #" << i; - valid_ = false; - } - - // See comments for clear_bytes_for_quantized_values(). - embedding->clear_bytes_for_col_scales(); - } - } - - const TaskSpec *GetTaskSpec() override { - if (!proto_) { - return nullptr; - } - auto extension_id = task_spec_in_embedding_network_proto; - if (proto_->HasExtension(extension_id)) { - return &(proto_->GetExtension(extension_id)); - } else { - TC_LOG(ERROR) << "Unable to get TaskSpec from EmbeddingNetworkProto"; - return nullptr; - } - } - - // Returns true if these params are valid. False otherwise (e.g., if the - // original proto data was corrupted). - bool is_valid() { return valid_; } - - protected: - int embeddings_size() const override { return proto_->embeddings_size(); } - - int embeddings_num_rows(int i) const override { - TC_DCHECK(InRange(i, embeddings_size())); - return proto_->embeddings(i).rows(); - } - - int embeddings_num_cols(int i) const override { - TC_DCHECK(InRange(i, embeddings_size())); - return proto_->embeddings(i).cols(); - } - - const void *embeddings_weights(int i) const override { - TC_DCHECK(InRange(i, embeddings_size())); - if (proto_->embeddings(i).is_quantized()) { - return static_cast<const void *>(embeddings_quant_weights_.at(i).data()); - } else { - return static_cast<const void *>(proto_->embeddings(i).value().data()); - } - } - - QuantizationType embeddings_quant_type(int i) const override { - TC_DCHECK(InRange(i, embeddings_size())); - return proto_->embeddings(i).is_quantized() ? QuantizationType::UINT8 - : QuantizationType::NONE; - } - - const float16 *embeddings_quant_scales(int i) const override { - TC_DCHECK(InRange(i, embeddings_size())); - return proto_->embeddings(i).is_quantized() - ? embeddings_quant_scales_.at(i).data() - : nullptr; - } - - int hidden_size() const override { return proto_->hidden_size(); } - - int hidden_num_rows(int i) const override { - TC_DCHECK(InRange(i, hidden_size())); - return proto_->hidden(i).rows(); - } - - int hidden_num_cols(int i) const override { - TC_DCHECK(InRange(i, hidden_size())); - return proto_->hidden(i).cols(); - } - - const void *hidden_weights(int i) const override { - TC_DCHECK(InRange(i, hidden_size())); - return proto_->hidden(i).value().data(); - } - - int hidden_bias_size() const override { return proto_->hidden_bias_size(); } - - int hidden_bias_num_rows(int i) const override { - TC_DCHECK(InRange(i, hidden_bias_size())); - return proto_->hidden_bias(i).rows(); - } - - int hidden_bias_num_cols(int i) const override { - TC_DCHECK(InRange(i, hidden_bias_size())); - return proto_->hidden_bias(i).cols(); - } - - const void *hidden_bias_weights(int i) const override { - TC_DCHECK(InRange(i, hidden_bias_size())); - return proto_->hidden_bias(i).value().data(); - } - - int softmax_size() const override { return proto_->has_softmax() ? 1 : 0; } - - int softmax_num_rows(int i) const override { - TC_DCHECK(InRange(i, softmax_size())); - return proto_->has_softmax() ? proto_->softmax().rows() : 0; - } - - int softmax_num_cols(int i) const override { - TC_DCHECK(InRange(i, softmax_size())); - return proto_->has_softmax() ? proto_->softmax().cols() : 0; - } - - const void *softmax_weights(int i) const override { - TC_DCHECK(InRange(i, softmax_size())); - return proto_->has_softmax() ? proto_->softmax().value().data() : nullptr; - } - - int softmax_bias_size() const override { - return proto_->has_softmax_bias() ? 1 : 0; - } - - int softmax_bias_num_rows(int i) const override { - TC_DCHECK(InRange(i, softmax_bias_size())); - return proto_->has_softmax_bias() ? proto_->softmax_bias().rows() : 0; - } - - int softmax_bias_num_cols(int i) const override { - TC_DCHECK(InRange(i, softmax_bias_size())); - return proto_->has_softmax_bias() ? proto_->softmax_bias().cols() : 0; - } - - const void *softmax_bias_weights(int i) const override { - TC_DCHECK(InRange(i, softmax_bias_size())); - return proto_->has_softmax_bias() ? proto_->softmax_bias().value().data() - : nullptr; - } - - int embedding_num_features_size() const override { - return proto_->embedding_num_features_size(); - } - - int embedding_num_features(int i) const override { - TC_DCHECK(InRange(i, embedding_num_features_size())); - return proto_->embedding_num_features(i); - } - - private: - std::unique_ptr<EmbeddingNetworkProto> proto_; - - // True if these params are valid. May be false if the original proto was - // corrupted. We prefer to set this to false to CHECK-failing. - bool valid_; - - // When the embeddings are quantized, these members are used to store their - // numeric values using the types expected by the rest of the class. Due to - // technical reasons, the proto stores this info using larger types (i.e., - // more bits). - std::vector<std::vector<float16>> embeddings_quant_scales_; - std::vector<std::vector<uint8>> embeddings_quant_weights_; -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_ diff --git a/common/embedding-network-params.h b/common/embedding-network-params.h deleted file mode 100755 index ee2d9dc..0000000 --- a/common/embedding-network-params.h +++ /dev/null @@ -1,325 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_ -#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_ - -#include <algorithm> -#include <string> - -#include "common/float16.h" -#include "common/task-context.h" -#include "common/task-spec.pb.h" -#include "util/base/logging.h" - -namespace libtextclassifier { -namespace nlp_core { - -enum class QuantizationType { NONE = 0, UINT8 }; - -// API for accessing parameters for a feed-forward neural network with -// embeddings. -// -// Note: this API is closely related to embedding-network.proto. The reason we -// have a separate API is that the proto may not be the only way of packaging -// these parameters. -class EmbeddingNetworkParams { - public: - virtual ~EmbeddingNetworkParams() {} - - // **** High-level API. - - // Simple representation of a matrix. This small struct that doesn't own any - // resource intentionally supports copy / assign, to simplify our APIs. - struct Matrix { - // Number of rows. - int rows; - - // Number of columns. - int cols; - - QuantizationType quant_type; - - // Pointer to matrix elements, in row-major order - // (https://en.wikipedia.org/wiki/Row-major_order) Not owned. - const void *elements; - - // Quantization scales: one scale for each row. - const float16 *quant_scales; - }; - - // Returns number of embedding spaces. - int GetNumEmbeddingSpaces() const { - if (embeddings_size() != embedding_num_features_size()) { - TC_LOG(ERROR) << "Embedding spaces mismatch " << embeddings_size() - << " != " << embedding_num_features_size(); - } - return std::max(0, - std::min(embeddings_size(), embedding_num_features_size())); - } - - // Returns embedding matrix for the i-th embedding space. - // - // NOTE: i must be in [0, GetNumEmbeddingSpaces()). Undefined behavior - // otherwise. - Matrix GetEmbeddingMatrix(int i) const { - TC_DCHECK(InRange(i, embeddings_size())); - Matrix matrix; - matrix.rows = embeddings_num_rows(i); - matrix.cols = embeddings_num_cols(i); - matrix.elements = embeddings_weights(i); - matrix.quant_type = embeddings_quant_type(i); - matrix.quant_scales = embeddings_quant_scales(i); - return matrix; - } - - // Returns number of features in i-th embedding space. - // - // NOTE: i must be in [0, GetNumEmbeddingSpaces()). Undefined behavior - // otherwise. - int GetNumFeaturesInEmbeddingSpace(int i) const { - TC_DCHECK(InRange(i, embedding_num_features_size())); - return std::max(0, embedding_num_features(i)); - } - - // Returns number of hidden layers in the neural network. Each such layer has - // weight matrix and a bias vector (a matrix with one column). - int GetNumHiddenLayers() const { - if (hidden_size() != hidden_bias_size()) { - TC_LOG(ERROR) << "Hidden layer mismatch " << hidden_size() - << " != " << hidden_bias_size(); - } - return std::max(0, std::min(hidden_size(), hidden_bias_size())); - } - - // Returns weight matrix for i-th hidden layer. - // - // NOTE: i must be in [0, GetNumHiddenLayers()). Undefined behavior - // otherwise. - Matrix GetHiddenLayerMatrix(int i) const { - TC_DCHECK(InRange(i, hidden_size())); - Matrix matrix; - matrix.rows = hidden_num_rows(i); - matrix.cols = hidden_num_cols(i); - - // Quantization not supported here. - matrix.quant_type = QuantizationType::NONE; - matrix.elements = hidden_weights(i); - return matrix; - } - - // Returns bias matrix for i-th hidden layer. Technically a Matrix, but we - // expect it to be a vector (i.e., num cols is 1). - // - // NOTE: i must be in [0, GetNumHiddenLayers()). Undefined behavior - // otherwise. - Matrix GetHiddenLayerBias(int i) const { - TC_DCHECK(InRange(i, hidden_bias_size())); - Matrix matrix; - matrix.rows = hidden_bias_num_rows(i); - matrix.cols = hidden_bias_num_cols(i); - - // Quantization not supported here. - matrix.quant_type = QuantizationType::NONE; - matrix.elements = hidden_bias_weights(i); - return matrix; - } - - // Returns true if a softmax layer exists. - bool HasSoftmaxLayer() const { - if (softmax_size() != softmax_bias_size()) { - TC_LOG(ERROR) << "Softmax layer mismatch " << softmax_size() - << " != " << softmax_bias_size(); - } - return (softmax_size() == 1) && (softmax_bias_size() == 1); - } - - // Returns weight matrix for the softmax layer. - // - // NOTE: Should be called only if HasSoftmaxLayer() is true. Undefined - // behavior otherwise. - Matrix GetSoftmaxMatrix() const { - TC_DCHECK(softmax_size() == 1); - Matrix matrix; - matrix.rows = softmax_num_rows(0); - matrix.cols = softmax_num_cols(0); - - // Quantization not supported here. - matrix.quant_type = QuantizationType::NONE; - matrix.elements = softmax_weights(0); - return matrix; - } - - // Returns bias for the softmax layer. Technically a Matrix, but we expect it - // to be a row/column vector (i.e., num cols is 1). - // - // NOTE: Should be called only if HasSoftmaxLayer() is true. Undefined - // behavior otherwise. - Matrix GetSoftmaxBias() const { - TC_DCHECK(softmax_bias_size() == 1); - Matrix matrix; - matrix.rows = softmax_bias_num_rows(0); - matrix.cols = softmax_bias_num_cols(0); - - // Quantization not supported here. - matrix.quant_type = QuantizationType::NONE; - matrix.elements = softmax_bias_weights(0); - return matrix; - } - - // Updates the EmbeddingNetwork-related parameters from task_context. Returns - // true on success, false on error. - virtual bool UpdateTaskContextParameters(TaskContext *task_context) { - const TaskSpec *task_spec = GetTaskSpec(); - if (task_spec == nullptr) { - TC_LOG(ERROR) << "Unable to get TaskSpec"; - return false; - } - for (const TaskSpec::Parameter ¶meter : task_spec->parameter()) { - task_context->SetParameter(parameter.name(), parameter.value()); - } - return true; - } - - // Returns a pointer to a TaskSpec with the EmbeddingNetwork-related - // parameters. Returns nullptr in case of problems. Ownership with the - // returned pointer is *not* transfered to the caller. - virtual const TaskSpec *GetTaskSpec() { - TC_LOG(ERROR) << "Not implemented"; - return nullptr; - } - - protected: - // **** Low-level API. - // - // * Most low-level API methods are documented by giving an equivalent - // function call on proto, the original proto (of type - // EmbeddingNetworkProto) which was used to generate the C++ code. - // - // * To simplify our generation code, optional proto fields of message type - // are treated as repeated fields with 0 or 1 instances. As such, we have - // *_size() methods for such optional fields: they return 0 or 1. - // - // * "transpose(M)" denotes the transpose of a matrix M. - // - // * Behavior is undefined when trying to retrieve a piece of data that does - // not exist: e.g., embeddings_num_rows(5) if embeddings_size() == 2. - - // ** Access methods for repeated MatrixParams embeddings. - // - // Returns proto.embeddings_size(). - virtual int embeddings_size() const = 0; - - // Returns number of rows of transpose(proto.embeddings(i)). - virtual int embeddings_num_rows(int i) const = 0; - - // Returns number of columns of transpose(proto.embeddings(i)). - virtual int embeddings_num_cols(int i) const = 0; - - // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major - // order. NOTE: for unquantized embeddings, this returns a pointer to float; - // for quantized embeddings, this returns a pointer to uint8. - virtual const void *embeddings_weights(int i) const = 0; - - virtual QuantizationType embeddings_quant_type(int i) const { - return QuantizationType::NONE; - } - - virtual const float16 *embeddings_quant_scales(int i) const { - return nullptr; - } - - // ** Access methods for repeated MatrixParams hidden. - // - // Returns embedding_network_proto.hidden_size(). - virtual int hidden_size() const = 0; - - // Returns embedding_network_proto.hidden(i).rows(). - virtual int hidden_num_rows(int i) const = 0; - - // Returns embedding_network_proto.hidden(i).rows(). - virtual int hidden_num_cols(int i) const = 0; - - // Returns pointer to beginning of array of floats with all values from - // embedding_network_proto.hidden(i). - virtual const void *hidden_weights(int i) const = 0; - - // ** Access methods for repeated MatrixParams hidden_bias. - // - // Returns proto.hidden_bias_size(). - virtual int hidden_bias_size() const = 0; - - // Returns number of rows of proto.hidden_bias(i). - virtual int hidden_bias_num_rows(int i) const = 0; - - // Returns number of columns of proto.hidden_bias(i). - virtual int hidden_bias_num_cols(int i) const = 0; - - // Returns pointer to elements of proto.hidden_bias(i), in row-major order. - virtual const void *hidden_bias_weights(int i) const = 0; - - // ** Access methods for optional MatrixParams softmax. - // - // Returns 1 if proto has optional field softmax, 0 otherwise. - virtual int softmax_size() const = 0; - - // Returns number of rows of transpose(proto.softmax()). - virtual int softmax_num_rows(int i) const = 0; - - // Returns number of columns of transpose(proto.softmax()). - virtual int softmax_num_cols(int i) const = 0; - - // Returns pointer to elements of transpose(proto.softmax()), in row-major - // order. - virtual const void *softmax_weights(int i) const = 0; - - // ** Access methods for optional MatrixParams softmax_bias. - // - // Returns 1 if proto has optional field softmax_bias, 0 otherwise. - virtual int softmax_bias_size() const = 0; - - // Returns number of rows of proto.softmax_bias(). - virtual int softmax_bias_num_rows(int i) const = 0; - - // Returns number of columns of proto.softmax_bias(). - virtual int softmax_bias_num_cols(int i) const = 0; - - // Returns pointer to elements of proto.softmax_bias(), in row-major order. - virtual const void *softmax_bias_weights(int i) const = 0; - - // ** Access methods for repeated int32 embedding_num_features. - // - // Returns proto.embedding_num_features_size(). - virtual int embedding_num_features_size() const = 0; - - // Returns proto.embedding_num_features(i). - virtual int embedding_num_features(int i) const = 0; - - // Returns true if and only if index is in range [0, size). Log an error - // message otherwise. - static bool InRange(int index, int size) { - if ((index < 0) || (index >= size)) { - TC_LOG(ERROR) << "Index " << index << " outside [0, " << size << ")"; - return false; - } - return true; - } -}; // class EmbeddingNetworkParams - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_ diff --git a/common/embedding-network.cc b/common/embedding-network.cc deleted file mode 100644 index b27cda3..0000000 --- a/common/embedding-network.cc +++ /dev/null @@ -1,380 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/embedding-network.h" - -#include <math.h> - -#include "common/simple-adder.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" - -namespace libtextclassifier { -namespace nlp_core { - -namespace { - -// Returns true if and only if matrix does not use any quantization. -bool CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) { - if (matrix.quant_type != QuantizationType::NONE) { - TC_LOG(ERROR) << "Unsupported quantization"; - TC_DCHECK(false); // Crash in debug mode. - return false; - } - return true; -} - -// Initializes a Matrix object with the parameters from the MatrixParams -// source_matrix. source_matrix should not use quantization. -// -// Returns true on success, false on error. -bool InitNonQuantizedMatrix(const EmbeddingNetworkParams::Matrix &source_matrix, - EmbeddingNetwork::Matrix *mat) { - mat->resize(source_matrix.rows); - - // Before we access the weights as floats, we need to check that they are - // really floats, i.e., no quantization is used. - if (!CheckNoQuantization(source_matrix)) return false; - const float *weights = - reinterpret_cast<const float *>(source_matrix.elements); - for (int r = 0; r < source_matrix.rows; ++r) { - (*mat)[r] = EmbeddingNetwork::VectorWrapper(weights, source_matrix.cols); - weights += source_matrix.cols; - } - return true; -} - -// Initializes a VectorWrapper object with the parameters from the MatrixParams -// source_matrix. source_matrix should have exactly one column and should not -// use quantization. -// -// Returns true on success, false on error. -bool InitNonQuantizedVector(const EmbeddingNetworkParams::Matrix &source_matrix, - EmbeddingNetwork::VectorWrapper *vector) { - if (source_matrix.cols != 1) { - TC_LOG(ERROR) << "wrong #cols " << source_matrix.cols; - return false; - } - if (!CheckNoQuantization(source_matrix)) { - TC_LOG(ERROR) << "unsupported quantization"; - return false; - } - // Before we access the weights as floats, we need to check that they are - // really floats, i.e., no quantization is used. - if (!CheckNoQuantization(source_matrix)) return false; - const float *weights = - reinterpret_cast<const float *>(source_matrix.elements); - *vector = EmbeddingNetwork::VectorWrapper(weights, source_matrix.rows); - return true; -} - -// Computes y = weights * Relu(x) + b where Relu is optionally applied. -template <typename ScaleAdderClass> -bool SparseReluProductPlusBias(bool apply_relu, - const EmbeddingNetwork::Matrix &weights, - const EmbeddingNetwork::VectorWrapper &b, - const VectorSpan<float> &x, - EmbeddingNetwork::Vector *y) { - // Check that dimensions match. - if ((x.size() != weights.size()) || weights.empty()) { - TC_LOG(ERROR) << x.size() << " != " << weights.size(); - return false; - } - if (weights[0].size() != b.size()) { - TC_LOG(ERROR) << weights[0].size() << " != " << b.size(); - return false; - } - - y->assign(b.data(), b.data() + b.size()); - ScaleAdderClass adder(y->data(), y->size()); - - const int x_size = x.size(); - for (int i = 0; i < x_size; ++i) { - const float &scale = x[i]; - if (apply_relu) { - if (scale > 0) { - adder.LazyScaleAdd(weights[i].data(), scale); - } - } else { - adder.LazyScaleAdd(weights[i].data(), scale); - } - } - return true; -} -} // namespace - -bool EmbeddingNetwork::ConcatEmbeddings( - const std::vector<FeatureVector> &feature_vectors, Vector *concat) const { - concat->resize(concat_layer_size_); - - // Invariant 1: feature_vectors contains exactly one element for each - // embedding space. That element is itself a FeatureVector, which may be - // empty, but it should be there. - if (feature_vectors.size() != embedding_matrices_.size()) { - TC_LOG(ERROR) << feature_vectors.size() - << " != " << embedding_matrices_.size(); - return false; - } - - // "es_index" stands for "embedding space index". - for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) { - // Access is safe by es_index loop bounds and Invariant 1. - EmbeddingMatrix *const embedding_matrix = - embedding_matrices_[es_index].get(); - if (embedding_matrix == nullptr) { - // Should not happen, hence our terse log error message. - TC_LOG(ERROR) << es_index; - return false; - } - - // Access is safe due to es_index loop bounds. - const FeatureVector &feature_vector = feature_vectors[es_index]; - - // Access is safe by es_index loop bounds, Invariant 1, and Invariant 2. - const int concat_offset = concat_offset_[es_index]; - - if (!GetEmbeddingInternal(feature_vector, embedding_matrix, concat_offset, - concat->data(), concat->size())) { - TC_LOG(ERROR) << es_index; - return false; - } - } - return true; -} - -bool EmbeddingNetwork::GetEmbedding(const FeatureVector &feature_vector, - int es_index, float *embedding) const { - EmbeddingMatrix *const embedding_matrix = embedding_matrices_[es_index].get(); - if (embedding_matrix == nullptr) { - // Should not happen, hence our terse log error message. - TC_LOG(ERROR) << es_index; - return false; - } - return GetEmbeddingInternal(feature_vector, embedding_matrix, 0, embedding, - embedding_matrices_[es_index]->dim()); -} - -bool EmbeddingNetwork::GetEmbeddingInternal( - const FeatureVector &feature_vector, - EmbeddingMatrix *const embedding_matrix, const int concat_offset, - float *concat, int concat_size) const { - const int embedding_dim = embedding_matrix->dim(); - const bool is_quantized = - embedding_matrix->quant_type() != QuantizationType::NONE; - const int num_features = feature_vector.size(); - for (int fi = 0; fi < num_features; ++fi) { - // Both accesses below are safe due to loop bounds for fi. - const FeatureType *feature_type = feature_vector.type(fi); - const FeatureValue feature_value = feature_vector.value(fi); - const int feature_offset = - concat_offset + feature_type->base() * embedding_dim; - - // Code below updates max(0, embedding_dim) elements from concat, starting - // with index feature_offset. Check below ensures these updates are safe. - if ((feature_offset < 0) || - (feature_offset + embedding_dim > concat_size)) { - TC_LOG(ERROR) << fi << ": " << feature_offset << " " << embedding_dim - << " " << concat_size; - return false; - } - - // Pointer to float / uint8 weights for relevant embedding. - const void *embedding_data; - - // Multiplier for each embedding weight. - float multiplier; - - if (feature_type->is_continuous()) { - // Continuous features (encoded as FloatFeatureValue). - FloatFeatureValue float_feature_value(feature_value); - const int id = float_feature_value.id; - embedding_matrix->get_embedding(id, &embedding_data, &multiplier); - multiplier *= float_feature_value.weight; - } else { - // Discrete features: every present feature has implicit value 1.0. - // Hence, after we grab the multiplier below, we don't multiply it by - // any weight. - embedding_matrix->get_embedding(feature_value, &embedding_data, - &multiplier); - } - - // Weighted embeddings will be added starting from this address. - float *concat_ptr = concat + feature_offset; - - if (is_quantized) { - const uint8 *quant_weights = - reinterpret_cast<const uint8 *>(embedding_data); - for (int i = 0; i < embedding_dim; ++i, ++quant_weights, ++concat_ptr) { - // 128 is bias for UINT8 quantization, only one we currently support. - *concat_ptr += (static_cast<int>(*quant_weights) - 128) * multiplier; - } - } else { - const float *weights = reinterpret_cast<const float *>(embedding_data); - for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) { - *concat_ptr += *weights * multiplier; - } - } - } - return true; -} - -bool EmbeddingNetwork::ComputeLogits(const VectorSpan<float> &input, - Vector *scores) const { - return EmbeddingNetwork::ComputeLogitsInternal(input, scores); -} - -bool EmbeddingNetwork::ComputeLogits(const Vector &input, - Vector *scores) const { - return EmbeddingNetwork::ComputeLogitsInternal(input, scores); -} - -bool EmbeddingNetwork::ComputeLogitsInternal(const VectorSpan<float> &input, - Vector *scores) const { - return FinishComputeFinalScoresInternal<SimpleAdder>(input, scores); -} - -template <typename ScaleAdderClass> -bool EmbeddingNetwork::FinishComputeFinalScoresInternal( - const VectorSpan<float> &input, Vector *scores) const { - // This vector serves as an alternating storage for activations of the - // different layers. We can't use just one vector here because all of the - // activations of the previous layer are needed for computation of - // activations of the next one. - std::vector<Vector> h_storage(2); - - // Compute pre-logits activations. - VectorSpan<float> h_in(input); - Vector *h_out; - for (int i = 0; i < hidden_weights_.size(); ++i) { - const bool apply_relu = i > 0; - h_out = &(h_storage[i % 2]); - h_out->resize(hidden_bias_[i].size()); - if (!SparseReluProductPlusBias<ScaleAdderClass>( - apply_relu, hidden_weights_[i], hidden_bias_[i], h_in, h_out)) { - return false; - } - h_in = VectorSpan<float>(*h_out); - } - - // Compute logit scores. - if (!SparseReluProductPlusBias<ScaleAdderClass>( - true, softmax_weights_, softmax_bias_, h_in, scores)) { - return false; - } - - return true; -} - -bool EmbeddingNetwork::ComputeFinalScores( - const std::vector<FeatureVector> &features, Vector *scores) const { - return ComputeFinalScores(features, {}, scores); -} - -bool EmbeddingNetwork::ComputeFinalScores( - const std::vector<FeatureVector> &features, - const std::vector<float> extra_inputs, Vector *scores) const { - // If we haven't successfully initialized, return without doing anything. - if (!is_valid()) return false; - - Vector concat; - if (!ConcatEmbeddings(features, &concat)) return false; - - if (!extra_inputs.empty()) { - concat.reserve(concat.size() + extra_inputs.size()); - for (int i = 0; i < extra_inputs.size(); i++) { - concat.push_back(extra_inputs[i]); - } - } - - scores->resize(softmax_bias_.size()); - return ComputeLogits(concat, scores); -} - -EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model) { - // We'll set valid_ to true only if construction is successful. If we detect - // an error along the way, we log an informative message and return early, but - // we do not crash. - valid_ = false; - - // Fill embedding_matrices_, concat_offset_, and concat_layer_size_. - const int num_embedding_spaces = model->GetNumEmbeddingSpaces(); - int offset_sum = 0; - for (int i = 0; i < num_embedding_spaces; ++i) { - concat_offset_.push_back(offset_sum); - const EmbeddingNetworkParams::Matrix matrix = model->GetEmbeddingMatrix(i); - if (matrix.quant_type != QuantizationType::UINT8) { - TC_LOG(ERROR) << "Unsupported quantization for embedding #" << i << ": " - << static_cast<int>(matrix.quant_type); - return; - } - - // There is no way to accomodate an empty embedding matrix. E.g., there is - // no way for get_embedding to return something that can be read safely. - // Hence, we catch that error here and return early. - if (matrix.rows == 0) { - TC_LOG(ERROR) << "Empty embedding matrix #" << i; - return; - } - embedding_matrices_.emplace_back(new EmbeddingMatrix(matrix)); - const int embedding_dim = embedding_matrices_.back()->dim(); - offset_sum += embedding_dim * model->GetNumFeaturesInEmbeddingSpace(i); - } - concat_layer_size_ = offset_sum; - - // Invariant 2 (trivial by the code above). - TC_DCHECK_EQ(concat_offset_.size(), embedding_matrices_.size()); - - const int num_hidden_layers = model->GetNumHiddenLayers(); - if (num_hidden_layers < 1) { - TC_LOG(ERROR) << "Wrong number of hidden layers: " << num_hidden_layers; - return; - } - hidden_weights_.resize(num_hidden_layers); - hidden_bias_.resize(num_hidden_layers); - - for (int i = 0; i < num_hidden_layers; ++i) { - const EmbeddingNetworkParams::Matrix matrix = - model->GetHiddenLayerMatrix(i); - const EmbeddingNetworkParams::Matrix bias = model->GetHiddenLayerBias(i); - if (!InitNonQuantizedMatrix(matrix, &hidden_weights_[i]) || - !InitNonQuantizedVector(bias, &hidden_bias_[i])) { - TC_LOG(ERROR) << "Bad hidden layer #" << i; - return; - } - } - - if (!model->HasSoftmaxLayer()) { - TC_LOG(ERROR) << "Missing softmax layer"; - return; - } - const EmbeddingNetworkParams::Matrix softmax = model->GetSoftmaxMatrix(); - const EmbeddingNetworkParams::Matrix softmax_bias = model->GetSoftmaxBias(); - if (!InitNonQuantizedMatrix(softmax, &softmax_weights_) || - !InitNonQuantizedVector(softmax_bias, &softmax_bias_)) { - TC_LOG(ERROR) << "Bad softmax layer"; - return; - } - - // Everything looks good. - valid_ = true; -} - -int EmbeddingNetwork::EmbeddingSize(int es_index) const { - return embedding_matrices_[es_index]->dim(); -} - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/embedding-network.h b/common/embedding-network.h deleted file mode 100644 index a02c6ea..0000000 --- a/common/embedding-network.h +++ /dev/null @@ -1,246 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_ -#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_ - -#include <memory> -#include <vector> - -#include "common/embedding-network-params.h" -#include "common/feature-extractor.h" -#include "common/vector-span.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" -#include "util/base/macros.h" - -namespace libtextclassifier { -namespace nlp_core { - -// Classifier using a hand-coded feed-forward neural network. -// -// No gradient computation, just inference. -// -// Classification works as follows: -// -// Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax -// -// In words: given some discrete features, this class extracts the embeddings -// for these features, concatenates them, passes them through one or two hidden -// layers (each layer uses Relu) and next through a softmax layer that computes -// an unnormalized score for each possible class. Note: there is always a -// softmax layer. -class EmbeddingNetwork { - public: - // Class used to represent an embedding matrix. Each row is the embedding on - // a vocabulary element. Number of columns = number of embedding dimensions. - class EmbeddingMatrix { - public: - explicit EmbeddingMatrix(const EmbeddingNetworkParams::Matrix source_matrix) - : rows_(source_matrix.rows), - cols_(source_matrix.cols), - quant_type_(source_matrix.quant_type), - data_(source_matrix.elements), - row_size_in_bytes_(GetRowSizeInBytes(cols_, quant_type_)), - quant_scales_(source_matrix.quant_scales) {} - - // Returns vocabulary size; one embedding for each vocabulary element. - int size() const { return rows_; } - - // Returns number of weights in embedding of each vocabulary element. - int dim() const { return cols_; } - - // Returns quantization type for this embedding matrix. - QuantizationType quant_type() const { return quant_type_; } - - // Gets embedding for k-th vocabulary element: on return, sets *data to - // point to the embedding weights and *scale to the quantization scale (1.0 - // if no quantization). - void get_embedding(int k, const void **data, float *scale) const { - if ((k < 0) || (k >= size())) { - TC_LOG(ERROR) << "Index outside [0, " << size() << "): " << k; - - // In debug mode, crash. In prod, pretend that k is 0. - TC_DCHECK(false); - k = 0; - } - *data = reinterpret_cast<const char *>(data_) + k * row_size_in_bytes_; - if (quant_type_ == QuantizationType::NONE) { - *scale = 1.0; - } else { - *scale = Float16To32(quant_scales_[k]); - } - } - - private: - static int GetRowSizeInBytes(int cols, QuantizationType quant_type) { - switch (quant_type) { - case QuantizationType::NONE: - return cols * sizeof(float); - case QuantizationType::UINT8: - return cols * sizeof(uint8); - default: - TC_LOG(ERROR) << "Unknown quant type: " - << static_cast<int>(quant_type); - return 0; - } - } - - // Vocabulary size. - const int rows_; - - // Number of elements in each embedding. - const int cols_; - - const QuantizationType quant_type_; - - // Pointer to the embedding weights, in row-major order. This is a pointer - // to an array of floats / uint8, depending on the quantization type. - // Not owned. - const void *const data_; - - // Number of bytes for one row. Used to jump to next row in data_. - const int row_size_in_bytes_; - - // Pointer to quantization scales. nullptr if no quantization. Otherwise, - // quant_scales_[i] is scale for embedding of i-th vocabulary element. - const float16 *const quant_scales_; - - TC_DISALLOW_COPY_AND_ASSIGN(EmbeddingMatrix); - }; - - // An immutable vector that doesn't own the memory that stores the underlying - // floats. Can be used e.g., as a wrapper around model weights stored in the - // static memory. - class VectorWrapper { - public: - VectorWrapper() : VectorWrapper(nullptr, 0) {} - - // Constructs a vector wrapper around the size consecutive floats that start - // at address data. Note: the underlying data should be alive for at least - // the lifetime of this VectorWrapper object. That's trivially true if data - // points to statically allocated data :) - VectorWrapper(const float *data, int size) : data_(data), size_(size) {} - - int size() const { return size_; } - - const float *data() const { return data_; } - - private: - const float *data_; // Not owned. - int size_; - - // Doesn't own anything, so it can be copied and assigned at will :) - }; - - typedef std::vector<VectorWrapper> Matrix; - typedef std::vector<float> Vector; - - // Constructs an embedding network using the parameters from model. - // - // Note: model should stay alive for at least the lifetime of this - // EmbeddingNetwork object. - explicit EmbeddingNetwork(const EmbeddingNetworkParams *model); - - virtual ~EmbeddingNetwork() {} - - // Returns true if this EmbeddingNetwork object has been correctly constructed - // and is ready to use. Idea: in case of errors, mark this EmbeddingNetwork - // object as invalid, but do not crash. - bool is_valid() const { return valid_; } - - // Runs forward computation to fill scores with unnormalized output unit - // scores. This is useful for making predictions. - // - // Returns true on success, false on error (e.g., if !is_valid()). - bool ComputeFinalScores(const std::vector<FeatureVector> &features, - Vector *scores) const; - - // Same as above, but allows specification of extra neural network inputs that - // will be appended to the embedding vector build from features. - bool ComputeFinalScores(const std::vector<FeatureVector> &features, - const std::vector<float> extra_inputs, - Vector *scores) const; - - // Constructs the concatenated input embedding vector in place in output - // vector concat. Returns true on success, false on error. - bool ConcatEmbeddings(const std::vector<FeatureVector> &features, - Vector *concat) const; - - // Sums embeddings for all features from |feature_vector| and adds result - // to values from the array pointed-to by |output|. Embeddings for continuous - // features are weighted by the feature weight. - // - // NOTE: output should point to an array of EmbeddingSize(es_index) floats. - bool GetEmbedding(const FeatureVector &feature_vector, int es_index, - float *embedding) const; - - // Runs the feed-forward neural network for |input| and computes logits for - // softmax layer. - bool ComputeLogits(const Vector &input, Vector *scores) const; - - // Same as above but uses a view of the feature vector. - bool ComputeLogits(const VectorSpan<float> &input, Vector *scores) const; - - // Returns the size (the number of columns) of the embedding space es_index. - int EmbeddingSize(int es_index) const; - - protected: - // Builds an embedding for given feature vector, and places it from - // concat_offset to the concat vector. - bool GetEmbeddingInternal(const FeatureVector &feature_vector, - EmbeddingMatrix *embedding_matrix, - int concat_offset, float *concat, - int embedding_size) const; - - // Templated function that computes the logit scores given the concatenated - // input embeddings. - bool ComputeLogitsInternal(const VectorSpan<float> &concat, - Vector *scores) const; - - // Computes the softmax scores (prior to normalization) from the concatenated - // representation. Returns true on success, false on error. - template <typename ScaleAdderClass> - bool FinishComputeFinalScoresInternal(const VectorSpan<float> &concat, - Vector *scores) const; - - // Set to true on successful construction, false otherwise. - bool valid_ = false; - - // Network parameters. - - // One weight matrix for each embedding space. - std::vector<std::unique_ptr<EmbeddingMatrix>> embedding_matrices_; - - // concat_offset_[i] is the input layer offset for i-th embedding space. - std::vector<int> concat_offset_; - - // Size of the input ("concatenation") layer. - int concat_layer_size_; - - // One weight matrix and one vector of bias weights for each hiden layer. - std::vector<Matrix> hidden_weights_; - std::vector<VectorWrapper> hidden_bias_; - - // Weight matrix and bias vector for the softmax layer. - Matrix softmax_weights_; - VectorWrapper softmax_bias_; -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_ diff --git a/common/embedding-network.proto b/common/embedding-network.proto deleted file mode 100644 index ce30b11..0000000 --- a/common/embedding-network.proto +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (C) 2017 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Protos for performing inference with an EmbeddingNetwork. - -syntax = "proto2"; -option optimize_for = LITE_RUNTIME; - -package libtextclassifier.nlp_core; - -// Wrapper for storing a matrix of parameters. These are stored in row-major -// order. -message MatrixParams { - optional int32 rows = 1; // # of rows in the matrix - optional int32 cols = 2; // # of columns in the matrix - - // Non-quantized matrix entries. - repeated float value = 3 [packed = true]; - - // Whether the matrix is quantized. - optional bool is_quantized = 4 [default = false]; - - // Bytes for all quantized values. Each value (see "repeated float value" - // field) is quantized to an uint8 (1 byte) value, and all these bytes are - // concatenated into the string from this field. - optional bytes bytes_for_quantized_values = 7; - - // Bytes for all scale factors for dequantizing the values. The quantization - // process generates a float16 scale factor for each column. The 2 bytes for - // each such float16 are put in little-endian order (least significant byte - // first) and next all these pairs of bytes are concatenated into the string - // from this field. - optional bytes bytes_for_col_scales = 8; - - reserved 5, 6; -} - -// Stores all parameters for a given EmbeddingNetwork. This can either be a -// EmbeddingNetwork or a PrecomputedEmbeddingNetwork: for precomputed networks, -// the embedding weights are actually the activations of the first hidden layer -// *before* the bias is added and the non-linear transform is applied. -// -// Thus, for PrecomputedEmbeddingNetwork storage, hidden layers are stored -// starting from the second hidden layer, while biases are stored for every -// hidden layer. -message EmbeddingNetworkProto { - // Embeddings and hidden layers. Note that if is_precomputed == true, then the - // embeddings should store the activations of the first hidden layer, so we - // must have hidden_bias_size() == hidden_size() + 1 (we store weights for - // first hidden layer bias, but no the layer itself.) - repeated MatrixParams embeddings = 1; - repeated MatrixParams hidden = 2; - repeated MatrixParams hidden_bias = 3; - - // Final layer of the network. - optional MatrixParams softmax = 4; - optional MatrixParams softmax_bias = 5; - - // Element i of the repeated field below indicates number of features that use - // the i-th embedding space. - repeated int32 embedding_num_features = 7; - - // Whether or not this is intended to store a precomputed network. - optional bool is_precomputed = 11 [default = false]; - - // True if this EmbeddingNetworkProto can be used for inference with no - // additional matrix transposition. - // - // Given an EmbeddingNetworkProto produced by a Neurosis training pipeline, we - // have to transpose a few matrices (e.g., the embedding matrices) before we - // can perform inference. When we do so, we negate this flag. Note: we don't - // simply set this to true: transposing twice takes us to the original state. - optional bool is_transposed = 12 [default = false]; - - // Allow extensions. - extensions 100 to max; - - reserved 6, 8, 9, 10; -} diff --git a/common/embedding-network_test.cc b/common/embedding-network_test.cc deleted file mode 100644 index 026ec17..0000000 --- a/common/embedding-network_test.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/embedding-network.h" -#include "common/embedding-network-params-from-proto.h" -#include "common/embedding-network.pb.h" -#include "common/simple-adder.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace { - -using testing::ElementsAreArray; - -class TestingEmbeddingNetwork : public EmbeddingNetwork { - public: - using EmbeddingNetwork::EmbeddingNetwork; - using EmbeddingNetwork::FinishComputeFinalScoresInternal; -}; - -void DiagonalAndBias3x3(int diagonal_value, int bias_value, - MatrixParams* weights, MatrixParams* bias) { - weights->set_rows(3); - weights->set_cols(3); - weights->add_value(diagonal_value); - weights->add_value(0); - weights->add_value(0); - weights->add_value(0); - weights->add_value(diagonal_value); - weights->add_value(0); - weights->add_value(0); - weights->add_value(0); - weights->add_value(diagonal_value); - - bias->set_rows(3); - bias->set_cols(1); - bias->add_value(bias_value); - bias->add_value(bias_value); - bias->add_value(bias_value); -} - -TEST(EmbeddingNetworkTest, IdentityThroughMultipleLayers) { - std::unique_ptr<EmbeddingNetworkProto> proto; - proto.reset(new EmbeddingNetworkProto); - - // These layers should be an identity with bias. - DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/1, - proto->add_hidden(), proto->add_hidden_bias()); - DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/2, - proto->add_hidden(), proto->add_hidden_bias()); - DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/3, - proto->add_hidden(), proto->add_hidden_bias()); - DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/4, - proto->add_hidden(), proto->add_hidden_bias()); - DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/5, - proto->mutable_softmax(), proto->mutable_softmax_bias()); - - EmbeddingNetworkParamsFromProto params(std::move(proto)); - TestingEmbeddingNetwork network(¶ms); - - std::vector<float> input({-2, -1, 0}); - std::vector<float> output; - network.FinishComputeFinalScoresInternal<SimpleAdder>( - VectorSpan<float>(input), &output); - - EXPECT_THAT(output, ElementsAreArray({14, 14, 15})); -} - -} // namespace -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/feature-descriptors.h b/common/feature-descriptors.h deleted file mode 100644 index 9aa6527..0000000 --- a/common/feature-descriptors.h +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_FEATURE_DESCRIPTORS_H_ -#define LIBTEXTCLASSIFIER_COMMON_FEATURE_DESCRIPTORS_H_ - -#include <memory> -#include <string> -#include <vector> - -#include "util/base/integral_types.h" -#include "util/base/logging.h" -#include "util/base/macros.h" - -namespace libtextclassifier { -namespace nlp_core { - -// Named feature parameter. -class Parameter { - public: - Parameter() {} - - void set_name(const std::string &name) { name_ = name; } - const std::string &name() const { return name_; } - - void set_value(const std::string &value) { value_ = value; } - const std::string &value() const { return value_; } - - private: - std::string name_; - std::string value_; -}; - -// Descriptor for a feature function. Used to store the results of parsing one -// feature function. -class FeatureFunctionDescriptor { - public: - FeatureFunctionDescriptor() {} - - // Accessors for the feature function type. The function type is the string - // that the feature extractor code is registered under. - void set_type(const std::string &type) { type_ = type; } - bool has_type() const { return !type_.empty(); } - const std::string &type() const { return type_; } - - // Accessors for the feature function name. The function name (if available) - // is used for some log messages. Otherwise, a more precise, but also more - // verbose name based on the feature specification is used. - void set_name(const std::string &name) { name_ = name; } - bool has_name() const { return !name_.empty(); } - const std::string &name() { return name_; } - - // Accessors for the default (name-less) parameter. - void set_argument(int32 argument) { argument_ = argument; } - bool has_argument() const { - // If argument has not been specified, clients should treat it as 0. This - // makes the test below correct, without having a separate has_argument_ - // bool field. - return argument_ != 0; - } - int32 argument() const { return argument_; } - - // Accessors for the named parameters. - Parameter *add_parameter() { - parameters_.emplace_back(); - return &(parameters_.back()); - } - int parameter_size() const { return parameters_.size(); } - const Parameter ¶meter(int i) const { - TC_DCHECK((i >= 0) && (i < parameter_size())); - return parameters_[i]; - } - - // Accessors for the sub (i.e., nested) features. Nested features: as in - // offset(1).label. - FeatureFunctionDescriptor *add_feature() { - sub_features_.emplace_back(new FeatureFunctionDescriptor()); - return sub_features_.back().get(); - } - int feature_size() const { return sub_features_.size(); } - const FeatureFunctionDescriptor &feature(int i) const { - TC_DCHECK((i >= 0) && (i < feature_size())); - return *(sub_features_[i].get()); - } - FeatureFunctionDescriptor *mutable_feature(int i) { - TC_DCHECK((i >= 0) && (i < feature_size())); - return sub_features_[i].get(); - } - - private: - // See comments for set_type(). - std::string type_; - - // See comments for set_name(). - std::string name_; - - // See comments for set_argument(). - int32 argument_ = 0; - - // See comemnts for add_parameter(). - std::vector<Parameter> parameters_; - - // See comments for add_feature(). - std::vector<std::unique_ptr<FeatureFunctionDescriptor>> sub_features_; - - TC_DISALLOW_COPY_AND_ASSIGN(FeatureFunctionDescriptor); -}; - -// List of FeatureFunctionDescriptors. Used to store the result of parsing the -// spec for several feature functions. -class FeatureExtractorDescriptor { - public: - FeatureExtractorDescriptor() {} - - int feature_size() const { return features_.size(); } - - FeatureFunctionDescriptor *add_feature() { - features_.emplace_back(new FeatureFunctionDescriptor()); - return features_.back().get(); - } - - const FeatureFunctionDescriptor &feature(int i) const { - TC_DCHECK((i >= 0) && (i < feature_size())); - return *(features_[i].get()); - } - - FeatureFunctionDescriptor *mutable_feature(int i) { - TC_DCHECK((i >= 0) && (i < feature_size())); - return features_[i].get(); - } - - private: - std::vector<std::unique_ptr<FeatureFunctionDescriptor>> features_; - - TC_DISALLOW_COPY_AND_ASSIGN(FeatureExtractorDescriptor); -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_FEATURE_DESCRIPTORS_H_ diff --git a/common/feature-extractor.cc b/common/feature-extractor.cc deleted file mode 100644 index 12de46d..0000000 --- a/common/feature-extractor.cc +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/feature-extractor.h" - -#include "common/feature-types.h" -#include "common/fml-parser.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" -#include "util/gtl/stl_util.h" -#include "util/strings/numbers.h" - -namespace libtextclassifier { -namespace nlp_core { - -constexpr FeatureValue GenericFeatureFunction::kNone; - -GenericFeatureExtractor::GenericFeatureExtractor() {} - -GenericFeatureExtractor::~GenericFeatureExtractor() {} - -bool GenericFeatureExtractor::Parse(const std::string &source) { - // Parse feature specification into descriptor. - FMLParser parser; - if (!parser.Parse(source, mutable_descriptor())) return false; - - // Initialize feature extractor from descriptor. - if (!InitializeFeatureFunctions()) return false; - return true; -} - -bool GenericFeatureExtractor::InitializeFeatureTypes() { - // Register all feature types. - GetFeatureTypes(&feature_types_); - for (size_t i = 0; i < feature_types_.size(); ++i) { - FeatureType *ft = feature_types_[i]; - ft->set_base(i); - - // Check for feature space overflow. - double domain_size = ft->GetDomainSize(); - if (domain_size < 0) { - TC_LOG(ERROR) << "Illegal domain size for feature " << ft->name() << ": " - << domain_size; - return false; - } - } - return true; -} - -FeatureValue GenericFeatureExtractor::GetDomainSize() const { - // Domain size of the set of features is equal to: - // [largest domain size of any feature types] * [number of feature types] - FeatureValue max_feature_type_dsize = 0; - for (size_t i = 0; i < feature_types_.size(); ++i) { - FeatureType *ft = feature_types_[i]; - const FeatureValue feature_type_dsize = ft->GetDomainSize(); - if (feature_type_dsize > max_feature_type_dsize) { - max_feature_type_dsize = feature_type_dsize; - } - } - - return max_feature_type_dsize * feature_types_.size(); -} - -std::string GenericFeatureFunction::GetParameter( - const std::string &name) const { - // Find named parameter in feature descriptor. - for (int i = 0; i < descriptor_->parameter_size(); ++i) { - if (name == descriptor_->parameter(i).name()) { - return descriptor_->parameter(i).value(); - } - } - return ""; -} - -GenericFeatureFunction::GenericFeatureFunction() {} - -GenericFeatureFunction::~GenericFeatureFunction() { delete feature_type_; } - -int GenericFeatureFunction::GetIntParameter(const std::string &name, - int default_value) const { - int32 parsed_value = default_value; - std::string value = GetParameter(name); - if (!value.empty()) { - if (!ParseInt32(value.c_str(), &parsed_value)) { - // A parameter value has been specified, but it can't be parsed as an int. - // We don't crash: instead, we long an error and return the default value. - TC_LOG(ERROR) << "Value of param " << name << " is not an int: " << value; - } - } - return parsed_value; -} - -bool GenericFeatureFunction::GetBoolParameter(const std::string &name, - bool default_value) const { - std::string value = GetParameter(name); - if (value.empty()) return default_value; - if (value == "true") return true; - if (value == "false") return false; - TC_LOG(ERROR) << "Illegal value '" << value << "' for bool parameter '" - << name << "'" - << " will assume default " << default_value; - return default_value; -} - -void GenericFeatureFunction::GetFeatureTypes( - std::vector<FeatureType *> *types) const { - if (feature_type_ != nullptr) types->push_back(feature_type_); -} - -FeatureType *GenericFeatureFunction::GetFeatureType() const { - // If a single feature type has been registered return it. - if (feature_type_ != nullptr) return feature_type_; - - // Get feature types for function. - std::vector<FeatureType *> types; - GetFeatureTypes(&types); - - // If there is exactly one feature type return this, else return null. - if (types.size() == 1) return types[0]; - return nullptr; -} - -std::string GenericFeatureFunction::name() const { - std::string output; - if (descriptor_->name().empty()) { - if (!prefix_.empty()) { - output.append(prefix_); - output.append("."); - } - ToFML(*descriptor_, &output); - } else { - output = descriptor_->name(); - } - return output; -} - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/feature-extractor.h b/common/feature-extractor.h deleted file mode 100644 index bdba609..0000000 --- a/common/feature-extractor.h +++ /dev/null @@ -1,665 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Generic feature extractor for extracting features from objects. The feature -// extractor can be used for extracting features from any object. The feature -// extractor and feature function classes are template classes that have to -// be instantiated for extracting feature from a specific object type. -// -// A feature extractor consists of a hierarchy of feature functions. Each -// feature function extracts one or more feature type and value pairs from the -// object. -// -// The feature extractor has a modular design where new feature functions can be -// registered as components. The feature extractor is initialized from a -// descriptor represented by a protocol buffer. The feature extractor can also -// be initialized from a text-based source specification of the feature -// extractor. Feature specification parsers can be added as components. By -// default the feature extractor can be read from an ASCII protocol buffer or in -// a simple feature modeling language (fml). - -// A feature function is invoked with a focus. Nested feature function can be -// invoked with another focus determined by the parent feature function. - -#ifndef LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_ -#define LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_ - -#include <stddef.h> - -#include <string> -#include <vector> - -#include "common/feature-descriptors.h" -#include "common/feature-types.h" -#include "common/fml-parser.h" -#include "common/registry.h" -#include "common/task-context.h" -#include "common/workspace.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" -#include "util/base/macros.h" -#include "util/gtl/stl_util.h" - -namespace libtextclassifier { -namespace nlp_core { - -typedef int64 Predicate; -typedef Predicate FeatureValue; - -// A union used to represent discrete and continuous feature values. -union FloatFeatureValue { - public: - explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {} - FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {} - FeatureValue discrete_value; - struct { - uint32 id; - float weight; - }; -}; - -// A feature vector contains feature type and value pairs. -class FeatureVector { - public: - FeatureVector() {} - - // Adds feature type and value pair to feature vector. - void add(FeatureType *type, FeatureValue value) { - features_.emplace_back(type, value); - } - - // Removes all elements from the feature vector. - void clear() { features_.clear(); } - - // Returns the number of elements in the feature vector. - int size() const { return features_.size(); } - - // Reserves space in the underlying feature vector. - void reserve(int n) { features_.reserve(n); } - - // Returns feature type for an element in the feature vector. - FeatureType *type(int index) const { return features_[index].type; } - - // Returns feature value for an element in the feature vector. - FeatureValue value(int index) const { return features_[index].value; } - - private: - // Structure for holding feature type and value pairs. - struct Element { - Element() : type(nullptr), value(-1) {} - Element(FeatureType *t, FeatureValue v) : type(t), value(v) {} - - FeatureType *type; - FeatureValue value; - }; - - // Array for storing feature vector elements. - std::vector<Element> features_; - - TC_DISALLOW_COPY_AND_ASSIGN(FeatureVector); -}; - -// The generic feature extractor is the type-independent part of a feature -// extractor. This holds the descriptor for the feature extractor and the -// collection of feature types used in the feature extractor. The feature -// types are not available until FeatureExtractor<>::Init() has been called. -class GenericFeatureExtractor { - public: - GenericFeatureExtractor(); - virtual ~GenericFeatureExtractor(); - - // Initializes the feature extractor from an FML string specification. For - // the FML specification grammar, see fml-parser.h. - // - // Returns true on success, false on syntax error. - bool Parse(const std::string &source); - - // Returns the feature extractor descriptor. - const FeatureExtractorDescriptor &descriptor() const { return descriptor_; } - FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; } - - // Returns the number of feature types in the feature extractor. Invalid - // before Init() has been called. - int feature_types() const { return feature_types_.size(); } - - // Returns a feature type used in the extractor. Invalid before Init() has - // been called. - const FeatureType *feature_type(int index) const { - return feature_types_[index]; - } - - // Returns the feature domain size of this feature extractor. - // NOTE: The way that domain size is calculated is, for some, unintuitive. It - // is the largest domain size of any feature type. - FeatureValue GetDomainSize() const; - - protected: - // Initializes the feature types used by the extractor. Called from - // FeatureExtractor<>::Init(). - // - // Returns true on success, false on error. - bool InitializeFeatureTypes(); - - private: - // Initializes the top-level feature functions. - virtual bool InitializeFeatureFunctions() = 0; - - // Returns all feature types used by the extractor. The feature types are - // added to the result array. - virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0; - - // Descriptor for the feature extractor. This is a protocol buffer that - // contains all the information about the feature extractor. The feature - // functions are initialized from the information in the descriptor. - FeatureExtractorDescriptor descriptor_; - - // All feature types used by the feature extractor. The collection of all the - // feature types describes the feature space of the feature set produced by - // the feature extractor. Not owned. - std::vector<FeatureType *> feature_types_; - - TC_DISALLOW_COPY_AND_ASSIGN(GenericFeatureExtractor); -}; - -// The generic feature function is the type-independent part of a feature -// function. Each feature function is associated with the descriptor that it is -// instantiated from. The feature types associated with this feature function -// will be established by the time FeatureExtractor<>::Init() completes. -class GenericFeatureFunction { - public: - // A feature value that represents the absence of a value. - static constexpr FeatureValue kNone = -1; - - GenericFeatureFunction(); - virtual ~GenericFeatureFunction(); - - // Sets up the feature function. NB: FeatureTypes of nested functions are not - // guaranteed to be available until Init(). - // - // Returns true on success, false on error. - virtual bool Setup(TaskContext *context) { return true; } - - // Initializes the feature function. NB: The FeatureType of this function must - // be established when this method completes. - // - // Returns true on success, false on error. - virtual bool Init(TaskContext *context) { return true; } - - // Requests workspaces from a registry to obtain indices into a WorkspaceSet - // for any Workspace objects used by this feature function. NB: This will be - // called after Init(), so it can depend on resources and arguments. - virtual void RequestWorkspaces(WorkspaceRegistry *registry) {} - - // Appends the feature types produced by the feature function to types. The - // default implementation appends feature_type(), if non-null. Invalid - // before Init() has been called. - virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const; - - // Returns the feature type for feature produced by this feature function. If - // the feature function produces features of different types this returns - // null. Invalid before Init() has been called. - virtual FeatureType *GetFeatureType() const; - - // Returns the name of the registry used for creating the feature function. - // This can be used for checking if two feature functions are of the same - // kind. - virtual const char *RegistryName() const = 0; - - // Returns the value of a named parameter from the feature function - // descriptor. Returns empty string ("") if parameter is not found. - std::string GetParameter(const std::string &name) const; - - // Returns the int value of a named parameter from the feature function - // descriptor. Returns default_value if the parameter is not found or if its - // value can't be parsed as an int. - int GetIntParameter(const std::string &name, int default_value) const; - - // Returns the bool value of a named parameter from the feature function - // descriptor. Returns default_value if the parameter is not found or if its - // value is not "true" or "false". - bool GetBoolParameter(const std::string &name, bool default_value) const; - - // Returns the FML function description for the feature function, i.e. the - // name and parameters without the nested features. - std::string FunctionName() const { - std::string output; - ToFMLFunction(*descriptor_, &output); - return output; - } - - // Returns the prefix for nested feature functions. This is the prefix of this - // feature function concatenated with the feature function name. - std::string SubPrefix() const { - return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName(); - } - - // Returns/sets the feature extractor this function belongs to. - GenericFeatureExtractor *extractor() const { return extractor_; } - void set_extractor(GenericFeatureExtractor *extractor) { - extractor_ = extractor; - } - - // Returns/sets the feature function descriptor. - FeatureFunctionDescriptor *descriptor() const { return descriptor_; } - void set_descriptor(FeatureFunctionDescriptor *descriptor) { - descriptor_ = descriptor; - } - - // Returns a descriptive name for the feature function. The name is taken from - // the descriptor for the feature function. If the name is empty or the - // feature function is a variable the name is the FML representation of the - // feature, including the prefix. - std::string name() const; - - // Returns the argument from the feature function descriptor. It defaults to - // 0 if the argument has not been specified. - int argument() const { - return descriptor_->has_argument() ? descriptor_->argument() : 0; - } - - // Returns/sets/clears function name prefix. - const std::string &prefix() const { return prefix_; } - void set_prefix(const std::string &prefix) { prefix_ = prefix; } - - protected: - // Returns the feature type for single-type feature functions. - FeatureType *feature_type() const { return feature_type_; } - - // Sets the feature type for single-type feature functions. This takes - // ownership of feature_type. Can only be called once with a non-null - // pointer. - void set_feature_type(FeatureType *feature_type) { - TC_DCHECK_NE(feature_type, nullptr); - feature_type_ = feature_type; - } - - private: - // Feature extractor this feature function belongs to. Not owned. - GenericFeatureExtractor *extractor_ = nullptr; - - // Descriptor for feature function. Not owned. - FeatureFunctionDescriptor *descriptor_ = nullptr; - - // Feature type for features produced by this feature function. If the - // feature function produces features of multiple feature types this is null - // and the feature function must return it's feature types in - // GetFeatureTypes(). Owned. - FeatureType *feature_type_ = nullptr; - - // Prefix used for sub-feature types of this function. - std::string prefix_; -}; - -// Feature function that can extract features from an object. Templated on -// two type arguments: -// -// OBJ: The "object" from which features are extracted; e.g., a sentence. This -// should be a plain type, rather than a reference or pointer. -// -// ARGS: A set of 0 or more types that are used to "index" into some part of the -// object that should be extracted, e.g. an int token index for a sentence -// object. This should not be a reference type. -template <class OBJ, class... ARGS> -class FeatureFunction - : public GenericFeatureFunction, - public RegisterableClass<FeatureFunction<OBJ, ARGS...> > { - public: - using Self = FeatureFunction<OBJ, ARGS...>; - - // Preprocesses the object. This will be called prior to calling Evaluate() - // or Compute() on that object. - virtual void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {} - - // Appends features computed from the object and focus to the result. The - // default implementation delegates to Compute(), adding a single value if - // available. Multi-valued feature functions must override this method. - virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, - ARGS... args, FeatureVector *result) const { - FeatureValue value = Compute(workspaces, object, args..., result); - if (value != kNone) result->add(feature_type(), value); - } - - // Returns a feature value computed from the object and focus, or kNone if no - // value is computed. Single-valued feature functions only need to override - // this method. - virtual FeatureValue Compute(const WorkspaceSet &workspaces, - const OBJ &object, ARGS... args, - const FeatureVector *fv) const { - return kNone; - } - - // Instantiates a new feature function in a feature extractor from a feature - // descriptor. - static Self *Instantiate(GenericFeatureExtractor *extractor, - FeatureFunctionDescriptor *fd, - const std::string &prefix) { - Self *f = Self::Create(fd->type()); - if (f != nullptr) { - f->set_extractor(extractor); - f->set_descriptor(fd); - f->set_prefix(prefix); - } - return f; - } - - // Returns the name of the registry for the feature function. - const char *RegistryName() const override { return Self::registry()->name(); } - - private: - // Special feature function class for resolving variable references. The type - // of the feature function is used for resolving the variable reference. When - // evaluated it will either get the feature value(s) from the variable portion - // of the feature vector, if present, or otherwise it will call the referenced - // feature extractor function directly to extract the feature(s). - class Reference; -}; - -// Base class for features with nested feature functions. The nested functions -// are of type NES, which may be different from the type of the parent function. -// NB: NestedFeatureFunction will ensure that all initialization of nested -// functions takes place during Setup() and Init() -- after the nested features -// are initialized, the parent feature is initialized via SetupNested() and -// InitNested(). Alternatively, a derived classes that overrides Setup() and -// Init() directly should call Parent::Setup(), Parent::Init(), etc. first. -// -// Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or -// Compute, since the nested functions may be of a different type. -template <class NES, class OBJ, class... ARGS> -class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> { - public: - using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>; - - // Clean up nested functions. - ~NestedFeatureFunction() override { - // Fully qualified class name, to avoid an ambiguity error when building for - // Android. - ::libtextclassifier::STLDeleteElements(&nested_); - } - - // By default, just appends the nested feature types. - void GetFeatureTypes(std::vector<FeatureType *> *types) const override { - // It's odd if a NestedFeatureFunction does not have anything nested inside - // it, so we crash in debug mode. Still, nothing should crash in prod mode. - TC_DCHECK(!this->nested().empty()) - << "Nested features require nested features to be defined."; - for (auto *function : nested_) function->GetFeatureTypes(types); - } - - // Sets up the nested features. - bool Setup(TaskContext *context) override { - bool success = CreateNested(this->extractor(), this->descriptor(), &nested_, - this->SubPrefix()); - if (!success) { - return false; - } - for (auto *function : nested_) { - if (!function->Setup(context)) return false; - } - if (!SetupNested(context)) { - return false; - } - return true; - } - - // Sets up this NestedFeatureFunction specifically. - virtual bool SetupNested(TaskContext *context) { return true; } - - // Initializes the nested features. - bool Init(TaskContext *context) override { - for (auto *function : nested_) { - if (!function->Init(context)) return false; - } - if (!InitNested(context)) return false; - return true; - } - - // Initializes this NestedFeatureFunction specifically. - virtual bool InitNested(TaskContext *context) { return true; } - - // Gets all the workspaces needed for the nested functions. - void RequestWorkspaces(WorkspaceRegistry *registry) override { - for (auto *function : nested_) function->RequestWorkspaces(registry); - } - - // Returns the list of nested feature functions. - const std::vector<NES *> &nested() const { return nested_; } - - // Instantiates nested feature functions for a feature function. Creates and - // initializes one feature function for each sub-descriptor in the feature - // descriptor. - static bool CreateNested(GenericFeatureExtractor *extractor, - FeatureFunctionDescriptor *fd, - std::vector<NES *> *functions, - const std::string &prefix) { - for (int i = 0; i < fd->feature_size(); ++i) { - FeatureFunctionDescriptor *sub = fd->mutable_feature(i); - NES *f = NES::Instantiate(extractor, sub, prefix); - if (f == nullptr) { - return false; - } - functions->push_back(f); - } - return true; - } - - protected: - // The nested feature functions, if any, in order of declaration in the - // feature descriptor. Owned. - std::vector<NES *> nested_; -}; - -// Base class for a nested feature function that takes nested features with the -// same signature as these features, i.e. a meta feature. For this class, we can -// provide preprocessing of the nested features. -template <class OBJ, class... ARGS> -class MetaFeatureFunction - : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ, - ARGS...> { - public: - // Preprocesses using the nested features. - void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override { - for (auto *function : this->nested_) { - function->Preprocess(workspaces, object); - } - } -}; - -// Template for a special type of locator: The locator of type -// FeatureFunction<OBJ, ARGS...> calls nested functions of type -// FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is -// responsible for translating by providing the following: -// -// // Gets the new additional focus. -// IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object); -// -// This is useful to e.g. add a token focus to a parser state based on some -// desired property of that state. -template <class DER, class OBJ, class IDX, class... ARGS> -class FeatureAddFocusLocator - : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ, - ARGS...> { - public: - void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override { - for (auto *function : this->nested_) { - function->Preprocess(workspaces, object); - } - } - - void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args, - FeatureVector *result) const override { - IDX focus = - static_cast<const DER *>(this)->GetFocus(workspaces, object, args...); - for (auto *function : this->nested()) { - function->Evaluate(workspaces, object, focus, args..., result); - } - } - - // Returns the first nested feature's computed value. - FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object, - ARGS... args, - const FeatureVector *result) const override { - IDX focus = - static_cast<const DER *>(this)->GetFocus(workspaces, object, args...); - return this->nested()[0]->Compute(workspaces, object, focus, args..., - result); - } -}; - -// CRTP feature locator class. This is a meta feature that modifies ARGS and -// then calls the nested feature functions with the modified ARGS. Note that in -// order for this template to work correctly, all of ARGS must be types for -// which the reference operator & can be interpreted as a pointer to the -// argument. The derived class DER must implement the UpdateFocus method which -// takes pointers to the ARGS arguments: -// -// // Updates the current arguments. -// void UpdateArgs(const OBJ &object, ARGS *...args) const; -template <class DER, class OBJ, class... ARGS> -class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> { - public: - // Feature locators have an additional check that there is no intrinsic type, - // but only in debug mode: having an intrinsic type here is odd, but not - // enough to motive a crash in prod. - void GetFeatureTypes(std::vector<FeatureType *> *types) const override { - TC_DCHECK_EQ(this->feature_type(), nullptr) - << "FeatureLocators should not have an intrinsic type."; - MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types); - } - - // Evaluates the locator. - void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args, - FeatureVector *result) const override { - static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...); - for (auto *function : this->nested()) { - function->Evaluate(workspaces, object, args..., result); - } - } - - // Returns the first nested feature's computed value. - FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object, - ARGS... args, - const FeatureVector *result) const override { - static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...); - return this->nested()[0]->Compute(workspaces, object, args..., result); - } -}; - -// Feature extractor for extracting features from objects of a certain class. -// Template type parameters are as defined for FeatureFunction. -template <class OBJ, class... ARGS> -class FeatureExtractor : public GenericFeatureExtractor { - public: - // Feature function type for top-level functions in the feature extractor. - typedef FeatureFunction<OBJ, ARGS...> Function; - typedef FeatureExtractor<OBJ, ARGS...> Self; - - // Feature locator type for the feature extractor. - template <class DER> - using Locator = FeatureLocator<DER, OBJ, ARGS...>; - - // Initializes feature extractor. - FeatureExtractor() {} - - ~FeatureExtractor() override { - // Fully qualified class name, to avoid an ambiguity error when building for - // Android. - ::libtextclassifier::STLDeleteElements(&functions_); - } - - // Sets up the feature extractor. Note that only top-level functions exist - // until Setup() is called. This does not take ownership over the context, - // which must outlive this. - bool Setup(TaskContext *context) { - for (Function *function : functions_) { - if (!function->Setup(context)) return false; - } - return true; - } - - // Initializes the feature extractor. Must be called after Setup(). This - // does not take ownership over the context, which must outlive this. - bool Init(TaskContext *context) { - for (Function *function : functions_) { - if (!function->Init(context)) return false; - } - if (!this->InitializeFeatureTypes()) { - return false; - } - return true; - } - - // Requests workspaces from the registry. Must be called after Init(), and - // before Preprocess(). Does not take ownership over registry. This should be - // the same registry used to initialize the WorkspaceSet used in Preprocess() - // and ExtractFeatures(). NB: This is a different ordering from that used in - // SentenceFeatureRepresentation style feature computation. - void RequestWorkspaces(WorkspaceRegistry *registry) { - for (auto *function : functions_) function->RequestWorkspaces(registry); - } - - // Preprocesses the object using feature functions for the phase. Must be - // called before any calls to ExtractFeatures() on that object and phase. - void Preprocess(WorkspaceSet *workspaces, OBJ *object) const { - for (Function *function : functions_) { - function->Preprocess(workspaces, object); - } - } - - // Extracts features from an object with a focus. This invokes all the - // top-level feature functions in the feature extractor. Only feature - // functions belonging to the specified phase are invoked. - void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object, - ARGS... args, FeatureVector *result) const { - result->reserve(this->feature_types()); - - // Extract features. - for (int i = 0; i < functions_.size(); ++i) { - functions_[i]->Evaluate(workspaces, object, args..., result); - } - } - - private: - // Creates and initializes all feature functions in the feature extractor. - bool InitializeFeatureFunctions() override { - // Create all top-level feature functions. - for (int i = 0; i < descriptor().feature_size(); ++i) { - FeatureFunctionDescriptor *fd = mutable_descriptor()->mutable_feature(i); - Function *function = Function::Instantiate(this, fd, ""); - if (function == nullptr) return false; - functions_.push_back(function); - } - return true; - } - - // Collect all feature types used in the feature extractor. - void GetFeatureTypes(std::vector<FeatureType *> *types) const override { - for (Function *function : functions_) { - function->GetFeatureTypes(types); - } - } - - // Top-level feature functions (and variables) in the feature extractor. - // Owned. INVARIANT: contains only non-null pointers. - std::vector<Function *> functions_; -}; - -#define REGISTER_FEATURE_FUNCTION(base, name, component) \ - REGISTER_CLASS_COMPONENT(base, name, component) - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_ diff --git a/common/feature-types.h b/common/feature-types.h deleted file mode 100644 index 92814d9..0000000 --- a/common/feature-types.h +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Common feature types for parser components. - -#ifndef LIBTEXTCLASSIFIER_COMMON_FEATURE_TYPES_H_ -#define LIBTEXTCLASSIFIER_COMMON_FEATURE_TYPES_H_ - -#include <algorithm> -#include <map> -#include <string> -#include <utility> - -#include "util/base/integral_types.h" -#include "util/base/logging.h" -#include "util/strings/numbers.h" - -namespace libtextclassifier { -namespace nlp_core { - -// TODO(djweiss) Clean this up as well. -// Use the same type for feature values as is used for predicated. -typedef int64 Predicate; -typedef Predicate FeatureValue; - -// Each feature value in a feature vector has a feature type. The feature type -// is used for converting feature type and value pairs to predicate values. The -// feature type can also return names for feature values and calculate the size -// of the feature value domain. The FeatureType class is abstract and must be -// specialized for the concrete feature types. -class FeatureType { - public: - // Initializes a feature type. - explicit FeatureType(const std::string &name) - : name_(name), base_(0), - is_continuous_(name.find("continuous") != std::string::npos) { - } - - virtual ~FeatureType() {} - - // Converts a feature value to a name. - virtual std::string GetFeatureValueName(FeatureValue value) const = 0; - - // Returns the size of the feature values domain. - virtual int64 GetDomainSize() const = 0; - - // Returns the feature type name. - const std::string &name() const { return name_; } - - Predicate base() const { return base_; } - void set_base(Predicate base) { base_ = base; } - - // Returns true iff this feature is continuous; see FloatFeatureValue. - bool is_continuous() const { return is_continuous_; } - - private: - // Feature type name. - std::string name_; - - // "Base" feature value: i.e. a "slot" in a global ordering of features. - Predicate base_; - - // See doc for is_continuous(). - bool is_continuous_; -}; - -// Feature type that is defined using an explicit map from FeatureValue to -// std::string values. This can reduce some of the boilerplate when defining -// features that generate enum values. Example usage: -// -// class BeverageSizeFeature : public FeatureFunction<Beverage> -// enum FeatureValue { SMALL, MEDIUM, LARGE }; // values for this feature -// void Init(TaskContext *context) override { -// set_feature_type(new EnumFeatureType("beverage_size", -// {{SMALL, "SMALL"}, {MEDIUM, "MEDIUM"}, {LARGE, "LARGE"}}); -// } -// [...] -// }; -class EnumFeatureType : public FeatureType { - public: - EnumFeatureType(const std::string &name, - const std::map<FeatureValue, std::string> &value_names) - : FeatureType(name), value_names_(value_names) { - for (const auto &pair : value_names) { - TC_DCHECK_GE(pair.first, 0) - << "Invalid feature value: " << pair.first << ", " << pair.second; - domain_size_ = std::max(domain_size_, pair.first + 1); - } - } - - // Returns the feature name for a given feature value. - std::string GetFeatureValueName(FeatureValue value) const override { - auto it = value_names_.find(value); - if (it == value_names_.end()) { - TC_LOG(ERROR) << "Invalid feature value " << value << " for " << name(); - return "<INVALID>"; - } - return it->second; - } - - // Returns the number of possible values for this feature type. This is one - // greater than the largest value in the value_names map. - FeatureValue GetDomainSize() const override { return domain_size_; } - - protected: - // Maximum possible value this feature could take. - FeatureValue domain_size_ = 0; - - // Names of feature values. - std::map<FeatureValue, std::string> value_names_; -}; - -// Feature type for binary features. -class BinaryFeatureType : public FeatureType { - public: - BinaryFeatureType(const std::string &name, const std::string &off, - const std::string &on) - : FeatureType(name), off_(off), on_(on) {} - - // Returns the feature name for a given feature value. - std::string GetFeatureValueName(FeatureValue value) const override { - if (value == 0) return off_; - if (value == 1) return on_; - return ""; - } - - // Binary features always have two feature values. - FeatureValue GetDomainSize() const override { return 2; } - - private: - // Feature value names for on and off. - std::string off_; - std::string on_; -}; - -// Feature type for numeric features. -class NumericFeatureType : public FeatureType { - public: - // Initializes numeric feature. - NumericFeatureType(const std::string &name, FeatureValue size) - : FeatureType(name), size_(size) {} - - // Returns numeric feature value. - std::string GetFeatureValueName(FeatureValue value) const override { - if (value < 0) return ""; - return IntToString(value); - } - - // Returns the number of feature values. - FeatureValue GetDomainSize() const override { return size_; } - - private: - // The underlying size of the numeric feature. - FeatureValue size_; -}; - -// Feature type for byte features, including an "outside" value. -class ByteFeatureType : public NumericFeatureType { - public: - explicit ByteFeatureType(const std::string &name) - : NumericFeatureType(name, 257) {} - - std::string GetFeatureValueName(FeatureValue value) const override { - if (value == 256) { - return "<NULL>"; - } - std::string result; - result += static_cast<char>(value); - return result; - } -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_FEATURE_TYPES_H_ diff --git a/common/file-utils.cc b/common/file-utils.cc deleted file mode 100644 index 6ae4442..0000000 --- a/common/file-utils.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/file-utils.h" - -#include <fcntl.h> -#include <stdio.h> -#include <sys/stat.h> -#include <sys/types.h> - -#include <fstream> -#include <memory> -#include <string> - -#include "util/base/logging.h" - -namespace libtextclassifier { -namespace nlp_core { - -namespace file_utils { - -bool GetFileContent(const std::string &filename, std::string *content) { - std::ifstream input_stream(filename, std::ifstream::binary); - if (input_stream.fail()) { - TC_LOG(INFO) << "Error opening " << filename; - return false; - } - - content->assign( - std::istreambuf_iterator<char>(input_stream), - std::istreambuf_iterator<char>()); - - if (input_stream.fail()) { - TC_LOG(ERROR) << "Error reading " << filename; - return false; - } - - TC_LOG(INFO) << "Successfully read " << filename; - return true; -} - -bool FileExists(const std::string &filename) { - struct stat s = {0}; - if (!stat(filename.c_str(), &s)) { - return s.st_mode & S_IFREG; - } else { - return false; - } -} - -bool DirectoryExists(const std::string &dirpath) { - struct stat s = {0}; - if (!stat(dirpath.c_str(), &s)) { - return s.st_mode & S_IFDIR; - } else { - return false; - } -} - -} // namespace file_utils - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/file-utils.h b/common/file-utils.h deleted file mode 100644 index e2a60f2..0000000 --- a/common/file-utils.h +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_FILE_UTILS_H_ -#define LIBTEXTCLASSIFIER_COMMON_FILE_UTILS_H_ - -#include <cstddef> -#include <memory> -#include <string> - -#include "common/config.h" - -#if PORTABLE_SAFT_MOBILE -#include <google/protobuf/io/zero_copy_stream_impl_lite.h> -#endif - -#include "common/mmap.h" -#include "util/strings/stringpiece.h" - -namespace libtextclassifier { -namespace nlp_core { - -namespace file_utils { - -// Reads the entire content of a file into a string. Returns true on success, -// false on error. -bool GetFileContent(const std::string &filename, std::string *content); - -// Parses a proto from its serialized representation in memory. That -// representation starts at address sp.data() and contains exactly sp.size() -// bytes. Returns true on success, false otherwise. -template<class Proto> -bool ParseProtoFromMemory(StringPiece sp, Proto *proto) { - if (!sp.data()) { - // Avoid passing a nullptr to ArrayInputStream below. - return false; - } -#if PORTABLE_SAFT_MOBILE - ::google::protobuf::io::ArrayInputStream stream(sp.data(), sp.size()); - return proto->ParseFromZeroCopyStream(&stream); -#else - - std::string data(sp.data(), sp.size()); - return proto->ParseFromString(data); -#endif -} - -// Parses a proto from a file. Returns true on success, false otherwise. -// -// Note: the entire content of the file should be the binary (not -// human-readable) serialization of a protocol buffer. -// -// Note: when we compile for Android, the proto parsing methods need to know the -// type of the message they are parsing. We use template polymorphism for that. -template<class Proto> -bool ReadProtoFromFile(const std::string &filename, Proto *proto) { - ScopedMmap scoped_mmap(filename); - const MmapHandle &handle = scoped_mmap.handle(); - if (!handle.ok()) { - return false; - } - return ParseProtoFromMemory(handle.to_stringpiece(), proto); -} - -// Returns true if filename is the name of an existing file, and false -// otherwise. -bool FileExists(const std::string &filename); - -// Returns true if dirpath is the path to an existing directory, and false -// otherwise. -bool DirectoryExists(const std::string &dirpath); - -} // namespace file_utils - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_FILE_UTILS_H_ diff --git a/common/float16.h b/common/float16.h deleted file mode 100644 index 8b52be3..0000000 --- a/common/float16.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_FLOAT16_H_ -#define LIBTEXTCLASSIFIER_COMMON_FLOAT16_H_ - -#include "util/base/casts.h" -#include "util/base/integral_types.h" - -namespace libtextclassifier { -namespace nlp_core { - -// 16 bit encoding of a float. NOTE: can't be used directly for computation: -// one first needs to convert it to a normal float, using Float16To32. -// -// Documentation copied from original file: -// -// Compact 16-bit encoding of floating point numbers. This -// representation uses 1 bit for the sign, 8 bits for the exponent and -// 7 bits for the mantissa. It is assumed that floats are in IEEE 754 -// format so a float16 is just bits 16-31 of a single precision float. -// -// NOTE: The IEEE floating point standard defines a float16 format that -// is different than this format (it has fewer bits of exponent and more -// bits of mantissa). We don't use that format here because conversion -// to/from 32-bit floats is more complex for that format, and the -// conversion for this format is very simple. -// -// <---------float16------------> -// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f -// <------------------------------float--------------------------> -// 3 3 2 2 1 1 0 -// 1 0 3 2 5 4 0 - -typedef uint16 float16; - -static inline float16 Float32To16(float f) { - // Note that we just truncate the mantissa bits: we make no effort to - // do any smarter rounding. - return (bit_cast<uint32>(f) >> 16) & 0xffff; -} - -static inline float Float16To32(float16 f) { - // We fill in the new mantissa bits with 0, and don't do anything smarter. - return bit_cast<float>(f << 16); -} - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_FLOAT16_H_ diff --git a/common/fml-parser.cc b/common/fml-parser.cc deleted file mode 100644 index 2964671..0000000 --- a/common/fml-parser.cc +++ /dev/null @@ -1,329 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/fml-parser.h" - -#include <ctype.h> -#include <string> - -#include "util/base/logging.h" -#include "util/strings/numbers.h" - -namespace libtextclassifier { -namespace nlp_core { - -namespace { -inline bool IsValidCharAtStartOfIdentifier(char c) { - return isalpha(c) || (c == '_') || (c == '/'); -} - -// Returns true iff character c can appear inside an identifier. -inline bool IsValidCharInsideIdentifier(char c) { - return isalnum(c) || (c == '_') || (c == '-') || (c == '/'); -} - -// Returns true iff character c can appear at the beginning of a number. -inline bool IsValidCharAtStartOfNumber(char c) { - return isdigit(c) || (c == '+') || (c == '-'); -} - -// Returns true iff character c can appear inside a number. -inline bool IsValidCharInsideNumber(char c) { - return isdigit(c) || (c == '.'); -} -} // namespace - -bool FMLParser::Initialize(const std::string &source) { - // Initialize parser state. - source_ = source; - current_ = source_.begin(); - item_start_ = line_start_ = current_; - line_number_ = item_line_number_ = 1; - - // Read first input item. - return NextItem(); -} - -void FMLParser::ReportError(const std::string &error_message) { - const int position = item_start_ - line_start_ + 1; - const std::string line(line_start_, current_); - - TC_LOG(ERROR) << "Error in feature model, line " << item_line_number_ - << ", position " << position << ": " << error_message - << "\n " << line << " <--HERE"; -} - -void FMLParser::Next() { - // Move to the next input character. If we are at a line break update line - // number and line start position. - if (CurrentChar() == '\n') { - ++line_number_; - ++current_; - line_start_ = current_; - } else { - ++current_; - } -} - -bool FMLParser::NextItem() { - // Skip white space and comments. - while (!eos()) { - if (CurrentChar() == '#') { - // Skip comment. - while (!eos() && CurrentChar() != '\n') Next(); - } else if (isspace(CurrentChar())) { - // Skip whitespace. - while (!eos() && isspace(CurrentChar())) Next(); - } else { - break; - } - } - - // Record start position for next item. - item_start_ = current_; - item_line_number_ = line_number_; - - // Check for end of input. - if (eos()) { - item_type_ = END; - return true; - } - - // Parse number. - if (IsValidCharAtStartOfNumber(CurrentChar())) { - std::string::iterator start = current_; - Next(); - while (!eos() && IsValidCharInsideNumber(CurrentChar())) Next(); - item_text_.assign(start, current_); - item_type_ = NUMBER; - return true; - } - - // Parse std::string. - if (CurrentChar() == '"') { - Next(); - std::string::iterator start = current_; - while (CurrentChar() != '"') { - if (eos()) { - ReportError("Unterminated string"); - return false; - } - Next(); - } - item_text_.assign(start, current_); - item_type_ = STRING; - Next(); - return true; - } - - // Parse identifier name. - if (IsValidCharAtStartOfIdentifier(CurrentChar())) { - std::string::iterator start = current_; - while (!eos() && IsValidCharInsideIdentifier(CurrentChar())) { - Next(); - } - item_text_.assign(start, current_); - item_type_ = NAME; - return true; - } - - // Single character item. - item_type_ = CurrentChar(); - Next(); - return true; -} - -bool FMLParser::Parse(const std::string &source, - FeatureExtractorDescriptor *result) { - // Initialize parser. - if (!Initialize(source)) { - return false; - } - - while (item_type_ != END) { - // Current item should be a feature name. - if (item_type_ != NAME) { - ReportError("Feature type name expected"); - return false; - } - std::string name = item_text_; - if (!NextItem()) { - return false; - } - - // Parse feature. - FeatureFunctionDescriptor *descriptor = result->add_feature(); - descriptor->set_type(name); - if (!ParseFeature(descriptor)) { - return false; - } - } - - return true; -} - -bool FMLParser::ParseFeature(FeatureFunctionDescriptor *result) { - // Parse argument and parameters. - if (item_type_ == '(') { - if (!NextItem()) return false; - if (!ParseParameter(result)) return false; - while (item_type_ == ',') { - if (!NextItem()) return false; - if (!ParseParameter(result)) return false; - } - - if (item_type_ != ')') { - ReportError(") expected"); - return false; - } - if (!NextItem()) return false; - } - - // Parse feature name. - if (item_type_ == ':') { - if (!NextItem()) return false; - if (item_type_ != NAME && item_type_ != STRING) { - ReportError("Feature name expected"); - return false; - } - std::string name = item_text_; - if (!NextItem()) return false; - - // Set feature name. - result->set_name(name); - } - - // Parse sub-features. - if (item_type_ == '.') { - // Parse dotted sub-feature. - if (!NextItem()) return false; - if (item_type_ != NAME) { - ReportError("Feature type name expected"); - return false; - } - std::string type = item_text_; - if (!NextItem()) return false; - - // Parse sub-feature. - FeatureFunctionDescriptor *subfeature = result->add_feature(); - subfeature->set_type(type); - if (!ParseFeature(subfeature)) return false; - } else if (item_type_ == '{') { - // Parse sub-feature block. - if (!NextItem()) return false; - while (item_type_ != '}') { - if (item_type_ != NAME) { - ReportError("Feature type name expected"); - return false; - } - std::string type = item_text_; - if (!NextItem()) return false; - - // Parse sub-feature. - FeatureFunctionDescriptor *subfeature = result->add_feature(); - subfeature->set_type(type); - if (!ParseFeature(subfeature)) return false; - } - if (!NextItem()) return false; - } - return true; -} - -bool FMLParser::ParseParameter(FeatureFunctionDescriptor *result) { - if (item_type_ == NUMBER) { - int32 argument; - if (!ParseInt32(item_text_.c_str(), &argument)) { - ReportError("Unable to parse number"); - return false; - } - if (!NextItem()) return false; - - // Set default argument for feature. - result->set_argument(argument); - } else if (item_type_ == NAME) { - std::string name = item_text_; - if (!NextItem()) return false; - if (item_type_ != '=') { - ReportError("= expected"); - return false; - } - if (!NextItem()) return false; - if (item_type_ >= END) { - ReportError("Parameter value expected"); - return false; - } - std::string value = item_text_; - if (!NextItem()) return false; - - // Add parameter to feature. - Parameter *parameter; - parameter = result->add_parameter(); - parameter->set_name(name); - parameter->set_value(value); - } else { - ReportError("Syntax error in parameter list"); - return false; - } - return true; -} - -void ToFMLFunction(const FeatureFunctionDescriptor &function, - std::string *output) { - output->append(function.type()); - if (function.argument() != 0 || function.parameter_size() > 0) { - output->append("("); - bool first = true; - if (function.argument() != 0) { - output->append(IntToString(function.argument())); - first = false; - } - for (int i = 0; i < function.parameter_size(); ++i) { - if (!first) output->append(","); - output->append(function.parameter(i).name()); - output->append("="); - output->append("\""); - output->append(function.parameter(i).value()); - output->append("\""); - first = false; - } - output->append(")"); - } -} - -void ToFML(const FeatureFunctionDescriptor &function, std::string *output) { - ToFMLFunction(function, output); - if (function.feature_size() == 1) { - output->append("."); - ToFML(function.feature(0), output); - } else if (function.feature_size() > 1) { - output->append(" { "); - for (int i = 0; i < function.feature_size(); ++i) { - if (i > 0) output->append(" "); - ToFML(function.feature(i), output); - } - output->append(" } "); - } -} - -void ToFML(const FeatureExtractorDescriptor &extractor, std::string *output) { - for (int i = 0; i < extractor.feature_size(); ++i) { - ToFML(extractor.feature(i), output); - output->append("\n"); - } -} - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/fml-parser.h b/common/fml-parser.h deleted file mode 100644 index b6b9da2..0000000 --- a/common/fml-parser.h +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Feature modeling language (fml) parser. -// -// BNF grammar for fml: -// -// <feature model> ::= { <feature extractor> } -// -// <feature extractor> ::= <extractor spec> | -// <extractor spec> '.' <feature extractor> | -// <extractor spec> '{' { <feature extractor> } '}' -// -// <extractor spec> ::= <extractor type> -// [ '(' <parameter list> ')' ] -// [ ':' <extractor name> ] -// -// <parameter list> = ( <parameter> | <argument> ) { ',' <parameter> } -// -// <parameter> ::= <parameter name> '=' <parameter value> -// -// <extractor type> ::= NAME -// <extractor name> ::= NAME | STRING -// <argument> ::= NUMBER -// <parameter name> ::= NAME -// <parameter value> ::= NUMBER | STRING | NAME - -#ifndef LIBTEXTCLASSIFIER_COMMON_FML_PARSER_H_ -#define LIBTEXTCLASSIFIER_COMMON_FML_PARSER_H_ - -#include <string> -#include <vector> - -#include "common/feature-descriptors.h" -#include "util/base/logging.h" - -namespace libtextclassifier { -namespace nlp_core { - -class FMLParser { - public: - // Parses fml specification into feature extractor descriptor. - // Returns true on success, false on error (e.g., syntax errors). - bool Parse(const std::string &source, FeatureExtractorDescriptor *result); - - private: - // Initializes the parser with the source text. - // Returns true on success, false on syntax error. - bool Initialize(const std::string &source); - - // Outputs an error message, with context info, and sets error_ to true. - void ReportError(const std::string &error_message); - - // Moves to the next input character. - void Next(); - - // Moves to the next input item. Sets item_text_ and item_type_ accordingly. - // Returns true on success, false on syntax error. - bool NextItem(); - - // Parses a feature descriptor. - // Returns true on success, false on syntax error. - bool ParseFeature(FeatureFunctionDescriptor *result); - - // Parses a parameter specification. - // Returns true on success, false on syntax error. - bool ParseParameter(FeatureFunctionDescriptor *result); - - // Returns true if end of source input has been reached. - bool eos() const { return current_ >= source_.end(); } - - // Returns current character. Other methods should access the current - // character through this method (instead of using *current_ directly): this - // method performs extra safety checks. - // - // In case of an unsafe access, returns '\0'. - char CurrentChar() const { - if ((current_ >= source_.begin()) && (current_ < source_.end())) { - return *current_; - } else { - TC_LOG(ERROR) << "Unsafe char read"; - return '\0'; - } - } - - // Item types. - enum ItemTypes { - END = 0, - NAME = -1, - NUMBER = -2, - STRING = -3, - }; - - // Source text. - std::string source_; - - // Current input position. - std::string::iterator current_; - - // Line number for current input position. - int line_number_; - - // Start position for current item. - std::string::iterator item_start_; - - // Start position for current line. - std::string::iterator line_start_; - - // Line number for current item. - int item_line_number_; - - // Item type for current item. If this is positive it is interpreted as a - // character. If it is negative it is interpreted as an item type. - int item_type_; - - // Text for current item. - std::string item_text_; -}; - -// Converts a FeatureFunctionDescriptor into an FML spec (reverse of parsing). -void ToFML(const FeatureFunctionDescriptor &function, std::string *output); - -// Like ToFML, but doesn't go into the nested functions. Instead, it generates -// a string that starts with the name of the feature extraction function and -// next, in-between parentheses, the parameters, separated by comma. -// Intuitively, the constructed string is the prefix of ToFML, before the "{" -// that starts the nested features. -void ToFMLFunction(const FeatureFunctionDescriptor &function, - std::string *output); - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_FML_PARSER_H_ diff --git a/common/fml-parser_test.cc b/common/fml-parser_test.cc deleted file mode 100644 index b46048f..0000000 --- a/common/fml-parser_test.cc +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/fml-parser.h" - -#include "common/feature-descriptors.h" -#include "gtest/gtest.h" - -namespace libtextclassifier { -namespace nlp_core { - -TEST(FMLParserTest, NoFeature) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - const std::string kFeatureName = ""; - EXPECT_TRUE(fml_parser.Parse(kFeatureName, &descriptor)); - EXPECT_EQ(0, descriptor.feature_size()); -} - -TEST(FMLParserTest, FeatureWithNoParams) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - const std::string kFeatureName = "continuous-bag-of-relevant-scripts"; - EXPECT_TRUE(fml_parser.Parse(kFeatureName, &descriptor)); - EXPECT_EQ(1, descriptor.feature_size()); - EXPECT_EQ(kFeatureName, descriptor.feature(0).type()); -} - -TEST(FMLParserTest, FeatureWithOneKeywordParameter) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_TRUE(fml_parser.Parse("myfeature(start=2)", &descriptor)); - EXPECT_EQ(1, descriptor.feature_size()); - EXPECT_EQ("myfeature", descriptor.feature(0).type()); - EXPECT_EQ(1, descriptor.feature(0).parameter_size()); - EXPECT_EQ("start", descriptor.feature(0).parameter(0).name()); - EXPECT_EQ("2", descriptor.feature(0).parameter(0).value()); - EXPECT_FALSE(descriptor.feature(0).has_argument()); -} - -TEST(FMLParserTest, FeatureWithDefaultArgumentNegative) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_TRUE(fml_parser.Parse("offset(-3)", &descriptor)); - EXPECT_EQ(1, descriptor.feature_size()); - EXPECT_EQ("offset", descriptor.feature(0).type()); - EXPECT_EQ(0, descriptor.feature(0).parameter_size()); - EXPECT_EQ(-3, descriptor.feature(0).argument()); -} - -TEST(FMLParserTest, FeatureWithDefaultArgumentPositive) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_TRUE(fml_parser.Parse("delta(7)", &descriptor)); - EXPECT_EQ(1, descriptor.feature_size()); - EXPECT_EQ("delta", descriptor.feature(0).type()); - EXPECT_EQ(0, descriptor.feature(0).parameter_size()); - EXPECT_EQ(7, descriptor.feature(0).argument()); -} - -TEST(FMLParserTest, FeatureWithDefaultArgumentZero) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_TRUE(fml_parser.Parse("delta(0)", &descriptor)); - EXPECT_EQ(1, descriptor.feature_size()); - EXPECT_EQ("delta", descriptor.feature(0).type()); - EXPECT_EQ(0, descriptor.feature(0).parameter_size()); - EXPECT_EQ(0, descriptor.feature(0).argument()); -} - -TEST(FMLParserTest, FeatureWithManyKeywordParameters) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_TRUE(fml_parser.Parse("myfeature(ratio=0.316,start=2,name=\"foo\")", - &descriptor)); - EXPECT_EQ(1, descriptor.feature_size()); - EXPECT_EQ("myfeature", descriptor.feature(0).type()); - EXPECT_EQ(3, descriptor.feature(0).parameter_size()); - EXPECT_EQ("ratio", descriptor.feature(0).parameter(0).name()); - EXPECT_EQ("0.316", descriptor.feature(0).parameter(0).value()); - EXPECT_EQ("start", descriptor.feature(0).parameter(1).name()); - EXPECT_EQ("2", descriptor.feature(0).parameter(1).value()); - EXPECT_EQ("name", descriptor.feature(0).parameter(2).name()); - EXPECT_EQ("foo", descriptor.feature(0).parameter(2).value()); - EXPECT_FALSE(descriptor.feature(0).has_argument()); -} - -TEST(FMLParserTest, FeatureWithAllKindsOfParameters) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_TRUE( - fml_parser.Parse("myfeature(17,ratio=0.316,start=2)", &descriptor)); - EXPECT_EQ(1, descriptor.feature_size()); - EXPECT_EQ("myfeature", descriptor.feature(0).type()); - EXPECT_EQ(2, descriptor.feature(0).parameter_size()); - EXPECT_EQ("ratio", descriptor.feature(0).parameter(0).name()); - EXPECT_EQ("0.316", descriptor.feature(0).parameter(0).value()); - EXPECT_EQ("start", descriptor.feature(0).parameter(1).name()); - EXPECT_EQ("2", descriptor.feature(0).parameter(1).value()); - EXPECT_EQ(17, descriptor.feature(0).argument()); -} - -TEST(FMLParserTest, FeatureWithWhitespaces) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_TRUE(fml_parser.Parse( - " myfeature\t\t\t\n(17,\nratio=0.316 , start=2) ", &descriptor)); - EXPECT_EQ(1, descriptor.feature_size()); - EXPECT_EQ("myfeature", descriptor.feature(0).type()); - EXPECT_EQ(2, descriptor.feature(0).parameter_size()); - EXPECT_EQ("ratio", descriptor.feature(0).parameter(0).name()); - EXPECT_EQ("0.316", descriptor.feature(0).parameter(0).value()); - EXPECT_EQ("start", descriptor.feature(0).parameter(1).name()); - EXPECT_EQ("2", descriptor.feature(0).parameter(1).value()); - EXPECT_EQ(17, descriptor.feature(0).argument()); -} - -TEST(FMLParserTest, Broken_ParamWithoutValue) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_FALSE( - fml_parser.Parse("myfeature(17,ratio=0.316,start)", &descriptor)); -} - -TEST(FMLParserTest, Broken_MissingCloseParen) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_FALSE(fml_parser.Parse("myfeature(17,ratio=0.316", &descriptor)); -} - -TEST(FMLParserTest, Broken_MissingOpenParen) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_FALSE(fml_parser.Parse("myfeature17,ratio=0.316)", &descriptor)); -} - -TEST(FMLParserTest, Broken_MissingQuote) { - FMLParser fml_parser; - FeatureExtractorDescriptor descriptor; - EXPECT_FALSE(fml_parser.Parse("count(17,name=\"foo)", &descriptor)); -} - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/list-of-strings.proto b/common/list-of-strings.proto deleted file mode 100644 index 5ba45ed..0000000 --- a/common/list-of-strings.proto +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (C) 2017 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto2"; -option optimize_for = LITE_RUNTIME; - -package libtextclassifier.nlp_core; - -message ListOfStrings { - repeated string element = 1; -} diff --git a/common/little-endian-data.h b/common/little-endian-data.h deleted file mode 100644 index e3bc88f..0000000 --- a/common/little-endian-data.h +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_LITTLE_ENDIAN_DATA_H_ -#define LIBTEXTCLASSIFIER_COMMON_LITTLE_ENDIAN_DATA_H_ - -#include <algorithm> -#include <string> -#include <vector> - -#include "util/base/endian.h" -#include "util/base/logging.h" - -namespace libtextclassifier { -namespace nlp_core { - -// Swaps the sizeof(T) bytes that start at addr. E.g., if sizeof(T) == 2, -// then (addr[0], addr[1]) -> (addr[1], addr[0]). Useful for little endian -// <-> big endian conversions. -template <class T> -void SwapBytes(T *addr) { - char *char_ptr = reinterpret_cast<char *>(addr); - std::reverse(char_ptr, char_ptr + sizeof(T)); -} - -// Assuming addr points to a piece of data of type T, with its bytes in the -// little/big endian order specific to the machine this code runs on, this -// method will re-arrange the bytes (in place) in little-endian order. -template <class T> -void HostToLittleEndian(T *addr) { - if (LittleEndian::IsLittleEndian()) { - // Do nothing: current machine is little-endian. - } else { - SwapBytes(addr); - } -} - -// Reverse of HostToLittleEndian. -template <class T> -void LittleEndianToHost(T *addr) { - // It turns out it's the same function: on little-endian machines, do nothing - // (source and target formats are identical). Otherwise, swap bytes. - HostToLittleEndian(addr); -} - -// Returns string obtained by concatenating the bytes of the elements from a -// vector (in order: v[0], v[1], etc). If the type T requires more than one -// byte, the byte for each element are first converted to little-endian format. -template<typename T> -std::string GetDataBytesInLittleEndianOrder(const std::vector<T> &v) { - std::string data_bytes; - for (const T element : v) { - T little_endian_element = element; - HostToLittleEndian(&little_endian_element); - data_bytes.append( - reinterpret_cast<const char *>(&little_endian_element), - sizeof(T)); - } - return data_bytes; -} - -// Performs reverse of GetDataBytesInLittleEndianOrder. -// -// I.e., decodes the data bytes from parameter bytes into num_elements Ts, and -// places them in the vector v (previous content of that vector is erased). -// -// We expect bytes to contain the concatenation of the bytes for exactly -// num_elements elements of type T. If the type T requires more than one byte, -// those bytes should be arranged in little-endian form. -// -// Returns true on success and false otherwise (e.g., bytes has the wrong size). -// Note: we do not want to crash on corrupted data (some clients, e..g, GMSCore, -// have asked us not to do so). Instead, we report the error and let the client -// decide what to do. On error, we also fill the vector with zeros, such that -// at least the dimension of v matches expectations. -template<typename T> -bool FillVectorFromDataBytesInLittleEndian( - const std::string &bytes, int num_elements, std::vector<T> *v) { - if (bytes.size() != num_elements * sizeof(T)) { - TC_LOG(ERROR) << "Wrong number of bytes: actual " << bytes.size() - << " vs expected " << num_elements - << " elements of sizeof(element) = " << sizeof(T) - << " bytes each ; will fill vector with zeros"; - v->assign(num_elements, static_cast<T>(0)); - return false; - } - v->clear(); - v->reserve(num_elements); - const T *start = reinterpret_cast<const T *>(bytes.data()); - if (LittleEndian::IsLittleEndian() || (sizeof(T) == 1)) { - // Fast in the common case ([almost] all hardware today is little-endian): - // if same endianness (or type T requires a single byte and endianness - // irrelevant), just use the bytes. - v->assign(start, start + num_elements); - } else { - // Slower (but very rare case): this code runs on a big endian machine and - // the type T requires more than one byte. Hence, some conversion is - // necessary. - for (int i = 0; i < num_elements; ++i) { - T temp = start[i]; - SwapBytes(&temp); - v->push_back(temp); - } - } - return true; -} - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_LITTLE_ENDIAN_DATA_H_ diff --git a/common/memory_image/data-store.cc b/common/memory_image/data-store.cc deleted file mode 100644 index a5f500c..0000000 --- a/common/memory_image/data-store.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/memory_image/data-store.h" - -#include "util/base/logging.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace memory_image { - -DataStore::DataStore(StringPiece bytes) - : reader_(bytes.data(), bytes.size()) { - if (!reader_.success_status()) { - TC_LOG(ERROR) << "Unable to successfully initialize DataStore."; - } -} - -StringPiece DataStore::GetData(const std::string &name) const { - if (!reader_.success_status()) { - TC_LOG(ERROR) << "DataStore::GetData(" << name << ") " - << "called on invalid DataStore; will return empty data " - << "chunk"; - return StringPiece(); - } - - const auto &entries = reader_.trimmed_proto().entries(); - const auto &it = entries.find(name); - if (it == entries.end()) { - TC_LOG(ERROR) << "Unknown key: " << name - << "; will return empty data chunk"; - return StringPiece(); - } - - const DataStoreEntryBytes &entry_bytes = it->second; - if (!entry_bytes.has_blob_index()) { - TC_LOG(ERROR) << "DataStoreEntryBytes with no blob_index; " - << "will return empty data chunk."; - return StringPiece(); - } - - int blob_index = entry_bytes.blob_index(); - return reader_.data_blob_view(blob_index); -} - -} // namespace memory_image -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/memory_image/data-store.h b/common/memory_image/data-store.h deleted file mode 100644 index 56aa4fc..0000000 --- a/common/memory_image/data-store.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_DATA_STORE_H_ -#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_DATA_STORE_H_ - -#include <string> - -#include "common/memory_image/data-store.pb.h" -#include "common/memory_image/memory-image-reader.h" -#include "util/strings/stringpiece.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace memory_image { - -// Class to access a data store. See usage example in comments for -// DataStoreBuilder. -class DataStore { - public: - // Constructs a DataStore using the indicated bytes, i.e., bytes.size() bytes - // starting at address bytes.data(). These bytes should contain the - // serialization of a data store, see DataStoreBuilder::SerializeAsString(). - explicit DataStore(StringPiece bytes); - - // Retrieves (start_addr, num_bytes) info for piece of memory that contains - // the data associated with the indicated name. Note: this piece of memory is - // inside the [start, start + size) (see constructor). This piece of memory - // starts at an offset from start which is a multiple of the alignment - // specified when the data store was built using DataStoreBuilder. - // - // If the alignment is a low power of 2 (e..g, 4, 8, or 16) and "start" passed - // to constructor corresponds to the beginning of a memory page or an address - // returned by new or malloc(), then start_addr is divisible with alignment. - StringPiece GetData(const std::string &name) const; - - private: - MemoryImageReader<DataStoreProto> reader_; -}; - -} // namespace memory_image -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_DATA_STORE_H_ diff --git a/common/memory_image/data-store.proto b/common/memory_image/data-store.proto deleted file mode 100644 index 68e914a..0000000 --- a/common/memory_image/data-store.proto +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (C) 2017 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Protos for a data store: a barebone in-memory file system. -// -// A DataStore maintains an association between names and chunks of bytes. It -// can be serialized into a string. Of course, it can be deserialized from a -// string, with minimal parsing; after deserialization, all chunks of bytes -// start at aligned addresses (aligned = multiple of an address specified at -// build time). - -syntax = "proto2"; -option optimize_for = LITE_RUNTIME; - -package libtextclassifier.nlp_core.memory_image; - -// Bytes for a data store entry. They can be stored either directly in the -// "data" field, or in the DataBlob with the 0-based index "blob_index". -message DataStoreEntryBytes { - oneof data { - // Bytes for this data store entry, stored in this message. - string in_place_data = 1; - - // 0-based index of the data blob with bytes for this data store entry. In - // this case, the actual bytes are stored outside this message; the - // DataStore code handles the association. - int32 blob_index = 2 [default = -1]; - } -} - -message DataStoreProto { - map<string, DataStoreEntryBytes> entries = 1; -} diff --git a/common/memory_image/embedding-network-params-from-image.h b/common/memory_image/embedding-network-params-from-image.h deleted file mode 100644 index e8c7d1e..0000000 --- a/common/memory_image/embedding-network-params-from-image.h +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_ -#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_ - -#include "common/embedding-network-package.pb.h" -#include "common/embedding-network-params.h" -#include "common/embedding-network.pb.h" -#include "common/memory_image/memory-image-reader.h" -#include "util/base/integral_types.h" -#include "util/strings/stringpiece.h" - -namespace libtextclassifier { -namespace nlp_core { - -// EmbeddingNetworkParams backed by a memory image. -// -// In this context, a memory image is like an EmbeddingNetworkProto, but with -// all repeated weights (>99% of the size) directly usable (with no parsing -// required). -class EmbeddingNetworkParamsFromImage : public EmbeddingNetworkParams { - public: - // Constructs an EmbeddingNetworkParamsFromImage, using the memory image that - // starts at address start and contains num_bytes bytes. - EmbeddingNetworkParamsFromImage(const void *start, uint64 num_bytes) - : memory_reader_(start, num_bytes), - trimmed_proto_(memory_reader_.trimmed_proto()) { - embeddings_blob_offset_ = 0; - - hidden_blob_offset_ = embeddings_blob_offset_ + embeddings_size(); - if (trimmed_proto_.embeddings_size() && - trimmed_proto_.embeddings(0).is_quantized()) { - // Adjust for quantization: each quantized matrix takes two blobs (instead - // of one): one for the quantized values and one for the scales. - hidden_blob_offset_ += embeddings_size(); - } - - hidden_bias_blob_offset_ = hidden_blob_offset_ + hidden_size(); - softmax_blob_offset_ = hidden_bias_blob_offset_ + hidden_bias_size(); - softmax_bias_blob_offset_ = softmax_blob_offset_ + softmax_size(); - } - - ~EmbeddingNetworkParamsFromImage() override {} - - const TaskSpec *GetTaskSpec() override { - auto extension_id = task_spec_in_embedding_network_proto; - if (trimmed_proto_.HasExtension(extension_id)) { - return &(trimmed_proto_.GetExtension(extension_id)); - } else { - return nullptr; - } - } - - protected: - int embeddings_size() const override { - return trimmed_proto_.embeddings_size(); - } - - int embeddings_num_rows(int i) const override { - TC_DCHECK(InRange(i, embeddings_size())); - return trimmed_proto_.embeddings(i).rows(); - } - - int embeddings_num_cols(int i) const override { - TC_DCHECK(InRange(i, embeddings_size())); - return trimmed_proto_.embeddings(i).cols(); - } - - const void *embeddings_weights(int i) const override { - TC_DCHECK(InRange(i, embeddings_size())); - const int blob_index = trimmed_proto_.embeddings(i).is_quantized() - ? (embeddings_blob_offset_ + 2 * i) - : (embeddings_blob_offset_ + i); - StringPiece data_blob_view = memory_reader_.data_blob_view(blob_index); - return data_blob_view.data(); - } - - QuantizationType embeddings_quant_type(int i) const override { - TC_DCHECK(InRange(i, embeddings_size())); - if (trimmed_proto_.embeddings(i).is_quantized()) { - return QuantizationType::UINT8; - } else { - return QuantizationType::NONE; - } - } - - const float16 *embeddings_quant_scales(int i) const override { - TC_DCHECK(InRange(i, embeddings_size())); - if (trimmed_proto_.embeddings(i).is_quantized()) { - // Each embedding matrix has two atttached data blobs (hence the "2 * i"): - // one blob with the quantized values and (immediately after it, hence the - // "+ 1") one blob with the scales. - int blob_index = embeddings_blob_offset_ + 2 * i + 1; - StringPiece data_blob_view = memory_reader_.data_blob_view(blob_index); - return reinterpret_cast<const float16 *>(data_blob_view.data()); - } else { - return nullptr; - } - } - - int hidden_size() const override { return trimmed_proto_.hidden_size(); } - - int hidden_num_rows(int i) const override { - TC_DCHECK(InRange(i, hidden_size())); - return trimmed_proto_.hidden(i).rows(); - } - - int hidden_num_cols(int i) const override { - TC_DCHECK(InRange(i, hidden_size())); - return trimmed_proto_.hidden(i).cols(); - } - - const void *hidden_weights(int i) const override { - TC_DCHECK(InRange(i, hidden_size())); - StringPiece data_blob_view = - memory_reader_.data_blob_view(hidden_blob_offset_ + i); - return data_blob_view.data(); - } - - int hidden_bias_size() const override { - return trimmed_proto_.hidden_bias_size(); - } - - int hidden_bias_num_rows(int i) const override { - TC_DCHECK(InRange(i, hidden_bias_size())); - return trimmed_proto_.hidden_bias(i).rows(); - } - - int hidden_bias_num_cols(int i) const override { - TC_DCHECK(InRange(i, hidden_bias_size())); - return trimmed_proto_.hidden_bias(i).cols(); - } - - const void *hidden_bias_weights(int i) const override { - TC_DCHECK(InRange(i, hidden_bias_size())); - StringPiece data_blob_view = - memory_reader_.data_blob_view(hidden_bias_blob_offset_ + i); - return data_blob_view.data(); - } - - int softmax_size() const override { - return trimmed_proto_.has_softmax() ? 1 : 0; - } - - int softmax_num_rows(int i) const override { - TC_DCHECK(InRange(i, softmax_size())); - return trimmed_proto_.softmax().rows(); - } - - int softmax_num_cols(int i) const override { - TC_DCHECK(InRange(i, softmax_size())); - return trimmed_proto_.softmax().cols(); - } - - const void *softmax_weights(int i) const override { - TC_DCHECK(InRange(i, softmax_size())); - StringPiece data_blob_view = - memory_reader_.data_blob_view(softmax_blob_offset_ + i); - return data_blob_view.data(); - } - - int softmax_bias_size() const override { - return trimmed_proto_.has_softmax_bias() ? 1 : 0; - } - - int softmax_bias_num_rows(int i) const override { - TC_DCHECK(InRange(i, softmax_bias_size())); - return trimmed_proto_.softmax_bias().rows(); - } - - int softmax_bias_num_cols(int i) const override { - TC_DCHECK(InRange(i, softmax_bias_size())); - return trimmed_proto_.softmax_bias().cols(); - } - - const void *softmax_bias_weights(int i) const override { - TC_DCHECK(InRange(i, softmax_bias_size())); - StringPiece data_blob_view = - memory_reader_.data_blob_view(softmax_bias_blob_offset_ + i); - return data_blob_view.data(); - } - - int embedding_num_features_size() const override { - return trimmed_proto_.embedding_num_features_size(); - } - - int embedding_num_features(int i) const override { - TC_DCHECK(InRange(i, embedding_num_features_size())); - return trimmed_proto_.embedding_num_features(i); - } - - private: - MemoryImageReader<EmbeddingNetworkProto> memory_reader_; - - const EmbeddingNetworkProto &trimmed_proto_; - - // 0-based offsets in the list of data blobs for the different MatrixParams - // fields. E.g., the 1st hidden MatrixParams has its weights stored in the - // data blob number hidden_blob_offset_, the 2nd one in hidden_blob_offset_ + - // 1, and so on. - int embeddings_blob_offset_; - int hidden_blob_offset_; - int hidden_bias_blob_offset_; - int softmax_blob_offset_; - int softmax_bias_blob_offset_; -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_ diff --git a/common/memory_image/in-memory-model-data.cc b/common/memory_image/in-memory-model-data.cc deleted file mode 100644 index acf3d86..0000000 --- a/common/memory_image/in-memory-model-data.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/memory_image/in-memory-model-data.h" - -#include "common/file-utils.h" -#include "util/base/logging.h" -#include "util/strings/stringpiece.h" - -namespace libtextclassifier { -namespace nlp_core { - -const char InMemoryModelData::kTaskSpecDataStoreEntryName[] = "TASK-SPEC-#@"; -const char InMemoryModelData::kFilePatternPrefix[] = "in-mem-model::"; - -bool InMemoryModelData::GetTaskSpec(TaskSpec *task_spec) const { - StringPiece blob = data_store_.GetData(kTaskSpecDataStoreEntryName); - if (blob.data() == nullptr) { - TC_LOG(ERROR) << "Can't find data blob for TaskSpec, i.e., entry " - << kTaskSpecDataStoreEntryName; - return false; - } - bool parse_status = file_utils::ParseProtoFromMemory(blob, task_spec); - if (!parse_status) { - TC_LOG(ERROR) << "Error parsing TaskSpec"; - return false; - } - return true; -} - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/memory_image/in-memory-model-data.h b/common/memory_image/in-memory-model-data.h deleted file mode 100644 index 91e4436..0000000 --- a/common/memory_image/in-memory-model-data.h +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_IN_MEMORY_MODEL_DATA_H_ -#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_IN_MEMORY_MODEL_DATA_H_ - -#include "common/memory_image/data-store.h" -#include "common/task-spec.pb.h" -#include "util/strings/stringpiece.h" - -namespace libtextclassifier { -namespace nlp_core { - -// In-memory representation of data for a Saft model. Provides access to a -// TaskSpec object (produced by the "spec" stage of the Saft training model) and -// to the bytes of the TaskInputs mentioned in that spec (all these bytes are in -// memory, no file I/O required). -// -// Technically, an InMemoryModelData is a DataStore that maps the special string -// kTaskSpecDataStoreEntryName to the binary serialization of a TaskSpec. For -// each TaskInput (of the TaskSpec) with a file_pattern that starts with -// kFilePatternPrefix (see below), the same DataStore maps file_pattern to some -// content bytes. This way, it is possible to have all TaskInputs in memory, -// while still allowing classic, on-disk TaskInputs. -class InMemoryModelData { - public: - // Name for the DataStore entry that stores the serialized TaskSpec for the - // entire model. - static const char kTaskSpecDataStoreEntryName[]; - - // Returns prefix for TaskInput::Part::file_pattern, to distinguish those - // "files" from other files. - static const char kFilePatternPrefix[]; - - // Constructs an InMemoryModelData based on a chunk of bytes. Those bytes - // should have been produced by a DataStoreBuilder. - explicit InMemoryModelData(StringPiece bytes) : data_store_(bytes) {} - - // Fills *task_spec with a TaskSpec similar to the one used by - // DataStoreBuilder (when building the bytes used to construct this - // InMemoryModelData) except that each file name - // (TaskInput::Part::file_pattern) is replaced with a name that can be used to - // retrieve the corresponding file content bytes via GetBytesForInputFile(). - // - // Returns true on success, false otherwise. - bool GetTaskSpec(TaskSpec *task_spec) const; - - // Gets content bytes for a file. The file_name argument should be the - // file_pattern for a TaskInput from the TaskSpec (see GetTaskSpec()). - // Returns a StringPiece indicating a memory area with the content bytes. On - // error, returns StringPiece(nullptr, 0). - StringPiece GetBytesForInputFile(const std::string &file_name) const { - return data_store_.GetData(file_name); - } - - private: - const memory_image::DataStore data_store_; -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_IN_MEMORY_MODEL_DATA_H_ diff --git a/common/memory_image/low-level-memory-reader.h b/common/memory_image/low-level-memory-reader.h deleted file mode 100644 index c87c772..0000000 --- a/common/memory_image/low-level-memory-reader.h +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_LOW_LEVEL_MEMORY_READER_H_ -#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_LOW_LEVEL_MEMORY_READER_H_ - -#include <string.h> - -#include <string> - -#include "util/base/endian.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" -#include "util/strings/stringpiece.h" - -namespace libtextclassifier { -namespace nlp_core { - -class LowLevelMemReader { - public: - // Constructs a MemReader instance that reads at most num_available_bytes - // starting from address start. - LowLevelMemReader(const void *start, uint64 num_available_bytes) - : current_(reinterpret_cast<const char *>(start)), - // 0 bytes available if start == nullptr - num_available_bytes_(start ? num_available_bytes : 0), - num_loaded_bytes_(0) { - } - - // Copies length bytes of data to address target. Advances current position - // and returns true on success and false otherwise. - bool Read(void *target, uint64 length) { - if (length > num_available_bytes_) { - TC_LOG(WARNING) << "Not enough bytes: available " << num_available_bytes_ - << " < required " << length; - return false; - } - memcpy(target, current_, length); - Advance(length); - return true; - } - - // Reads the string encoded at the current position. The bytes starting at - // current position should contain (1) little-endian uint32 size (in bytes) of - // the actual string and next (2) the actual bytes of the string. Advances - // the current position and returns true if successful, false otherwise. - // - // On success, sets *view to be a view of the relevant bytes: view.data() - // points to the beginning of the string bytes, and view.size() is the number - // of such bytes. - bool ReadString(StringPiece *view) { - uint32 size; - if (!Read(&size, sizeof(size))) { - TC_LOG(ERROR) << "Unable to read std::string size"; - return false; - } - size = LittleEndian::ToHost32(size); - if (size > num_available_bytes_) { - TC_LOG(WARNING) << "Not enough bytes: " << num_available_bytes_ - << " available < " << size << " required "; - return false; - } - *view = StringPiece(current_, size); - Advance(size); - return true; - } - - // Like ReadString(StringPiece *) but reads directly into a C++ string, - // instead of a StringPiece (StringPiece-like object). - bool ReadString(std::string *target) { - StringPiece view; - if (!ReadString(&view)) { - return false; - } - *target = view.ToString(); - return true; - } - - // Returns current position. - const char *GetCurrent() const { return current_; } - - // Returns remaining number of available bytes. - uint64 GetNumAvailableBytes() const { return num_available_bytes_; } - - // Returns number of bytes read ("loaded") so far. - uint64 GetNumLoadedBytes() const { return num_loaded_bytes_; } - - // Advance the current read position by indicated number of bytes. Returns - // true on success, false otherwise (e.g., if there are not enough available - // bytes to advance num_bytes). - bool Advance(uint64 num_bytes) { - if (num_bytes > num_available_bytes_) { - return false; - } - - // Next line never results in an underflow of the unsigned - // num_available_bytes_, due to the previous if. - num_available_bytes_ -= num_bytes; - current_ += num_bytes; - num_loaded_bytes_ += num_bytes; - return true; - } - - // Advance current position to nearest multiple of alignment. Returns false - // if not enough bytes available to do that, true (success) otherwise. - bool SkipToAlign(int alignment) { - int num_extra_bytes = num_loaded_bytes_ % alignment; - if (num_extra_bytes == 0) { - return true; - } - return Advance(alignment - num_extra_bytes); - } - - private: - // Current position in the in-memory data. Next call to Read() will read from - // this address. - const char *current_; - - // Number of available bytes we can still read. - uint64 num_available_bytes_; - - // Number of bytes read ("loaded") so far. - uint64 num_loaded_bytes_; -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_LOW_LEVEL_MEMORY_READER_H_ diff --git a/common/memory_image/memory-image-common.cc b/common/memory_image/memory-image-common.cc deleted file mode 100644 index 6debf1d..0000000 --- a/common/memory_image/memory-image-common.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/memory_image/memory-image-common.h" - -namespace libtextclassifier { -namespace nlp_core { - -// IMPORTANT: this signature should never change. If you change the protocol, -// update kCurrentVersion, *not* this signature. -const char MemoryImageConstants::kSignature[] = "Memory image $5%1#o3-1x32"; - -const int MemoryImageConstants::kCurrentVersion = 1; - -const int MemoryImageConstants::kDefaultAlignment = 16; - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/memory_image/memory-image-common.h b/common/memory_image/memory-image-common.h deleted file mode 100644 index 3a46f49..0000000 --- a/common/memory_image/memory-image-common.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Common utils for memory images. - -#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_COMMON_H_ -#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_COMMON_H_ - -#include <stddef.h> - -#include <string> - -#include "util/strings/stringpiece.h" - -namespace libtextclassifier { -namespace nlp_core { - -class MemoryImageConstants { - public: - static const char kSignature[]; - static const int kCurrentVersion; - static const int kDefaultAlignment; -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_COMMON_H_ diff --git a/common/memory_image/memory-image-reader.cc b/common/memory_image/memory-image-reader.cc deleted file mode 100644 index 7e717d5..0000000 --- a/common/memory_image/memory-image-reader.cc +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/memory_image/memory-image-reader.h" - -#include <string> - -#include "common/memory_image/low-level-memory-reader.h" -#include "common/memory_image/memory-image-common.h" -#include "common/memory_image/memory-image.pb.h" -#include "util/base/endian.h" -#include "util/base/logging.h" - -namespace libtextclassifier { -namespace nlp_core { - -namespace { - -// Checks that the memory area read by mem_reader starts with the expected -// signature. Advances mem_reader past the signature and returns success -// status. -bool ReadAndCheckSignature(LowLevelMemReader *mem_reader) { - const std::string expected_signature = MemoryImageConstants::kSignature; - const int signature_size = expected_signature.size(); - if (mem_reader->GetNumAvailableBytes() < signature_size) { - TC_LOG(ERROR) << "Not enough bytes to check signature"; - return false; - } - const std::string actual_signature(mem_reader->GetCurrent(), signature_size); - if (!mem_reader->Advance(signature_size)) { - TC_LOG(ERROR) << "Failed to advance past signature"; - return false; - } - if (actual_signature != expected_signature) { - TC_LOG(ERROR) << "Different signature: actual \"" << actual_signature - << "\" != expected \"" << expected_signature << "\""; - return false; - } - return true; -} - -// Parses MemoryImageHeader from mem_reader. Advances mem_reader past it. -// Returns success status. -bool ParseMemoryImageHeader( - LowLevelMemReader *mem_reader, MemoryImageHeader *header) { - std::string header_proto_str; - if (!mem_reader->ReadString(&header_proto_str)) { - TC_LOG(ERROR) << "Unable to read header_proto_str"; - return false; - } - if (!header->ParseFromString(header_proto_str)) { - TC_LOG(ERROR) << "Unable to parse MemoryImageHeader"; - return false; - } - return true; -} - -} // namespace - -bool GeneralMemoryImageReader::ReadMemoryImage() { - LowLevelMemReader mem_reader(start_, num_bytes_); - - // Read and check signature. - if (!ReadAndCheckSignature(&mem_reader)) { - return false; - } - - // Parse MemoryImageHeader header_. - if (!ParseMemoryImageHeader(&mem_reader, &header_)) { - return false; - } - - // Check endianness. - if (header_.is_little_endian() != LittleEndian::IsLittleEndian()) { - // TODO(salcianu): implement conversion: it will take time, but it's better - // than crashing. Not very urgent: [almost] all current Android phones are - // little-endian. - TC_LOG(ERROR) << "Memory image is " - << (header_.is_little_endian() ? "little" : "big") - << " endian. " - << "Local system is different and we don't currently support " - << "conversion between the two."; - return false; - } - - // Read binary serialization of trimmed original proto. - if (!mem_reader.ReadString(&trimmed_proto_serialization_)) { - TC_LOG(ERROR) << "Unable to read trimmed proto binary serialization"; - return false; - } - - // Fill vector of pointers to beginning of each data blob. - for (int i = 0; i < header_.blob_info_size(); ++i) { - const MemoryImageDataBlobInfo &blob_info = header_.blob_info(i); - if (!mem_reader.SkipToAlign(header_.alignment())) { - TC_LOG(ERROR) << "Unable to align for blob #i" << i; - return false; - } - data_blob_views_.emplace_back( - mem_reader.GetCurrent(), - blob_info.num_bytes()); - if (!mem_reader.Advance(blob_info.num_bytes())) { - TC_LOG(ERROR) << "Not enough bytes for blob #i" << i; - return false; - } - } - - return true; -} - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/memory_image/memory-image-reader.h b/common/memory_image/memory-image-reader.h deleted file mode 100644 index c5954fd..0000000 --- a/common/memory_image/memory-image-reader.h +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// MemoryImageReader, class for reading a memory image. - -#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_READER_H_ -#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_READER_H_ - -#include <string> -#include <vector> - -#include "common/memory_image/memory-image.pb.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" -#include "util/base/macros.h" -#include "util/strings/stringpiece.h" - -namespace libtextclassifier { -namespace nlp_core { - -// General, non-templatized class, to reduce code duplication. -// -// Given a memory area (pointer to start + size in bytes) parses a memory image -// from there into (1) MemoryImageHeader proto (it includes the serialized form -// of the trimmed down original proto) and (2) a list of void* pointers to the -// beginning of all data blobs. -// -// In case of parsing errors, we prefer to log the error and set the -// success_status() to false, instead of CHECK-failing . This way, the client -// has the option of performing error recovery or crashing. Some mobile apps -// don't like crashing (a restart is very slow) so, if possible, we try to avoid -// that. -class GeneralMemoryImageReader { - public: - // Constructs this object. See class-level comments. Note: the memory area - // pointed to by start should not be deallocated while this object is used: - // this object does not copy it; instead, it keeps pointers inside that memory - // area. - GeneralMemoryImageReader(const void *start, uint64 num_bytes) - : start_(start), num_bytes_(num_bytes) { - success_ = ReadMemoryImage(); - } - - virtual ~GeneralMemoryImageReader() {} - - // Returns true if reading the memory image has been successful. If this - // returns false, then none of the other accessors should be used. - bool success_status() const { return success_; } - - // Returns number of data blobs from the memory image. - int num_data_blobs() const { - return data_blob_views_.size(); - } - - // Returns pointer to the beginning of the data blob #i. - StringPiece data_blob_view(int i) const { - if ((i < 0) || (i >= num_data_blobs())) { - TC_LOG(ERROR) << "Blob index " << i << " outside range [0, " - << num_data_blobs() << "); will return empty data chunk"; - return StringPiece(); - } - return data_blob_views_[i]; - } - - // Returns std::string with binary serialization of the original proto, but - // trimmed of the large fields (those were placed in the data blobs). - std::string trimmed_proto_str() const { - return trimmed_proto_serialization_.ToString(); - } - - // Same as above but returns the trimmed proto as a string piece pointing to - // the image. - StringPiece trimmed_proto_view() const { - return trimmed_proto_serialization_; - } - - const MemoryImageHeader &header() { return header_; } - - protected: - void set_as_failed() { - success_ = false; - } - - private: - bool ReadMemoryImage(); - - // Pointer to beginning of memory image. Not owned. - const void *const start_; - - // Number of bytes in the memory image. This class will not read more bytes. - const uint64 num_bytes_; - - // MemoryImageHeader parsed from the memory image. - MemoryImageHeader header_; - - // Binary serialization of the trimmed version of the original proto. - // Represented as a StringPiece backed up by the underlying memory image - // bytes. - StringPiece trimmed_proto_serialization_; - - // List of StringPiece objects for all data blobs from the memory image (in - // order). - std::vector<StringPiece> data_blob_views_; - - // Memory reading success status. - bool success_; - - TC_DISALLOW_COPY_AND_ASSIGN(GeneralMemoryImageReader); -}; - -// Like GeneralMemoryImageReader, but has knowledge about the type of the -// original proto. As such, it can parse it (well, the trimmed version) and -// offer access to it. -// -// Template parameter T should be the type of the original proto. -template<class T> -class MemoryImageReader : public GeneralMemoryImageReader { - public: - MemoryImageReader(const void *start, uint64 num_bytes) - : GeneralMemoryImageReader(start, num_bytes) { - if (!trimmed_proto_.ParseFromString(trimmed_proto_str())) { - TC_LOG(INFO) << "Unable to parse the trimmed proto"; - set_as_failed(); - } - } - - // Returns const reference to the trimmed version of the original proto. - // Useful for retrieving the many small fields that are not converted into - // data blobs. - const T &trimmed_proto() const { return trimmed_proto_; } - - private: - T trimmed_proto_; - - TC_DISALLOW_COPY_AND_ASSIGN(MemoryImageReader); -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_READER_H_ diff --git a/common/memory_image/memory-image.proto b/common/memory_image/memory-image.proto deleted file mode 100644 index f6b624c..0000000 --- a/common/memory_image/memory-image.proto +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (C) 2017 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Protos for "memory images". - -syntax = "proto2"; -option optimize_for = LITE_RUNTIME; - -package libtextclassifier.nlp_core; - -message MemoryImageDataBlobInfo { - // Size (in bytes) of this data blob. - optional uint64 num_bytes = 1; - - // Indicates whether this data blob corresponds to an array. - optional bool is_array = 2 [default = true]; - - // Size (in bytes) of each array element. Useful for little <-> big endian - // conversions. -1 means unknown: no endianness conversion in that case. - optional int32 element_size = 3 [default = -1]; -} - -message MemoryImageHeader { - // Version of the algorithm used to produce the memory image. We should - // increase the value used here every time we perform an incompatible change. - // Algorithm version v should handle only memory images of the same version, - // and crash otherwise. - optional int32 version = 1 [default = -1]; - - // True if the info stored in the data blobs uses the little endian - // convention. Almost all machines today are little-endian but we want to be - // able to crash with an informative message or perform a (costly) conversion - // in the rare cases when that's not true. - optional bool is_little_endian = 2 [default = true]; - - // Alignment (in bytes) for all data blobs. E.g., if this field is 16, then - // each data blob starts at an offset that's a multiple of 16, where the - // offset is measured from the beginning of the memory image. On the client - // side, allocating the entire memory image at an aligned address (by same - // alignment) makes sure all data blobs are properly aligned. - // - // NOTE: I (salcianu) explored the idea of a different alignment for each data - // blob: e.g., float[] should be fine with 4-byte alignment (alignment = 4) - // but char[] are fine with no alignment (alignment = 1). As we expect only a - // few (but large) data blobs, the space benefit is not worth the extra code - // complexity. - optional int32 alignment = 3 [default = 8]; - - // One MemoryImageDataBlobInfo for each data blob, in order. There is one - // data blob for each large field we handle specially. - repeated MemoryImageDataBlobInfo blob_info = 4; -} diff --git a/common/mock_functions.h b/common/mock_functions.h deleted file mode 100644 index b5bcb07..0000000 --- a/common/mock_functions.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_ -#define LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_ - -#include <math.h> - -#include "common/registry.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace functions { - -// Abstract double -> double function. -class Function : public RegisterableClass<Function> { - public: - virtual ~Function() {} - virtual double Evaluate(double x) = 0; -}; - -class Cos : public Function { - public: - double Evaluate(double x) override { return cos(x); } - TC_DEFINE_REGISTRATION_METHOD("cos", Cos); -}; - -class Exp : public Function { - public: - double Evaluate(double x) override { return exp(x); } - TC_DEFINE_REGISTRATION_METHOD("exp", Exp); -}; - -// Abstract int -> int function. -class IntFunction : public RegisterableClass<IntFunction> { - public: - virtual ~IntFunction() {} - virtual int Evaluate(int k) = 0; -}; - -class Inc : public IntFunction { - public: - int Evaluate(int k) override { return k + 1; } - TC_DEFINE_REGISTRATION_METHOD("inc", Inc); -}; - -class Dec : public IntFunction { - public: - int Evaluate(int k) override { return k + 1; } - TC_DEFINE_REGISTRATION_METHOD("dec", Dec); -}; - -} // namespace functions - -// Should be inside namespace libtextclassifier::nlp_core. -TC_DECLARE_CLASS_REGISTRY_NAME(functions::Function); -TC_DECLARE_CLASS_REGISTRY_NAME(functions::IntFunction); - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_ diff --git a/common/registry.h b/common/registry.h deleted file mode 100644 index d958225..0000000 --- a/common/registry.h +++ /dev/null @@ -1,281 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Mechanism to instantiate classes by name. -// -// This mechanism is useful if the concrete classes to be instantiated are not -// statically known (e.g., if their names are read from a dynamically-provided -// config). -// -// In that case, the first step is to define the API implemented by the -// instantiated classes. E.g., -// -// // In a header file function.h: -// -// // Abstract function that takes a double and returns a double. -// class Function : public RegisterableClass<Function> { -// public: -// virtual ~Function() {} -// virtual double Evaluate(double x) = 0; -// }; -// -// // Should be inside namespace libtextclassifier::nlp_core. -// TC_DECLARE_CLASS_REGISTRY_NAME(Function); -// -// Notice the inheritance from RegisterableClass<Function>. RegisterableClass -// is defined by this file (registry.h). Under the hood, this inheritanace -// defines a "registry" that maps names (zero-terminated arrays of chars) to -// factory methods that create Functions. You should give a human-readable name -// to this registry. To do that, use the following macro in a .cc file (it has -// to be a .cc file, as it defines some static data): -// -// // Inside function.cc -// // Should be inside namespace libtextclassifier::nlp_core. -// TC_DEFINE_CLASS_REGISTRY_NAME("function", Function); -// -// Now, let's define a few concrete Functions: e.g., -// -// class Cos : public Function { -// public: -// double Evaluate(double x) override { return cos(x); } -// TC_DEFINE_REGISTRATION_METHOD("cos", Cos); -// }; -// -// class Exp : public Function { -// public: -// double Evaluate(double x) override { return exp(x); } -// TC_DEFINE_REGISTRATION_METHOD("sin", Sin); -// }; -// -// Each concrete Function implementation should have (in the public section) the -// macro -// -// TC_DEFINE_REGISTRATION_METHOD("name", implementation_class); -// -// This defines a RegisterClass static method that, when invoked, associates -// "name" with a factory method that creates instances of implementation_class. -// -// Before instantiating Functions by name, we need to tell our system which -// Functions we may be interested in. This is done by calling the -// Foo::RegisterClass() for each relevant Foo implementation of Function. It is -// ok to call Foo::RegisterClass() multiple times (even in parallel): only the -// first call will perform something, the others will return immediately. -// -// Cos::RegisterClass(); -// Exp::RegisterClass(); -// -// Now, let's instantiate a Function based on its name. This get a lot more -// interesting if the Function name is not statically known (i.e., -// read from an input proto: -// -// std::unique_ptr<Function> f(Function::Create("cos")); -// double result = f->Evaluate(arg); -// -// NOTE: the same binary can use this mechanism for different APIs. E.g., one -// can also have (in the binary with Function, Sin, Cos, etc): -// -// class IntFunction : public RegisterableClass<IntFunction> { -// public: -// virtual ~IntFunction() {} -// virtual int Evaluate(int k) = 0; -// }; -// -// TC_DECLARE_CLASS_REGISTRY_NAME(IntFunction); -// -// TC_DEFINE_CLASS_REGISTRY_NAME("int function", IntFunction); -// -// class Inc : public IntFunction { -// public: -// int Evaluate(int k) override { return k + 1; } -// TC_DEFINE_REGISTRATION_METHOD("inc", Inc); -// }; -// -// RegisterableClass<Function> and RegisterableClass<IntFunction> define their -// own registries: each maps string names to implementation of the corresponding -// API. - -#ifndef LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_ -#define LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_ - -#include <stdlib.h> -#include <string.h> - -#include <string> - -#include "util/base/logging.h" - -namespace libtextclassifier { -namespace nlp_core { - -namespace internal { -// Registry that associates keys (zero-terminated array of chars) with values. -// Values are pointers to type T (the template parameter). This is used to -// store the association between component names and factory methods that -// produce those components; the error messages are focused on that case. -// -// Internally, this registry uses a linked list of (key, value) pairs. We do -// not use an STL map, list, etc because we aim for small code size. -template <class T> -class ComponentRegistry { - public: - explicit ComponentRegistry(const char *name) : name_(name), head_(nullptr) {} - - // Adds a the (key, value) pair to this registry (if the key does not already - // exists in this registry) and returns true. If the registry already has a - // mapping for key, returns false and does not modify the registry. NOTE: the - // error (false) case happens even if the existing value for key is equal with - // the new one. - // - // This method does not take ownership of key, nor of value. - bool Add(const char *key, T *value) { - const Cell *old_cell = FindCell(key); - if (old_cell != nullptr) { - TC_LOG(ERROR) << "Duplicate component: " << key; - return false; - } - Cell *new_cell = new Cell(key, value, head_); - head_ = new_cell; - return true; - } - - // Returns the value attached to a key in this registry. Returns nullptr on - // error (e.g., unknown key). - T *Lookup(const char *key) const { - const Cell *cell = FindCell(key); - if (cell == nullptr) { - TC_LOG(ERROR) << "Unknown " << name() << " component: " << key; - } - return (cell == nullptr) ? nullptr : cell->value(); - } - - T *Lookup(const std::string &key) const { return Lookup(key.c_str()); } - - // Returns name of this ComponentRegistry. - const char *name() const { return name_; } - - private: - // Cell for the singly-linked list underlying this ComponentRegistry. Each - // cell contains a key, the value for that key, as well as a pointer to the - // next Cell from the list. - class Cell { - public: - // Constructs a new Cell. - Cell(const char *key, T *value, Cell *next) - : key_(key), value_(value), next_(next) {} - - const char *key() const { return key_; } - T *value() const { return value_; } - Cell *next() const { return next_; } - - private: - const char *const key_; - T *const value_; - Cell *const next_; - }; - - // Finds Cell for indicated key in the singly-linked list pointed to by head_. - // Returns pointer to that first Cell with that key, or nullptr if no such - // Cell (i.e., unknown key). - // - // Caller does NOT own the returned pointer. - const Cell *FindCell(const char *key) const { - Cell *c = head_; - while (c != nullptr && strcmp(key, c->key()) != 0) { - c = c->next(); - } - return c; - } - - // Human-readable description for this ComponentRegistry. For debug purposes. - const char *const name_; - - // Pointer to the first Cell from the underlying list of (key, value) pairs. - Cell *head_; -}; -} // namespace internal - -// Base class for registerable classes. -template <class T> -class RegisterableClass { - public: - // Factory function type. - typedef T *(Factory)(); - - // Registry type. - typedef internal::ComponentRegistry<Factory> Registry; - - // Creates a new instance of T. Returns pointer to new instance or nullptr in - // case of errors (e.g., unknown component). - // - // Passes ownership of the returned pointer to the caller. - static T *Create(const std::string &name) { // NOLINT - auto *factory = registry()->Lookup(name); - if (factory == nullptr) { - TC_LOG(ERROR) << "Unknown RegisterableClass " << name; - return nullptr; - } - return factory(); - } - - // Returns registry for class. - static Registry *registry() { - static Registry *registry_for_type_t = new Registry(kRegistryName); - return registry_for_type_t; - } - - protected: - // Factory method for subclass ComponentClass. Used internally by the static - // method RegisterClass() defined by TC_DEFINE_REGISTRATION_METHOD. - template <class ComponentClass> - static T *_internal_component_factory() { - return new ComponentClass(); - } - - private: - // Human-readable name for the registry for this class. - static const char kRegistryName[]; -}; - -// Defines the static method component_class::RegisterClass() that should be -// called before trying to instantiate component_class by name. Should be used -// inside the public section of the declaration of component_class. See -// comments at the top-level of this file. -#define TC_DEFINE_REGISTRATION_METHOD(component_name, component_class) \ - static void RegisterClass() { \ - static bool once = registry()->Add( \ - component_name, &_internal_component_factory<component_class>); \ - if (!once) { \ - TC_LOG(ERROR) << "Problem registering " << component_name; \ - } \ - TC_DCHECK(once); \ - } - -// Defines the human-readable name of the registry associated with base_class. -#define TC_DECLARE_CLASS_REGISTRY_NAME(base_class) \ - template <> \ - const char ::libtextclassifier::nlp_core::RegisterableClass< \ - base_class>::kRegistryName[] - -// Defines the human-readable name of the registry associated with base_class. -#define TC_DEFINE_CLASS_REGISTRY_NAME(registry_name, base_class) \ - template <> \ - const char ::libtextclassifier::nlp_core::RegisterableClass< \ - base_class>::kRegistryName[] = registry_name - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_ diff --git a/common/registry_test.cc b/common/registry_test.cc deleted file mode 100644 index d5d7006..0000000 --- a/common/registry_test.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <memory> - -#include "common/mock_functions.h" -#include "gtest/gtest.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace functions { - -TEST(RegistryTest, InstantiateFunctionsByName) { - // First, we need to register the functions we are interested in: - Exp::RegisterClass(); - Inc::RegisterClass(); - Cos::RegisterClass(); - - // RegisterClass methods can be called in any order, even multiple times :) - Cos::RegisterClass(); - Inc::RegisterClass(); - Inc::RegisterClass(); - Cos::RegisterClass(); - Inc::RegisterClass(); - - // NOTE: we intentionally do not register Dec. Attempts to create an instance - // of that function by name should fail. - - // Instantiate a few functions and check that the created functions produce - // the expected results for a few sample values. - std::unique_ptr<Function> f1(Function::Create("cos")); - ASSERT_NE(f1, nullptr); - std::unique_ptr<Function> f2(Function::Create("exp")); - ASSERT_NE(f2, nullptr); - EXPECT_NEAR(f1->Evaluate(-3), -0.9899, 0.0001); - EXPECT_NEAR(f2->Evaluate(2.3), 9.9741, 0.0001); - - std::unique_ptr<IntFunction> f3(IntFunction::Create("inc")); - ASSERT_NE(f3, nullptr); - EXPECT_EQ(f3->Evaluate(7), 8); - - // Instantiating unknown functions should return nullptr, but not crash - // anything. - EXPECT_EQ(Function::Create("mambo"), nullptr); - - // Functions that are defined in the code, but are not registered are unknown. - EXPECT_EQ(IntFunction::Create("dec"), nullptr); - - // Function and IntFunction use different registries. - EXPECT_EQ(IntFunction::Create("exp"), nullptr); -} - -} // namespace functions -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/simple-adder.h b/common/simple-adder.h deleted file mode 100644 index c16cc8a..0000000 --- a/common/simple-adder.h +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_SIMPLE_ADDER_H_ -#define LIBTEXTCLASSIFIER_COMMON_SIMPLE_ADDER_H_ - -#include "util/base/integral_types.h" -#include "util/base/port.h" - -namespace libtextclassifier { -namespace nlp_core { - -// Implements add and scaleadd in the most straight-forward way, and it doesn't -// have any additional requirement on the alignment and array size. -class SimpleAdder { - public: - TC_ATTRIBUTE_ALWAYS_INLINE SimpleAdder(float *dest, int num_floats) - : dest_(dest), num_floats_(num_floats) {} - - TC_ATTRIBUTE_ALWAYS_INLINE void LazyAdd(const float *source) const { - AddImpl(source, num_floats_, dest_); - } - - TC_ATTRIBUTE_ALWAYS_INLINE void LazyScaleAdd(const float *source, - const float scale) const { - ScaleAddImpl(source, num_floats_, scale, dest_); - } - - // Simple fast while loop to implement dest += source. - TC_ATTRIBUTE_ALWAYS_INLINE static void AddImpl(const float *__restrict source, - uint32 size, - float *__restrict dest) { - for (uint32 i = 0; i < size; ++i) { - dest[i] += source[i]; - } - } - - // Simple fast while loop to implement dest += scale * source. - TC_ATTRIBUTE_ALWAYS_INLINE static void ScaleAddImpl( - const float *__restrict source, uint32 size, const float scale, - float *__restrict dest) { - for (uint32 i = 0; i < size; ++i) { - dest[i] += source[i] * scale; - } - } - - private: - float *dest_; - int num_floats_; -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_SIMPLE_ADDER_H_ diff --git a/common/task-context.cc b/common/task-context.cc deleted file mode 100644 index e4c1090..0000000 --- a/common/task-context.cc +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/task-context.h" - -#include <stdlib.h> - -#include <string> - -#include "util/base/integral_types.h" -#include "util/base/logging.h" -#include "util/strings/numbers.h" - -namespace libtextclassifier { -namespace nlp_core { - -namespace { -int32 ParseInt32WithDefault(const std::string &s, int32 defval) { - int32 value = defval; - return ParseInt32(s.c_str(), &value) ? value : defval; -} - -int64 ParseInt64WithDefault(const std::string &s, int64 defval) { - int64 value = defval; - return ParseInt64(s.c_str(), &value) ? value : defval; -} - -double ParseDoubleWithDefault(const std::string &s, double defval) { - double value = defval; - return ParseDouble(s.c_str(), &value) ? value : defval; -} -} // namespace - -TaskInput *TaskContext::GetInput(const std::string &name) { - // Return existing input if it exists. - for (int i = 0; i < spec_.input_size(); ++i) { - if (spec_.input(i).name() == name) return spec_.mutable_input(i); - } - - // Create new input. - TaskInput *input = spec_.add_input(); - input->set_name(name); - return input; -} - -TaskInput *TaskContext::GetInput(const std::string &name, - const std::string &file_format, - const std::string &record_format) { - TaskInput *input = GetInput(name); - if (!file_format.empty()) { - bool found = false; - for (int i = 0; i < input->file_format_size(); ++i) { - if (input->file_format(i) == file_format) found = true; - } - if (!found) input->add_file_format(file_format); - } - if (!record_format.empty()) { - bool found = false; - for (int i = 0; i < input->record_format_size(); ++i) { - if (input->record_format(i) == record_format) found = true; - } - if (!found) input->add_record_format(record_format); - } - return input; -} - -void TaskContext::SetParameter(const std::string &name, - const std::string &value) { - TC_LOG(INFO) << "SetParameter(" << name << ", " << value << ")"; - - // If the parameter already exists update the value. - for (int i = 0; i < spec_.parameter_size(); ++i) { - if (spec_.parameter(i).name() == name) { - spec_.mutable_parameter(i)->set_value(value); - return; - } - } - - // Add new parameter. - TaskSpec::Parameter *param = spec_.add_parameter(); - param->set_name(name); - param->set_value(value); -} - -std::string TaskContext::GetParameter(const std::string &name) const { - // First try to find parameter in task specification. - for (int i = 0; i < spec_.parameter_size(); ++i) { - if (spec_.parameter(i).name() == name) return spec_.parameter(i).value(); - } - - // Parameter not found, return empty std::string. - return ""; -} - -int TaskContext::GetIntParameter(const std::string &name) const { - std::string value = GetParameter(name); - return ParseInt32WithDefault(value, 0); -} - -int64 TaskContext::GetInt64Parameter(const std::string &name) const { - std::string value = GetParameter(name); - return ParseInt64WithDefault(value, 0); -} - -bool TaskContext::GetBoolParameter(const std::string &name) const { - std::string value = GetParameter(name); - return value == "true"; -} - -double TaskContext::GetFloatParameter(const std::string &name) const { - std::string value = GetParameter(name); - return ParseDoubleWithDefault(value, 0.0); -} - -std::string TaskContext::Get(const std::string &name, - const char *defval) const { - // First try to find parameter in task specification. - for (int i = 0; i < spec_.parameter_size(); ++i) { - if (spec_.parameter(i).name() == name) return spec_.parameter(i).value(); - } - - // Parameter not found, return default value. - return defval; -} - -std::string TaskContext::Get(const std::string &name, - const std::string &defval) const { - return Get(name, defval.c_str()); -} - -int TaskContext::Get(const std::string &name, int defval) const { - std::string value = Get(name, ""); - return ParseInt32WithDefault(value, defval); -} - -int64 TaskContext::Get(const std::string &name, int64 defval) const { - std::string value = Get(name, ""); - return ParseInt64WithDefault(value, defval); -} - -double TaskContext::Get(const std::string &name, double defval) const { - std::string value = Get(name, ""); - return ParseDoubleWithDefault(value, defval); -} - -bool TaskContext::Get(const std::string &name, bool defval) const { - std::string value = Get(name, ""); - return value.empty() ? defval : value == "true"; -} - -std::string TaskContext::InputFile(const TaskInput &input) { - if (input.part_size() == 0) { - TC_LOG(ERROR) << "No file for TaskInput " << input.name(); - return ""; - } - if (input.part_size() > 1) { - TC_LOG(ERROR) << "Ambiguous: multiple files for TaskInput " << input.name(); - } - return input.part(0).file_pattern(); -} - -bool TaskContext::Supports(const TaskInput &input, - const std::string &file_format, - const std::string &record_format) { - // Check file format. - if (input.file_format_size() > 0) { - bool found = false; - for (int i = 0; i < input.file_format_size(); ++i) { - if (input.file_format(i) == file_format) { - found = true; - break; - } - } - if (!found) return false; - } - - // Check record format. - if (input.record_format_size() > 0) { - bool found = false; - for (int i = 0; i < input.record_format_size(); ++i) { - if (input.record_format(i) == record_format) { - found = true; - break; - } - } - if (!found) return false; - } - - return true; -} - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/task-context.h b/common/task-context.h deleted file mode 100644 index c55ed67..0000000 --- a/common/task-context.h +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_TASK_CONTEXT_H_ -#define LIBTEXTCLASSIFIER_COMMON_TASK_CONTEXT_H_ - -#include <string> -#include <vector> - -#include "common/task-spec.pb.h" -#include "util/base/integral_types.h" - -namespace libtextclassifier { -namespace nlp_core { - -// A task context holds configuration information for a task. It is basically a -// wrapper around a TaskSpec protocol buffer. -class TaskContext { - public: - // Returns the underlying task specification protocol buffer for the context. - const TaskSpec &spec() const { return spec_; } - TaskSpec *mutable_spec() { return &spec_; } - - // Returns a named input descriptor for the task. A new input is created if - // the task context does not already have an input with that name. - TaskInput *GetInput(const std::string &name); - TaskInput *GetInput(const std::string &name, - const std::string &file_format, - const std::string &record_format); - - // Sets task parameter. - void SetParameter(const std::string &name, const std::string &value); - - // Returns task parameter. If the parameter is not in the task configuration - // the (default) value of the corresponding command line flag is returned. - std::string GetParameter(const std::string &name) const; - int GetIntParameter(const std::string &name) const; - int64 GetInt64Parameter(const std::string &name) const; - bool GetBoolParameter(const std::string &name) const; - double GetFloatParameter(const std::string &name) const; - - // Returns task parameter. If the parameter is not in the task configuration - // the default value is returned. - std::string Get(const std::string &name, const std::string &defval) const; - std::string Get(const std::string &name, const char *defval) const; - int Get(const std::string &name, int defval) const; - int64 Get(const std::string &name, int64 defval) const; - double Get(const std::string &name, double defval) const; - bool Get(const std::string &name, bool defval) const; - - // Returns input file name for a single-file task input. - // - // Special cases: returns the empty string if the TaskInput does not have any - // input files. Returns the first file if the TaskInput has multiple input - // files. - static std::string InputFile(const TaskInput &input); - - // Returns true if task input supports the file and record format. - static bool Supports(const TaskInput &input, const std::string &file_format, - const std::string &record_format); - - private: - // Underlying task specification protocol buffer. - TaskSpec spec_; - - // Vector of parameters required by this task. These must be specified in the - // task rather than relying on default values. - std::vector<std::string> required_parameters_; -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_TASK_CONTEXT_H_ diff --git a/common/task-spec.proto b/common/task-spec.proto deleted file mode 100644 index ab986ce..0000000 --- a/common/task-spec.proto +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (C) 2017 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// LINT: ALLOW_GROUPS -// Protocol buffer specifications for task configuration. - -syntax = "proto2"; -option optimize_for = LITE_RUNTIME; - -package libtextclassifier.nlp_core; - -// Task input descriptor. -message TaskInput { - // Name of input resource. - required string name = 1; - - // File format for resource. - repeated string file_format = 3; - - // Record format for resource. - repeated string record_format = 4; - - // An input can consist of multiple file sets. - repeated group Part = 6 { - // File pattern for file set. - optional string file_pattern = 7; - - // File format for file set. - optional string file_format = 8; - - // Record format for file set. - optional string record_format = 9; - } - - reserved 2, 5; -} - -// A task specification is used for describing executing parameters. -message TaskSpec { - // Task parameters. - repeated group Parameter = 3 { - required string name = 4; - optional string value = 5; - } - - // Task inputs. - repeated TaskInput input = 6; - - reserved 1, 2, 7; -} diff --git a/common/vector-span.h b/common/vector-span.h deleted file mode 100644 index d7fbfe9..0000000 --- a/common/vector-span.h +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_ -#define LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_ - -#include <vector> - -namespace libtextclassifier { - -// StringPiece analogue for std::vector<T>. -template <class T> -class VectorSpan { - public: - VectorSpan() : begin_(), end_() {} - VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit) - : begin_(v.begin()), end_(v.end()) {} - VectorSpan(typename std::vector<T>::const_iterator begin, - typename std::vector<T>::const_iterator end) - : begin_(begin), end_(end) {} - - const T& operator[](typename std::vector<T>::size_type i) const { - return *(begin_ + i); - } - - int size() const { return end_ - begin_; } - typename std::vector<T>::const_iterator begin() const { return begin_; } - typename std::vector<T>::const_iterator end() const { return end_; } - - private: - typename std::vector<T>::const_iterator begin_; - typename std::vector<T>::const_iterator end_; -}; - -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_ diff --git a/common/workspace.cc b/common/workspace.cc deleted file mode 100644 index 770e4be..0000000 --- a/common/workspace.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/workspace.h" - -#include <atomic> -#include <string> - -namespace libtextclassifier { -namespace nlp_core { - -// static -int GetFreshTypeId() { - // Static local below is initialized the first time this method is run. - static std::atomic<int> counter(0); - return counter++; -} - -std::string WorkspaceRegistry::DebugString() const { - std::string str; - for (auto &it : workspace_names_) { - const std::string &type_name = workspace_types_.at(it.first); - for (size_t index = 0; index < it.second.size(); ++index) { - const std::string &workspace_name = it.second[index]; - str.append("\n "); - str.append(type_name); - str.append(" :: "); - str.append(workspace_name); - } - } - return str; -} - -VectorIntWorkspace::VectorIntWorkspace(int size) : elements_(size) {} - -VectorIntWorkspace::VectorIntWorkspace(int size, int value) - : elements_(size, value) {} - -VectorIntWorkspace::VectorIntWorkspace(const std::vector<int> &elements) - : elements_(elements) {} - -std::string VectorIntWorkspace::TypeName() { return "Vector"; } - -VectorVectorIntWorkspace::VectorVectorIntWorkspace(int size) - : elements_(size) {} - -std::string VectorVectorIntWorkspace::TypeName() { return "VectorVector"; } - -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/common/workspace.h b/common/workspace.h deleted file mode 100644 index e003bde..0000000 --- a/common/workspace.h +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Notes on thread-safety: All of the classes here are thread-compatible. More -// specifically, the registry machinery is thread-safe, as long as each thread -// performs feature extraction on a different Sentence object. - -#ifndef LIBTEXTCLASSIFIER_COMMON_WORKSPACE_H_ -#define LIBTEXTCLASSIFIER_COMMON_WORKSPACE_H_ - -#include <stddef.h> -#include <string> -#include <unordered_map> -#include <utility> -#include <vector> - -#include "util/base/logging.h" -#include "util/base/macros.h" - -namespace libtextclassifier { -namespace nlp_core { - -// A base class for shared workspaces. Derived classes implement a static member -// function TypeName() which returns a human readable std::string name for the -// class. -class Workspace { - public: - // Polymorphic destructor. - virtual ~Workspace() {} - - protected: - // Create an empty workspace. - Workspace() {} - - private: - TC_DISALLOW_COPY_AND_ASSIGN(Workspace); -}; - -// Returns a new, strictly increasing int every time it is invoked. -int GetFreshTypeId(); - -// Struct to simulate typeid, but without RTTI. -template <typename T> -struct TypeId { - static int type_id; -}; - -template <typename T> -int TypeId<T>::type_id = GetFreshTypeId(); - -// A registry that keeps track of workspaces. -class WorkspaceRegistry { - public: - // Create an empty registry. - WorkspaceRegistry() {} - - // Returns the index of a named workspace, adding it to the registry first - // if necessary. - template <class W> - int Request(const std::string &name) { - const int id = TypeId<W>::type_id; - max_workspace_id_ = std::max(id, max_workspace_id_); - workspace_types_[id] = W::TypeName(); - std::vector<std::string> &names = workspace_names_[id]; - for (int i = 0; i < names.size(); ++i) { - if (names[i] == name) return i; - } - names.push_back(name); - return names.size() - 1; - } - - // Returns the maximum workspace id that has been registered. - int MaxId() const { - return max_workspace_id_; - } - - const std::unordered_map<int, std::vector<std::string> > &WorkspaceNames() - const { - return workspace_names_; - } - - // Returns a std::string describing the registered workspaces. - std::string DebugString() const; - - private: - // Workspace type names, indexed as workspace_types_[typeid]. - std::unordered_map<int, std::string> workspace_types_; - - // Workspace names, indexed as workspace_names_[typeid][workspace]. - std::unordered_map<int, std::vector<std::string> > workspace_names_; - - // The maximum workspace id that has been registered. - int max_workspace_id_ = 0; - - TC_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry); -}; - -// A typed collected of workspaces. The workspaces are indexed according to an -// external WorkspaceRegistry. If the WorkspaceSet is const, the contents are -// also immutable. -class WorkspaceSet { - public: - ~WorkspaceSet() { Reset(WorkspaceRegistry()); } - - // Returns true if a workspace has been set. - template <class W> - bool Has(int index) const { - const int id = TypeId<W>::type_id; - TC_DCHECK_GE(id, 0); - TC_DCHECK_LT(id, workspaces_.size()); - TC_DCHECK_GE(index, 0); - TC_DCHECK_LT(index, workspaces_[id].size()); - if (id >= workspaces_.size()) return false; - return workspaces_[id][index] != nullptr; - } - - // Returns an indexed workspace; the workspace must have been set. - template <class W> - const W &Get(int index) const { - TC_DCHECK(Has<W>(index)); - const int id = TypeId<W>::type_id; - const Workspace *w = workspaces_[id][index]; - return reinterpret_cast<const W &>(*w); - } - - // Sets an indexed workspace; this takes ownership of the workspace, which - // must have been new-allocated. It is an error to set a workspace twice. - template <class W> - void Set(int index, W *workspace) { - const int id = TypeId<W>::type_id; - TC_DCHECK_GE(id, 0); - TC_DCHECK_LT(id, workspaces_.size()); - TC_DCHECK_GE(index, 0); - TC_DCHECK_LT(index, workspaces_[id].size()); - TC_DCHECK(workspaces_[id][index] == nullptr); - TC_DCHECK(workspace != nullptr); - workspaces_[id][index] = workspace; - } - - void Reset(const WorkspaceRegistry ®istry) { - // Deallocate current workspaces. - for (auto &it : workspaces_) { - for (size_t index = 0; index < it.size(); ++index) { - delete it[index]; - } - } - workspaces_.clear(); - workspaces_.resize(registry.MaxId() + 1, std::vector<Workspace *>()); - for (auto &it : registry.WorkspaceNames()) { - workspaces_[it.first].resize(it.second.size()); - } - } - - private: - // The set of workspaces, indexed as workspaces_[typeid][index]. - std::vector<std::vector<Workspace *> > workspaces_; -}; - -// A workspace that wraps around a single int. -class SingletonIntWorkspace : public Workspace { - public: - // Default-initializes the int value. - SingletonIntWorkspace() {} - - // Initializes the int with the given value. - explicit SingletonIntWorkspace(int value) : value_(value) {} - - // Returns the name of this type of workspace. - static std::string TypeName() { return "SingletonInt"; } - - // Returns the int value. - int get() const { return value_; } - - // Sets the int value. - void set(int value) { value_ = value; } - - private: - // The enclosed int. - int value_ = 0; -}; - -// A workspace that wraps around a vector of int. -class VectorIntWorkspace : public Workspace { - public: - // Creates a vector of the given size. - explicit VectorIntWorkspace(int size); - - // Creates a vector initialized with the given array. - explicit VectorIntWorkspace(const std::vector<int> &elements); - - // Creates a vector of the given size, with each element initialized to the - // given value. - VectorIntWorkspace(int size, int value); - - // Returns the name of this type of workspace. - static std::string TypeName(); - - // Returns the i'th element. - int element(int i) const { return elements_[i]; } - - // Sets the i'th element. - void set_element(int i, int value) { elements_[i] = value; } - - private: - // The enclosed vector. - std::vector<int> elements_; -}; - -// A workspace that wraps around a vector of vector of int. -class VectorVectorIntWorkspace : public Workspace { - public: - // Creates a vector of empty vectors of the given size. - explicit VectorVectorIntWorkspace(int size); - - // Returns the name of this type of workspace. - static std::string TypeName(); - - // Returns the i'th vector of elements. - const std::vector<int> &elements(int i) const { return elements_[i]; } - - // Mutable access to the i'th vector of elements. - std::vector<int> *mutable_elements(int i) { return &(elements_[i]); } - - private: - // The enclosed vector of vector of elements. - std::vector<std::vector<int> > elements_; -}; - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_COMMON_WORKSPACE_H_ diff --git a/datetime/extractor.cc b/datetime/extractor.cc new file mode 100644 index 0000000..f4ab8f4 --- /dev/null +++ b/datetime/extractor.cc @@ -0,0 +1,469 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datetime/extractor.h" + +#include "util/base/logging.h" + +namespace libtextclassifier2 { + +bool DatetimeExtractor::Extract(DateParseData* result, + CodepointSpan* result_span) const { + result->field_set_mask = 0; + *result_span = {kInvalidIndex, kInvalidIndex}; + + if (rule_.regex->groups() == nullptr) { + return false; + } + + for (int group_id = 0; group_id < rule_.regex->groups()->size(); group_id++) { + UnicodeText group_text; + const int group_type = rule_.regex->groups()->Get(group_id); + if (group_type == DatetimeGroupType_GROUP_UNUSED) { + continue; + } + if (!GroupTextFromMatch(group_id, &group_text)) { + TC_LOG(ERROR) << "Couldn't retrieve group."; + return false; + } + // The pattern can have a group defined in a part that was not matched, + // e.g. an optional part. In this case we'll get an empty content here. + if (group_text.empty()) { + continue; + } + switch (group_type) { + case DatetimeGroupType_GROUP_YEAR: { + if (!ParseYear(group_text, &(result->year))) { + TC_LOG(ERROR) << "Couldn't extract YEAR."; + return false; + } + result->field_set_mask |= DateParseData::YEAR_FIELD; + break; + } + case DatetimeGroupType_GROUP_MONTH: { + if (!ParseMonth(group_text, &(result->month))) { + TC_LOG(ERROR) << "Couldn't extract MONTH."; + return false; + } + result->field_set_mask |= DateParseData::MONTH_FIELD; + break; + } + case DatetimeGroupType_GROUP_DAY: { + if (!ParseDigits(group_text, &(result->day_of_month))) { + TC_LOG(ERROR) << "Couldn't extract DAY."; + return false; + } + result->field_set_mask |= DateParseData::DAY_FIELD; + break; + } + case DatetimeGroupType_GROUP_HOUR: { + if (!ParseDigits(group_text, &(result->hour))) { + TC_LOG(ERROR) << "Couldn't extract HOUR."; + return false; + } + result->field_set_mask |= DateParseData::HOUR_FIELD; + break; + } + case DatetimeGroupType_GROUP_MINUTE: { + if (!ParseDigits(group_text, &(result->minute))) { + TC_LOG(ERROR) << "Couldn't extract MINUTE."; + return false; + } + result->field_set_mask |= DateParseData::MINUTE_FIELD; + break; + } + case DatetimeGroupType_GROUP_SECOND: { + if (!ParseDigits(group_text, &(result->second))) { + TC_LOG(ERROR) << "Couldn't extract SECOND."; + return false; + } + result->field_set_mask |= DateParseData::SECOND_FIELD; + break; + } + case DatetimeGroupType_GROUP_AMPM: { + if (!ParseAMPM(group_text, &(result->ampm))) { + TC_LOG(ERROR) << "Couldn't extract AMPM."; + return false; + } + result->field_set_mask |= DateParseData::AMPM_FIELD; + break; + } + case DatetimeGroupType_GROUP_RELATIONDISTANCE: { + if (!ParseRelationDistance(group_text, &(result->relation_distance))) { + TC_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD."; + return false; + } + result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD; + break; + } + case DatetimeGroupType_GROUP_RELATION: { + if (!ParseRelation(group_text, &(result->relation))) { + TC_LOG(ERROR) << "Couldn't extract RELATION_FIELD."; + return false; + } + result->field_set_mask |= DateParseData::RELATION_FIELD; + break; + } + case DatetimeGroupType_GROUP_RELATIONTYPE: { + if (!ParseRelationType(group_text, &(result->relation_type))) { + TC_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD."; + return false; + } + result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD; + break; + } + case DatetimeGroupType_GROUP_DUMMY1: + case DatetimeGroupType_GROUP_DUMMY2: + break; + default: + TC_LOG(INFO) << "Unknown group type."; + continue; + } + if (!UpdateMatchSpan(group_id, result_span)) { + TC_LOG(ERROR) << "Couldn't update span."; + return false; + } + } + + if (result_span->first == kInvalidIndex || + result_span->second == kInvalidIndex) { + *result_span = {kInvalidIndex, kInvalidIndex}; + } + + return true; +} + +bool DatetimeExtractor::RuleIdForType(DatetimeExtractorType type, + int* rule_id) const { + auto type_it = type_and_locale_to_rule_.find(type); + if (type_it == type_and_locale_to_rule_.end()) { + return false; + } + + auto locale_it = type_it->second.find(locale_id_); + if (locale_it == type_it->second.end()) { + return false; + } + *rule_id = locale_it->second; + return true; +} + +bool DatetimeExtractor::ExtractType(const UnicodeText& input, + DatetimeExtractorType extractor_type, + UnicodeText* match_result) const { + int rule_id; + if (!RuleIdForType(extractor_type, &rule_id)) { + return false; + } + + std::unique_ptr<UniLib::RegexMatcher> matcher = + rules_[rule_id]->Matcher(input); + if (!matcher) { + return false; + } + + int status; + if (!matcher->Find(&status)) { + return false; + } + + if (match_result != nullptr) { + *match_result = matcher->Group(&status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + } + return true; +} + +bool DatetimeExtractor::GroupTextFromMatch(int group_id, + UnicodeText* result) const { + int status; + *result = matcher_.Group(group_id, &status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + return true; +} + +bool DatetimeExtractor::UpdateMatchSpan(int group_id, + CodepointSpan* span) const { + int status; + const int match_start = matcher_.Start(group_id, &status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + const int match_end = matcher_.End(group_id, &status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + if (span->first == kInvalidIndex || span->first > match_start) { + span->first = match_start; + } + if (span->second == kInvalidIndex || span->second < match_end) { + span->second = match_end; + } + + return true; +} + +template <typename T> +bool DatetimeExtractor::MapInput( + const UnicodeText& input, + const std::vector<std::pair<DatetimeExtractorType, T>>& mapping, + T* result) const { + for (const auto& type_value_pair : mapping) { + if (ExtractType(input, type_value_pair.first)) { + *result = type_value_pair.second; + return true; + } + } + return false; +} + +bool DatetimeExtractor::ParseWrittenNumber(const UnicodeText& input, + int* parsed_number) const { + std::vector<std::pair<int, int>> found_numbers; + for (const auto& type_value_pair : + std::vector<std::pair<DatetimeExtractorType, int>>{ + {DatetimeExtractorType_ZERO, 0}, + {DatetimeExtractorType_ONE, 1}, + {DatetimeExtractorType_TWO, 2}, + {DatetimeExtractorType_THREE, 3}, + {DatetimeExtractorType_FOUR, 4}, + {DatetimeExtractorType_FIVE, 5}, + {DatetimeExtractorType_SIX, 6}, + {DatetimeExtractorType_SEVEN, 7}, + {DatetimeExtractorType_EIGHT, 8}, + {DatetimeExtractorType_NINE, 9}, + {DatetimeExtractorType_TEN, 10}, + {DatetimeExtractorType_ELEVEN, 11}, + {DatetimeExtractorType_TWELVE, 12}, + {DatetimeExtractorType_THIRTEEN, 13}, + {DatetimeExtractorType_FOURTEEN, 14}, + {DatetimeExtractorType_FIFTEEN, 15}, + {DatetimeExtractorType_SIXTEEN, 16}, + {DatetimeExtractorType_SEVENTEEN, 17}, + {DatetimeExtractorType_EIGHTEEN, 18}, + {DatetimeExtractorType_NINETEEN, 19}, + {DatetimeExtractorType_TWENTY, 20}, + {DatetimeExtractorType_THIRTY, 30}, + {DatetimeExtractorType_FORTY, 40}, + {DatetimeExtractorType_FIFTY, 50}, + {DatetimeExtractorType_SIXTY, 60}, + {DatetimeExtractorType_SEVENTY, 70}, + {DatetimeExtractorType_EIGHTY, 80}, + {DatetimeExtractorType_NINETY, 90}, + {DatetimeExtractorType_HUNDRED, 100}, + {DatetimeExtractorType_THOUSAND, 1000}, + }) { + int rule_id; + if (!RuleIdForType(type_value_pair.first, &rule_id)) { + return false; + } + + std::unique_ptr<UniLib::RegexMatcher> matcher = + rules_[rule_id]->Matcher(input); + if (!matcher) { + return false; + } + + int status; + while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { + int span_start = matcher->Start(&status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + found_numbers.push_back({span_start, type_value_pair.second}); + } + } + + std::sort(found_numbers.begin(), found_numbers.end(), + [](const std::pair<int, int>& a, const std::pair<int, int>& b) { + return a.first < b.first; + }); + + int sum = 0; + int running_value = -1; + // Simple math to make sure we handle written numerical modifiers correctly + // so that :="fifty one thousand and one" maps to 51001 and not 50 1 1000 1. + for (const std::pair<int, int> position_number_pair : found_numbers) { + if (running_value >= 0) { + if (running_value > position_number_pair.second) { + sum += running_value; + running_value = position_number_pair.second; + } else { + running_value *= position_number_pair.second; + } + } else { + running_value = position_number_pair.second; + } + } + sum += running_value; + *parsed_number = sum; + return true; +} + +bool DatetimeExtractor::ParseDigits(const UnicodeText& input, + int* parsed_digits) const { + UnicodeText digit; + if (!ExtractType(input, DatetimeExtractorType_DIGITS, &digit)) { + return false; + } + + if (!unilib_.ParseInt32(digit, parsed_digits)) { + return false; + } + return true; +} + +bool DatetimeExtractor::ParseYear(const UnicodeText& input, + int* parsed_year) const { + if (!ParseDigits(input, parsed_year)) { + return false; + } + + if (*parsed_year < 100) { + if (*parsed_year < 50) { + *parsed_year += 2000; + } else { + *parsed_year += 1900; + } + } + + return true; +} + +bool DatetimeExtractor::ParseMonth(const UnicodeText& input, + int* parsed_month) const { + if (ParseDigits(input, parsed_month)) { + return true; + } + + if (MapInput(input, + { + {DatetimeExtractorType_JANUARY, 1}, + {DatetimeExtractorType_FEBRUARY, 2}, + {DatetimeExtractorType_MARCH, 3}, + {DatetimeExtractorType_APRIL, 4}, + {DatetimeExtractorType_MAY, 5}, + {DatetimeExtractorType_JUNE, 6}, + {DatetimeExtractorType_JULY, 7}, + {DatetimeExtractorType_AUGUST, 8}, + {DatetimeExtractorType_SEPTEMBER, 9}, + {DatetimeExtractorType_OCTOBER, 10}, + {DatetimeExtractorType_NOVEMBER, 11}, + {DatetimeExtractorType_DECEMBER, 12}, + }, + parsed_month)) { + return true; + } + + return false; +} + +bool DatetimeExtractor::ParseAMPM(const UnicodeText& input, + int* parsed_ampm) const { + return MapInput(input, + { + {DatetimeExtractorType_AM, DateParseData::AMPM::AM}, + {DatetimeExtractorType_PM, DateParseData::AMPM::PM}, + }, + parsed_ampm); +} + +bool DatetimeExtractor::ParseRelationDistance(const UnicodeText& input, + int* parsed_distance) const { + if (ParseDigits(input, parsed_distance)) { + return true; + } + if (ParseWrittenNumber(input, parsed_distance)) { + return true; + } + return false; +} + +bool DatetimeExtractor::ParseRelation( + const UnicodeText& input, DateParseData::Relation* parsed_relation) const { + return MapInput( + input, + { + {DatetimeExtractorType_NOW, DateParseData::Relation::NOW}, + {DatetimeExtractorType_YESTERDAY, DateParseData::Relation::YESTERDAY}, + {DatetimeExtractorType_TOMORROW, DateParseData::Relation::TOMORROW}, + {DatetimeExtractorType_NEXT, DateParseData::Relation::NEXT}, + {DatetimeExtractorType_NEXT_OR_SAME, + DateParseData::Relation::NEXT_OR_SAME}, + {DatetimeExtractorType_LAST, DateParseData::Relation::LAST}, + {DatetimeExtractorType_PAST, DateParseData::Relation::PAST}, + {DatetimeExtractorType_FUTURE, DateParseData::Relation::FUTURE}, + }, + parsed_relation); +} + +bool DatetimeExtractor::ParseRelationType( + const UnicodeText& input, + DateParseData::RelationType* parsed_relation_type) const { + return MapInput( + input, + { + {DatetimeExtractorType_MONDAY, DateParseData::MONDAY}, + {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY}, + {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY}, + {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY}, + {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY}, + {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY}, + {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY}, + {DatetimeExtractorType_DAY, DateParseData::DAY}, + {DatetimeExtractorType_WEEK, DateParseData::WEEK}, + {DatetimeExtractorType_MONTH, DateParseData::MONTH}, + {DatetimeExtractorType_YEAR, DateParseData::YEAR}, + }, + parsed_relation_type); +} + +bool DatetimeExtractor::ParseTimeUnit(const UnicodeText& input, + int* parsed_time_unit) const { + return MapInput(input, + { + {DatetimeExtractorType_DAYS, DateParseData::DAYS}, + {DatetimeExtractorType_WEEKS, DateParseData::WEEKS}, + {DatetimeExtractorType_MONTHS, DateParseData::MONTHS}, + {DatetimeExtractorType_HOURS, DateParseData::HOURS}, + {DatetimeExtractorType_MINUTES, DateParseData::MINUTES}, + {DatetimeExtractorType_SECONDS, DateParseData::SECONDS}, + {DatetimeExtractorType_YEARS, DateParseData::YEARS}, + }, + parsed_time_unit); +} + +bool DatetimeExtractor::ParseWeekday(const UnicodeText& input, + int* parsed_weekday) const { + return MapInput( + input, + { + {DatetimeExtractorType_MONDAY, DateParseData::MONDAY}, + {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY}, + {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY}, + {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY}, + {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY}, + {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY}, + {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY}, + }, + parsed_weekday); +} + +} // namespace libtextclassifier2 diff --git a/datetime/extractor.h b/datetime/extractor.h new file mode 100644 index 0000000..5c36ec4 --- /dev/null +++ b/datetime/extractor.h @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_DATETIME_EXTRACTOR_H_ +#define LIBTEXTCLASSIFIER_DATETIME_EXTRACTOR_H_ + +#include <string> +#include <unordered_map> +#include <vector> + +#include "model_generated.h" +#include "types.h" +#include "util/strings/stringpiece.h" +#include "util/utf8/unicodetext.h" +#include "util/utf8/unilib.h" + +namespace libtextclassifier2 { + +struct CompiledRule { + // The compiled regular expression. + std::unique_ptr<const UniLib::RegexPattern> compiled_regex; + + // The uncompiled pattern and information about the pattern groups. + const DatetimeModelPattern_::Regex* regex; + + // DatetimeModelPattern which 'regex' is part of and comes from. + const DatetimeModelPattern* pattern; +}; + +// A helper class for DatetimeParser that extracts structured data +// (DateParseDate) from the current match of the passed RegexMatcher. +class DatetimeExtractor { + public: + DatetimeExtractor( + const CompiledRule& rule, const UniLib::RegexMatcher& matcher, + int locale_id, const UniLib& unilib, + const std::vector<std::unique_ptr<const UniLib::RegexPattern>>& + extractor_rules, + const std::unordered_map<DatetimeExtractorType, + std::unordered_map<int, int>>& + type_and_locale_to_extractor_rule) + : rule_(rule), + matcher_(matcher), + locale_id_(locale_id), + unilib_(unilib), + rules_(extractor_rules), + type_and_locale_to_rule_(type_and_locale_to_extractor_rule) {} + bool Extract(DateParseData* result, CodepointSpan* result_span) const; + + private: + bool RuleIdForType(DatetimeExtractorType type, int* rule_id) const; + + // Returns true if the rule for given extractor matched. If it matched, + // match_result will contain the first group of the rule (if match_result not + // nullptr). + bool ExtractType(const UnicodeText& input, + DatetimeExtractorType extractor_type, + UnicodeText* match_result = nullptr) const; + + bool GroupTextFromMatch(int group_id, UnicodeText* result) const; + + // Updates the span to include the current match for the given group. + bool UpdateMatchSpan(int group_id, CodepointSpan* span) const; + + // Returns true if any of the extractors from 'mapping' matched. If it did, + // will fill 'result' with the associated value from 'mapping'. + template <typename T> + bool MapInput(const UnicodeText& input, + const std::vector<std::pair<DatetimeExtractorType, T>>& mapping, + T* result) const; + + bool ParseDigits(const UnicodeText& input, int* parsed_digits) const; + bool ParseWrittenNumber(const UnicodeText& input, int* parsed_number) const; + bool ParseYear(const UnicodeText& input, int* parsed_year) const; + bool ParseMonth(const UnicodeText& input, int* parsed_month) const; + bool ParseAMPM(const UnicodeText& input, int* parsed_ampm) const; + bool ParseRelation(const UnicodeText& input, + DateParseData::Relation* parsed_relation) const; + bool ParseRelationDistance(const UnicodeText& input, + int* parsed_distance) const; + bool ParseTimeUnit(const UnicodeText& input, int* parsed_time_unit) const; + bool ParseRelationType( + const UnicodeText& input, + DateParseData::RelationType* parsed_relation_type) const; + bool ParseWeekday(const UnicodeText& input, int* parsed_weekday) const; + + const CompiledRule& rule_; + const UniLib::RegexMatcher& matcher_; + int locale_id_; + const UniLib& unilib_; + const std::vector<std::unique_ptr<const UniLib::RegexPattern>>& rules_; + const std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>& + type_and_locale_to_rule_; +}; + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_DATETIME_EXTRACTOR_H_ diff --git a/datetime/parser.cc b/datetime/parser.cc new file mode 100644 index 0000000..4bc5dff --- /dev/null +++ b/datetime/parser.cc @@ -0,0 +1,405 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datetime/parser.h" + +#include <set> +#include <unordered_set> + +#include "datetime/extractor.h" +#include "util/calendar/calendar.h" +#include "util/i18n/locale.h" +#include "util/strings/split.h" + +namespace libtextclassifier2 { +std::unique_ptr<DatetimeParser> DatetimeParser::Instance( + const DatetimeModel* model, const UniLib& unilib, + ZlibDecompressor* decompressor) { + std::unique_ptr<DatetimeParser> result( + new DatetimeParser(model, unilib, decompressor)); + if (!result->initialized_) { + result.reset(); + } + return result; +} + +DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib, + ZlibDecompressor* decompressor) + : unilib_(unilib) { + initialized_ = false; + + if (model == nullptr) { + return; + } + + if (model->patterns() != nullptr) { + for (const DatetimeModelPattern* pattern : *model->patterns()) { + if (pattern->regexes()) { + for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) { + std::unique_ptr<UniLib::RegexPattern> regex_pattern = + UncompressMakeRegexPattern(unilib, regex->pattern(), + regex->compressed_pattern(), + decompressor); + if (!regex_pattern) { + TC_LOG(ERROR) << "Couldn't create rule pattern."; + return; + } + rules_.push_back({std::move(regex_pattern), regex, pattern}); + if (pattern->locales()) { + for (int locale : *pattern->locales()) { + locale_to_rules_[locale].push_back(rules_.size() - 1); + } + } + } + } + } + } + + if (model->extractors() != nullptr) { + for (const DatetimeModelExtractor* extractor : *model->extractors()) { + std::unique_ptr<UniLib::RegexPattern> regex_pattern = + UncompressMakeRegexPattern(unilib, extractor->pattern(), + extractor->compressed_pattern(), + decompressor); + if (!regex_pattern) { + TC_LOG(ERROR) << "Couldn't create extractor pattern"; + return; + } + extractor_rules_.push_back(std::move(regex_pattern)); + + if (extractor->locales()) { + for (int locale : *extractor->locales()) { + type_and_locale_to_extractor_rule_[extractor->extractor()][locale] = + extractor_rules_.size() - 1; + } + } + } + } + + if (model->locales() != nullptr) { + for (int i = 0; i < model->locales()->Length(); ++i) { + locale_string_to_id_[model->locales()->Get(i)->str()] = i; + } + } + + if (model->default_locales() != nullptr) { + for (const int locale : *model->default_locales()) { + default_locale_ids_.push_back(locale); + } + } + + use_extractors_for_locating_ = model->use_extractors_for_locating(); + + initialized_ = true; +} + +bool DatetimeParser::Parse( + const std::string& input, const int64 reference_time_ms_utc, + const std::string& reference_timezone, const std::string& locales, + ModeFlag mode, bool anchor_start_end, + std::vector<DatetimeParseResultSpan>* results) const { + return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false), + reference_time_ms_utc, reference_timezone, locales, mode, + anchor_start_end, results); +} + +bool DatetimeParser::FindSpansUsingLocales( + const std::vector<int>& locale_ids, const UnicodeText& input, + const int64 reference_time_ms_utc, const std::string& reference_timezone, + ModeFlag mode, bool anchor_start_end, const std::string& reference_locale, + std::unordered_set<int>* executed_rules, + std::vector<DatetimeParseResultSpan>* found_spans) const { + for (const int locale_id : locale_ids) { + auto rules_it = locale_to_rules_.find(locale_id); + if (rules_it == locale_to_rules_.end()) { + continue; + } + + for (const int rule_id : rules_it->second) { + // Skip rules that were already executed in previous locales. + if (executed_rules->find(rule_id) != executed_rules->end()) { + continue; + } + + if (!(rules_[rule_id].pattern->enabled_modes() & mode)) { + continue; + } + + executed_rules->insert(rule_id); + + if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc, + reference_timezone, reference_locale, locale_id, + anchor_start_end, found_spans)) { + return false; + } + } + } + return true; +} + +bool DatetimeParser::Parse( + const UnicodeText& input, const int64 reference_time_ms_utc, + const std::string& reference_timezone, const std::string& locales, + ModeFlag mode, bool anchor_start_end, + std::vector<DatetimeParseResultSpan>* results) const { + std::vector<DatetimeParseResultSpan> found_spans; + std::unordered_set<int> executed_rules; + std::string reference_locale; + const std::vector<int> requested_locales = + ParseAndExpandLocales(locales, &reference_locale); + if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc, + reference_timezone, mode, anchor_start_end, + reference_locale, &executed_rules, &found_spans)) { + return false; + } + + std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans; + int counter = 0; + for (const auto& found_span : found_spans) { + indexed_found_spans.push_back({found_span, counter}); + counter++; + } + + // Resolve conflicts by always picking the longer span and breaking ties by + // selecting the earlier entry in the list for a given locale. + std::sort(indexed_found_spans.begin(), indexed_found_spans.end(), + [](const std::pair<DatetimeParseResultSpan, int>& a, + const std::pair<DatetimeParseResultSpan, int>& b) { + if ((a.first.span.second - a.first.span.first) != + (b.first.span.second - b.first.span.first)) { + return (a.first.span.second - a.first.span.first) > + (b.first.span.second - b.first.span.first); + } else { + return a.second < b.second; + } + }); + + found_spans.clear(); + for (auto& span_index_pair : indexed_found_spans) { + found_spans.push_back(span_index_pair.first); + } + + std::set<int, std::function<bool(int, int)>> chosen_indices_set( + [&found_spans](int a, int b) { + return found_spans[a].span.first < found_spans[b].span.first; + }); + for (int i = 0; i < found_spans.size(); ++i) { + if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) { + chosen_indices_set.insert(i); + results->push_back(found_spans[i]); + } + } + + return true; +} + +bool DatetimeParser::HandleParseMatch( + const CompiledRule& rule, const UniLib::RegexMatcher& matcher, + int64 reference_time_ms_utc, const std::string& reference_timezone, + const std::string& reference_locale, int locale_id, + std::vector<DatetimeParseResultSpan>* result) const { + int status = UniLib::RegexMatcher::kNoError; + const int start = matcher.Start(&status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + + const int end = matcher.End(&status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + + DatetimeParseResultSpan parse_result; + if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone, + reference_locale, locale_id, &(parse_result.data), + &parse_result.span)) { + return false; + } + if (!use_extractors_for_locating_) { + parse_result.span = {start, end}; + } + if (parse_result.span.first != kInvalidIndex && + parse_result.span.second != kInvalidIndex) { + parse_result.target_classification_score = + rule.pattern->target_classification_score(); + parse_result.priority_score = rule.pattern->priority_score(); + result->push_back(parse_result); + } + return true; +} + +bool DatetimeParser::ParseWithRule( + const CompiledRule& rule, const UnicodeText& input, + const int64 reference_time_ms_utc, const std::string& reference_timezone, + const std::string& reference_locale, const int locale_id, + bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const { + std::unique_ptr<UniLib::RegexMatcher> matcher = + rule.compiled_regex->Matcher(input); + int status = UniLib::RegexMatcher::kNoError; + if (anchor_start_end) { + if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) { + if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc, + reference_timezone, reference_locale, locale_id, + result)) { + return false; + } + } + } else { + while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { + if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc, + reference_timezone, reference_locale, locale_id, + result)) { + return false; + } + } + } + return true; +} + +std::vector<int> DatetimeParser::ParseAndExpandLocales( + const std::string& locales, std::string* reference_locale) const { + std::vector<StringPiece> split_locales = strings::Split(locales, ','); + if (!split_locales.empty()) { + *reference_locale = split_locales[0].ToString(); + } else { + *reference_locale = ""; + } + + std::vector<int> result; + for (const StringPiece& locale_str : split_locales) { + auto locale_it = locale_string_to_id_.find(locale_str.ToString()); + if (locale_it != locale_string_to_id_.end()) { + result.push_back(locale_it->second); + } + + const Locale locale = Locale::FromBCP47(locale_str.ToString()); + if (!locale.IsValid()) { + continue; + } + + const std::string language = locale.Language(); + const std::string script = locale.Script(); + const std::string region = locale.Region(); + + // First, try adding *-region locale. + if (!region.empty()) { + locale_it = locale_string_to_id_.find("*-" + region); + if (locale_it != locale_string_to_id_.end()) { + result.push_back(locale_it->second); + } + } + // Second, try adding language-script-* locale. + if (!script.empty()) { + locale_it = locale_string_to_id_.find(language + "-" + script + "-*"); + if (locale_it != locale_string_to_id_.end()) { + result.push_back(locale_it->second); + } + } + // Third, try adding language-* locale. + if (!language.empty()) { + locale_it = locale_string_to_id_.find(language + "-*"); + if (locale_it != locale_string_to_id_.end()) { + result.push_back(locale_it->second); + } + } + } + + // Add the default locales if they haven't been added already. + const std::unordered_set<int> result_set(result.begin(), result.end()); + for (const int default_locale_id : default_locale_ids_) { + if (result_set.find(default_locale_id) == result_set.end()) { + result.push_back(default_locale_id); + } + } + + return result; +} + +namespace { + +DatetimeGranularity GetGranularity(const DateParseData& data) { + DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR; + if ((data.field_set_mask & DateParseData::YEAR_FIELD) || + (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && + (data.relation_type == DateParseData::RelationType::YEAR))) { + granularity = DatetimeGranularity::GRANULARITY_YEAR; + } + if ((data.field_set_mask & DateParseData::MONTH_FIELD) || + (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && + (data.relation_type == DateParseData::RelationType::MONTH))) { + granularity = DatetimeGranularity::GRANULARITY_MONTH; + } + if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && + (data.relation_type == DateParseData::RelationType::WEEK)) { + granularity = DatetimeGranularity::GRANULARITY_WEEK; + } + if (data.field_set_mask & DateParseData::DAY_FIELD || + (data.field_set_mask & DateParseData::RELATION_FIELD && + (data.relation == DateParseData::Relation::NOW || + data.relation == DateParseData::Relation::TOMORROW || + data.relation == DateParseData::Relation::YESTERDAY)) || + (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && + (data.relation_type == DateParseData::RelationType::MONDAY || + data.relation_type == DateParseData::RelationType::TUESDAY || + data.relation_type == DateParseData::RelationType::WEDNESDAY || + data.relation_type == DateParseData::RelationType::THURSDAY || + data.relation_type == DateParseData::RelationType::FRIDAY || + data.relation_type == DateParseData::RelationType::SATURDAY || + data.relation_type == DateParseData::RelationType::SUNDAY || + data.relation_type == DateParseData::RelationType::DAY))) { + granularity = DatetimeGranularity::GRANULARITY_DAY; + } + if (data.field_set_mask & DateParseData::HOUR_FIELD) { + granularity = DatetimeGranularity::GRANULARITY_HOUR; + } + if (data.field_set_mask & DateParseData::MINUTE_FIELD) { + granularity = DatetimeGranularity::GRANULARITY_MINUTE; + } + if (data.field_set_mask & DateParseData::SECOND_FIELD) { + granularity = DatetimeGranularity::GRANULARITY_SECOND; + } + return granularity; +} + +} // namespace + +bool DatetimeParser::ExtractDatetime(const CompiledRule& rule, + const UniLib::RegexMatcher& matcher, + const int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, + int locale_id, DatetimeParseResult* result, + CodepointSpan* result_span) const { + DateParseData parse; + DatetimeExtractor extractor(rule, matcher, locale_id, unilib_, + extractor_rules_, + type_and_locale_to_extractor_rule_); + if (!extractor.Extract(&parse, result_span)) { + return false; + } + + result->granularity = GetGranularity(parse); + + if (!calendar_lib_.InterpretParseData( + parse, reference_time_ms_utc, reference_timezone, reference_locale, + result->granularity, &(result->time_ms_utc))) { + return false; + } + + return true; +} + +} // namespace libtextclassifier2 diff --git a/datetime/parser.h b/datetime/parser.h new file mode 100644 index 0000000..0666607 --- /dev/null +++ b/datetime/parser.h @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_DATETIME_PARSER_H_ +#define LIBTEXTCLASSIFIER_DATETIME_PARSER_H_ + +#include <memory> +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "datetime/extractor.h" +#include "model_generated.h" +#include "types.h" +#include "util/base/integral_types.h" +#include "util/calendar/calendar.h" +#include "util/utf8/unilib.h" +#include "zlib-utils.h" + +namespace libtextclassifier2 { + +// Parses datetime expressions in the input and resolves them to actual absolute +// time. +class DatetimeParser { + public: + static std::unique_ptr<DatetimeParser> Instance( + const DatetimeModel* model, const UniLib& unilib, + ZlibDecompressor* decompressor); + + // Parses the dates in 'input' and fills result. Makes sure that the results + // do not overlap. + // If 'anchor_start_end' is true the extracted results need to start at the + // beginning of 'input' and end at the end of it. + bool Parse(const std::string& input, int64 reference_time_ms_utc, + const std::string& reference_timezone, const std::string& locales, + ModeFlag mode, bool anchor_start_end, + std::vector<DatetimeParseResultSpan>* results) const; + + // Same as above but takes UnicodeText. + bool Parse(const UnicodeText& input, int64 reference_time_ms_utc, + const std::string& reference_timezone, const std::string& locales, + ModeFlag mode, bool anchor_start_end, + std::vector<DatetimeParseResultSpan>* results) const; + + protected: + DatetimeParser(const DatetimeModel* model, const UniLib& unilib, + ZlibDecompressor* decompressor); + + // Returns a list of locale ids for given locale spec string (comma-separated + // locale names). Assigns the first parsed locale to reference_locale. + std::vector<int> ParseAndExpandLocales(const std::string& locales, + std::string* reference_locale) const; + + // Helper function that finds datetime spans, only using the rules associated + // with the given locales. + bool FindSpansUsingLocales( + const std::vector<int>& locale_ids, const UnicodeText& input, + const int64 reference_time_ms_utc, const std::string& reference_timezone, + ModeFlag mode, bool anchor_start_end, const std::string& reference_locale, + std::unordered_set<int>* executed_rules, + std::vector<DatetimeParseResultSpan>* found_spans) const; + + bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, const int locale_id, + bool anchor_start_end, + std::vector<DatetimeParseResultSpan>* result) const; + + // Converts the current match in 'matcher' into DatetimeParseResult. + bool ExtractDatetime(const CompiledRule& rule, + const UniLib::RegexMatcher& matcher, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, int locale_id, + DatetimeParseResult* result, + CodepointSpan* result_span) const; + + // Parse and extract information from current match in 'matcher'. + bool HandleParseMatch(const CompiledRule& rule, + const UniLib::RegexMatcher& matcher, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, int locale_id, + std::vector<DatetimeParseResultSpan>* result) const; + + private: + bool initialized_; + const UniLib& unilib_; + std::vector<CompiledRule> rules_; + std::unordered_map<int, std::vector<int>> locale_to_rules_; + std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_; + std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>> + type_and_locale_to_extractor_rule_; + std::unordered_map<std::string, int> locale_string_to_id_; + std::vector<int> default_locale_ids_; + CalendarLib calendar_lib_; + bool use_extractors_for_locating_; +}; + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_DATETIME_PARSER_H_ diff --git a/datetime/parser_test.cc b/datetime/parser_test.cc new file mode 100644 index 0000000..e61ed12 --- /dev/null +++ b/datetime/parser_test.cc @@ -0,0 +1,457 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <time.h> +#include <fstream> +#include <iostream> +#include <memory> +#include <string> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "datetime/parser.h" +#include "model_generated.h" +#include "text-classifier.h" +#include "types-test-util.h" + +using testing::ElementsAreArray; + +namespace libtextclassifier2 { +namespace { + +std::string GetModelPath() { + return LIBTEXTCLASSIFIER_TEST_DATA_DIR; +} + +std::string ReadFile(const std::string& file_name) { + std::ifstream file_stream(file_name); + return std::string(std::istreambuf_iterator<char>(file_stream), {}); +} + +std::string FormatMillis(int64 time_ms_utc) { + long time_seconds = time_ms_utc / 1000; // NOLINT + // Format time, "ddd yyyy-mm-dd hh:mm:ss zzz" + char buffer[512]; + strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z", + localtime(&time_seconds)); + return std::string(buffer); +} + +class ParserTest : public testing::Test { + public: + void SetUp() override { + model_buffer_ = ReadFile(GetModelPath() + "test_model.fb"); + classifier_ = TextClassifier::FromUnownedBuffer( + model_buffer_.data(), model_buffer_.size(), &unilib_); + TC_CHECK(classifier_); + parser_ = classifier_->DatetimeParserForTests(); + } + + bool HasNoResult(const std::string& text, bool anchor_start_end = false, + const std::string& timezone = "Europe/Zurich") { + std::vector<DatetimeParseResultSpan> results; + if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION, + anchor_start_end, &results)) { + TC_LOG(ERROR) << text; + TC_CHECK(false); + } + return results.empty(); + } + + bool ParsesCorrectly(const std::string& marked_text, + const int64 expected_ms_utc, + DatetimeGranularity expected_granularity, + bool anchor_start_end = false, + const std::string& timezone = "Europe/Zurich", + const std::string& locales = "en-US") { + const UnicodeText marked_text_unicode = + UTF8ToUnicodeText(marked_text, /*do_copy=*/false); + auto brace_open_it = + std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{'); + auto brace_end_it = + std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}'); + TC_CHECK(brace_open_it != marked_text_unicode.end()); + TC_CHECK(brace_end_it != marked_text_unicode.end()); + + std::string text; + text += + UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it); + text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it); + text += UnicodeText::UTF8Substring(std::next(brace_end_it), + marked_text_unicode.end()); + + std::vector<DatetimeParseResultSpan> results; + + if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION, + anchor_start_end, &results)) { + TC_LOG(ERROR) << text; + TC_CHECK(false); + } + if (results.empty()) { + TC_LOG(ERROR) << "No results."; + return false; + } + + const int expected_start_index = + std::distance(marked_text_unicode.begin(), brace_open_it); + // The -1 bellow is to account for the opening bracket character. + const int expected_end_index = + std::distance(marked_text_unicode.begin(), brace_end_it) - 1; + + std::vector<DatetimeParseResultSpan> filtered_results; + for (const DatetimeParseResultSpan& result : results) { + if (SpansOverlap(result.span, + {expected_start_index, expected_end_index})) { + filtered_results.push_back(result); + } + } + + const std::vector<DatetimeParseResultSpan> expected{ + {{expected_start_index, expected_end_index}, + {expected_ms_utc, expected_granularity}, + /*target_classification_score=*/1.0, + /*priority_score=*/0.0}}; + const bool matches = + testing::Matches(ElementsAreArray(expected))(filtered_results); + if (!matches) { + TC_LOG(ERROR) << "Expected: " << expected[0] << " which corresponds to: " + << FormatMillis(expected[0].data.time_ms_utc); + for (int i = 0; i < filtered_results.size(); ++i) { + TC_LOG(ERROR) << "Actual[" << i << "]: " << filtered_results[i] + << " which corresponds to: " + << FormatMillis(filtered_results[i].data.time_ms_utc); + } + } + return matches; + } + + bool ParsesCorrectlyGerman(const std::string& marked_text, + const int64 expected_ms_utc, + DatetimeGranularity expected_granularity) { + return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", /*locales=*/"de"); + } + + protected: + std::string model_buffer_; + std::unique_ptr<TextClassifier> classifier_; + const DatetimeParser* parser_; + UniLib unilib_; +}; + +// Test with just a few cases to make debugging of general failures easier. +TEST_F(ParserTest, ParseShort) { + EXPECT_TRUE( + ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_DAY)); +} + +TEST_F(ParserTest, Parse) { + EXPECT_TRUE( + ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectly("{january 31 2018}", 1517353200000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000, + GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectly("{09/Mar/2004 22:02:40}", 1078866160000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{Dec 2, 2010 2:39:58 AM}", 1291253998000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{Jun 09 2011 15:28:14}", 1307626094000, + GRANULARITY_SECOND)); + EXPECT_TRUE( + ParsesCorrectly("{Mar 16 08:12:04}", 6419524000, GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29},573", 1277512289000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}", 1137899465000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{11:42:35}", 38555000, GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{11:42:35}.173", 38555000, GRANULARITY_SECOND)); + EXPECT_TRUE( + ParsesCorrectly("{23/Apr 11:42:35},173", 9715355000, GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}.883", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}.883", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}.883", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly( + "Are sentiments apartments decisively the especially alteration. " + "Thrown shy denote ten ladies though ask saw. Or by to he going " + "think order event music. Incommode so intention defective at " + "convinced. Led income months itself and houses you. After nor " + "you leave might share court balls. {19/apr/2010 06:36:15} Are " + "sentiments apartments decisively the especially alteration. " + "Thrown shy denote ten ladies though ask saw. Or by to he going " + "think order event music. Incommode so intention defective at " + "convinced. Led income months itself and houses you. After nor " + "you leave might share court balls. ", + 1271651775000, GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000, + GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30 am}", 1514777400000, + GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000, + GRANULARITY_HOUR)); + + EXPECT_TRUE(ParsesCorrectly("{today}", -3600000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectly("{today}", -57600000, GRANULARITY_DAY, + /*anchor_start_end=*/false, + "America/Los_Angeles")); + EXPECT_TRUE(ParsesCorrectly("{next week}", 255600000, GRANULARITY_WEEK)); + EXPECT_TRUE(ParsesCorrectly("{next day}", 82800000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectly("{in three days}", 255600000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectly("{in three weeks}", 1465200000, GRANULARITY_WEEK)); + EXPECT_TRUE(ParsesCorrectly("{tomorrow}", 82800000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectly("{tomorrow at 4:00}", 97200000, GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4}", 97200000, GRANULARITY_HOUR)); + EXPECT_TRUE(ParsesCorrectly("{next wednesday}", 514800000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectly("{next wednesday at 4}", 529200000, GRANULARITY_HOUR)); + EXPECT_TRUE(ParsesCorrectly("last seen {today at 9:01 PM}", 72060000, + GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectly("{Three days ago}", -262800000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_DAY)); +} + +TEST_F(ParserTest, ParseWithAnchor) { + EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000, + GRANULARITY_DAY, /*anchor_start_end=*/false)); + EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000, + GRANULARITY_DAY, /*anchor_start_end=*/true)); + EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000, + GRANULARITY_DAY, /*anchor_start_end=*/false)); + EXPECT_TRUE(HasNoResult("lorem 1 january 2018 ipsum", + /*anchor_start_end=*/true)); +} + +TEST_F(ParserTest, ParseGerman) { + EXPECT_TRUE( + ParsesCorrectlyGerman("{Januar 1 2018}", 1514761200000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{1 2 2018}", 1517439600000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectlyGerman("lorem {1 Januar 2018} ipsum", + 1514761200000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectlyGerman("{19/Apr/2010:06:36:15}", 1271651775000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{09/März/2004 22:02:40}", 1078866160000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{Dez 2, 2010 2:39:58}", 1291253998000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{Juni 09 2011 15:28:14}", 1307626094000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{März 16 08:12:04}", 6419524000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29},573", 1277512289000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}", 1137899465000, + GRANULARITY_SECOND)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{11:42:35}", 38555000, GRANULARITY_SECOND)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{11:42:35}.173", 38555000, GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr 11:42:35},173", 9715355000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}.883", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}.883", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}.883", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}", 1271651775000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}", 1514777400000, + GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30 nachm}", + 1514820600000, GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4 nachm}", 1514818800000, + GRANULARITY_HOUR)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{14.03.2017}", 1489446000000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectlyGerman("{heute}", -3600000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{nächste Woche}", 342000000, GRANULARITY_WEEK)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{nächsten Tag}", 82800000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{in drei Tagen}", 255600000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{in drei Wochen}", 1551600000, GRANULARITY_WEEK)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{vor drei Tagen}", -262800000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectlyGerman("{morgen}", 82800000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{morgen um 4:00}", 97200000, GRANULARITY_MINUTE)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{morgen um 4}", 97200000, GRANULARITY_HOUR)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{nächsten Mittwoch}", 514800000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectlyGerman("{nächsten Mittwoch um 4}", 529200000, + GRANULARITY_HOUR)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{Vor drei Tagen}", -262800000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{in einer woche}", 342000000, GRANULARITY_WEEK)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{in einer tag}", 82800000, GRANULARITY_DAY)); +} + +TEST_F(ParserTest, ParseNonUs) { + EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", + /*locales=*/"en-GB")); + EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", /*locales=*/"en")); +} + +TEST_F(ParserTest, ParseUs) { + EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", + /*locales=*/"en-US")); + EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", + /*locales=*/"es-US")); +} + +TEST_F(ParserTest, ParseUnknownLanguage) { + EXPECT_TRUE(ParsesCorrectly("bylo to {31. 12. 2015} v 6 hodin", 1451516400000, + GRANULARITY_DAY, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", /*locales=*/"xx")); +} + +class ParserLocaleTest : public testing::Test { + public: + void SetUp() override; + bool HasResult(const std::string& input, const std::string& locales); + + protected: + UniLib unilib_; + flatbuffers::FlatBufferBuilder builder_; + std::unique_ptr<DatetimeParser> parser_; +}; + +void AddPattern(const std::string& regex, int locale, + std::vector<std::unique_ptr<DatetimeModelPatternT>>* patterns) { + patterns->emplace_back(new DatetimeModelPatternT); + patterns->back()->regexes.emplace_back(new DatetimeModelPattern_::RegexT); + patterns->back()->regexes.back()->pattern = regex; + patterns->back()->regexes.back()->groups.push_back( + DatetimeGroupType_GROUP_UNUSED); + patterns->back()->locales.push_back(locale); +} + +void ParserLocaleTest::SetUp() { + DatetimeModelT model; + model.use_extractors_for_locating = false; + model.locales.clear(); + model.locales.push_back("en-US"); + model.locales.push_back("en-CH"); + model.locales.push_back("zh-Hant"); + model.locales.push_back("en-*"); + model.locales.push_back("zh-Hant-*"); + model.locales.push_back("*-CH"); + model.locales.push_back("default"); + model.default_locales.push_back(6); + + AddPattern(/*regex=*/"en-US", /*locale=*/0, &model.patterns); + AddPattern(/*regex=*/"en-CH", /*locale=*/1, &model.patterns); + AddPattern(/*regex=*/"zh-Hant", /*locale=*/2, &model.patterns); + AddPattern(/*regex=*/"en-all", /*locale=*/3, &model.patterns); + AddPattern(/*regex=*/"zh-Hant-all", /*locale=*/4, &model.patterns); + AddPattern(/*regex=*/"all-CH", /*locale=*/5, &model.patterns); + AddPattern(/*regex=*/"default", /*locale=*/6, &model.patterns); + + builder_.Finish(DatetimeModel::Pack(builder_, &model)); + const DatetimeModel* model_fb = + flatbuffers::GetRoot<DatetimeModel>(builder_.GetBufferPointer()); + ASSERT_TRUE(model_fb); + + parser_ = DatetimeParser::Instance(model_fb, unilib_, + /*decompressor=*/nullptr); + ASSERT_TRUE(parser_); +} + +bool ParserLocaleTest::HasResult(const std::string& input, + const std::string& locales) { + std::vector<DatetimeParseResultSpan> results; + EXPECT_TRUE(parser_->Parse(input, /*reference_time_ms_utc=*/0, + /*reference_timezone=*/"", locales, + ModeFlag_ANNOTATION, false, &results)); + return results.size() == 1; +} + +TEST_F(ParserLocaleTest, English) { + EXPECT_TRUE(HasResult("en-US", /*locales=*/"en-US")); + EXPECT_FALSE(HasResult("en-CH", /*locales=*/"en-US")); + EXPECT_FALSE(HasResult("en-US", /*locales=*/"en-CH")); + EXPECT_TRUE(HasResult("en-CH", /*locales=*/"en-CH")); + EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH")); +} + +TEST_F(ParserLocaleTest, TraditionalChinese) { + EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant")); + EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-TW")); + EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-SG")); + EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh-SG")); + EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh")); + EXPECT_TRUE(HasResult("default", /*locales=*/"zh")); + EXPECT_TRUE(HasResult("default", /*locales=*/"zh-Hant-SG")); +} + +TEST_F(ParserLocaleTest, SwissEnglish) { + EXPECT_TRUE(HasResult("all-CH", /*locales=*/"de-CH")); + EXPECT_TRUE(HasResult("all-CH", /*locales=*/"en-CH")); + EXPECT_TRUE(HasResult("en-all", /*locales=*/"en-CH")); + EXPECT_FALSE(HasResult("all-CH", /*locales=*/"de-DE")); + EXPECT_TRUE(HasResult("default", /*locales=*/"de-CH")); + EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH")); +} + +} // namespace +} // namespace libtextclassifier2 diff --git a/smartselect/feature-processor.cc b/feature-processor.cc index c1db95a..551e649 100644 --- a/smartselect/feature-processor.cc +++ b/feature-processor.cc @@ -14,59 +14,51 @@ * limitations under the License. */ -#include "smartselect/feature-processor.h" +#include "feature-processor.h" #include <iterator> #include <set> #include <vector> -#include "smartselect/text-classification-model.pb.h" #include "util/base/logging.h" #include "util/strings/utf8.h" #include "util/utf8/unicodetext.h" -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT -#include "unicode/brkiter.h" -#include "unicode/errorcode.h" -#include "unicode/uchar.h" -#endif -namespace libtextclassifier { +namespace libtextclassifier2 { namespace internal { TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( - const FeatureProcessorOptions& options) { + const FeatureProcessorOptions* const options) { TokenFeatureExtractorOptions extractor_options; - extractor_options.num_buckets = options.num_buckets(); - for (int order : options.chargram_orders()) { - extractor_options.chargram_orders.push_back(order); + extractor_options.num_buckets = options->num_buckets(); + if (options->chargram_orders() != nullptr) { + for (int order : *options->chargram_orders()) { + extractor_options.chargram_orders.push_back(order); + } } - extractor_options.max_word_length = options.max_word_length(); - extractor_options.extract_case_feature = options.extract_case_feature(); - extractor_options.unicode_aware_features = options.unicode_aware_features(); + extractor_options.max_word_length = options->max_word_length(); + extractor_options.extract_case_feature = options->extract_case_feature(); + extractor_options.unicode_aware_features = options->unicode_aware_features(); extractor_options.extract_selection_mask_feature = - options.extract_selection_mask_feature(); - for (int i = 0; i < options.regexp_feature_size(); ++i) { - extractor_options.regexp_features.push_back(options.regexp_feature(i)); + options->extract_selection_mask_feature(); + if (options->regexp_feature() != nullptr) { + for (const auto& regexp_feauture : *options->regexp_feature()) { + extractor_options.regexp_features.push_back(regexp_feauture->str()); + } } - extractor_options.remap_digits = options.remap_digits(); - extractor_options.lowercase_tokens = options.lowercase_tokens(); + extractor_options.remap_digits = options->remap_digits(); + extractor_options.lowercase_tokens = options->lowercase_tokens(); - for (const auto& chargram : options.allowed_chargrams()) { - extractor_options.allowed_chargrams.insert(chargram); + if (options->allowed_chargrams() != nullptr) { + for (const auto& chargram : *options->allowed_chargrams()) { + extractor_options.allowed_chargrams.insert(chargram->str()); + } } - return extractor_options; } -FeatureProcessorOptions ParseSerializedOptions( - const std::string& serialized_options) { - FeatureProcessorOptions options; - options.ParseFromString(serialized_options); - return options; -} - void SplitTokensOnSelectionBoundaries(CodepointSpan selection, std::vector<Token>* tokens) { for (auto it = tokens->begin(); it != tokens->end(); ++it) { @@ -119,6 +111,16 @@ void SplitTokensOnSelectionBoundaries(CodepointSpan selection, } } +const UniLib* MaybeCreateUnilib(const UniLib* unilib, + std::unique_ptr<UniLib>* owned_unilib) { + if (unilib) { + return unilib; + } else { + owned_unilib->reset(new UniLib); + return owned_unilib->get(); + } +} + } // namespace internal void FeatureProcessor::StripTokensFromOtherLines( @@ -126,6 +128,12 @@ void FeatureProcessor::StripTokensFromOtherLines( std::vector<Token>* tokens) const { const UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); + StripTokensFromOtherLines(context_unicode, span, tokens); +} + +void FeatureProcessor::StripTokensFromOtherLines( + const UnicodeText& context_unicode, CodepointSpan span, + std::vector<Token>* tokens) const { std::vector<UnicodeTextRange> lines = SplitContext(context_unicode); auto span_start = context_unicode.begin(); @@ -157,37 +165,43 @@ void FeatureProcessor::StripTokensFromOtherLines( } std::string FeatureProcessor::GetDefaultCollection() const { - if (options_.default_collection() < 0 || - options_.default_collection() >= options_.collections_size()) { + if (options_->default_collection() < 0 || + options_->collections() == nullptr || + options_->default_collection() >= options_->collections()->size()) { TC_LOG(ERROR) << "Invalid or missing default collection. Returning empty string."; return ""; } - return options_.collections(options_.default_collection()); + return (*options_->collections())[options_->default_collection()]->str(); +} + +std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const { + const UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false); + return Tokenize(text_unicode); } std::vector<Token> FeatureProcessor::Tokenize( - const std::string& utf8_text) const { - if (options_.tokenization_type() == - libtextclassifier::FeatureProcessorOptions::INTERNAL_TOKENIZER) { - return tokenizer_.Tokenize(utf8_text); - } else if (options_.tokenization_type() == - libtextclassifier::FeatureProcessorOptions::ICU || - options_.tokenization_type() == - libtextclassifier::FeatureProcessorOptions::MIXED) { + const UnicodeText& text_unicode) const { + if (options_->tokenization_type() == + FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER) { + return tokenizer_.Tokenize(text_unicode); + } else if (options_->tokenization_type() == + FeatureProcessorOptions_::TokenizationType_ICU || + options_->tokenization_type() == + FeatureProcessorOptions_::TokenizationType_MIXED) { std::vector<Token> result; - if (!ICUTokenize(utf8_text, &result)) { + if (!ICUTokenize(text_unicode, &result)) { return {}; } - if (options_.tokenization_type() == - libtextclassifier::FeatureProcessorOptions::MIXED) { - InternalRetokenize(utf8_text, &result); + if (options_->tokenization_type() == + FeatureProcessorOptions_::TokenizationType_MIXED) { + InternalRetokenize(text_unicode, &result); } return result; } else { TC_LOG(ERROR) << "Unknown tokenization type specified. Using " "internal."; - return tokenizer_.Tokenize(utf8_text); + return tokenizer_.Tokenize(text_unicode); } } @@ -205,11 +219,11 @@ bool FeatureProcessor::LabelToSpan( const int result_begin_token_index = token_span.first; const Token& result_begin_token = - tokens[options_.context_size() - result_begin_token_index]; + tokens[options_->context_size() - result_begin_token_index]; const int result_begin_codepoint = result_begin_token.start; const int result_end_token_index = token_span.second; const Token& result_end_token = - tokens[options_.context_size() + result_end_token_index]; + tokens[options_->context_size() + result_end_token_index]; const int result_end_codepoint = result_end_token.end; if (result_begin_codepoint == kInvalidIndex || @@ -224,9 +238,11 @@ bool FeatureProcessor::LabelToSpan( UnicodeText::const_iterator token_end = token_end_unicode.end(); const int begin_ignored = CountIgnoredSpanBoundaryCodepoints( - token_begin, token_begin_unicode.end(), /*count_from_beginning=*/true); - const int end_ignored = CountIgnoredSpanBoundaryCodepoints( - token_end_unicode.begin(), token_end, /*count_from_beginning=*/false); + token_begin, token_begin_unicode.end(), + /*count_from_beginning=*/true); + const int end_ignored = + CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end, + /*count_from_beginning=*/false); // In case everything would be stripped, set the span to the original // beginning and zero length. if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) { @@ -257,8 +273,8 @@ bool FeatureProcessor::SpanToLabel( } const int click_position = - options_.context_size(); // Click is always in the middle. - const int padding = options_.context_size() - options_.max_selection_span(); + options_->context_size(); // Click is always in the middle. + const int padding = options_->context_size() - options_->max_selection_span(); int span_left = 0; for (int i = click_position - 1; i >= padding; i--) { @@ -282,7 +298,7 @@ bool FeatureProcessor::SpanToLabel( bool tokens_match_span; const CodepointIndex tokens_start = tokens[click_position - span_left].start; const CodepointIndex tokens_end = tokens[click_position + span_right].end; - if (options_.snap_label_span_boundaries_to_containing_tokens()) { + if (options_->snap_label_span_boundaries_to_containing_tokens()) { tokens_match_span = tokens_start <= span.first && tokens_end >= span.second; } else { const UnicodeText token_left_unicode = UTF8ToUnicodeText( @@ -296,7 +312,8 @@ bool FeatureProcessor::SpanToLabel( const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints( span_begin, token_left_unicode.end(), /*count_from_beginning=*/true); const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints( - token_right_unicode.begin(), span_end, /*count_from_beginning=*/false); + token_right_unicode.begin(), span_end, + /*count_from_beginning=*/false); tokens_match_span = tokens_start <= span.first && tokens_start + num_punctuation_start >= span.first && @@ -422,19 +439,22 @@ int CenterTokenFromMiddleOfSelection( int FeatureProcessor::FindCenterToken(CodepointSpan span, const std::vector<Token>& tokens) const { - if (options_.center_token_selection_method() == - FeatureProcessorOptions::CENTER_TOKEN_FROM_CLICK) { + if (options_->center_token_selection_method() == + FeatureProcessorOptions_:: + CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) { return internal::CenterTokenFromClick(span, tokens); - } else if (options_.center_token_selection_method() == - FeatureProcessorOptions::CENTER_TOKEN_MIDDLE_OF_SELECTION) { + } else if (options_->center_token_selection_method() == + FeatureProcessorOptions_:: + CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) { return internal::CenterTokenFromMiddleOfSelection(span, tokens); - } else if (options_.center_token_selection_method() == - FeatureProcessorOptions::DEFAULT_CENTER_TOKEN_METHOD) { + } else if (options_->center_token_selection_method() == + FeatureProcessorOptions_:: + CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) { // TODO(zilka): Remove once we have new models on the device. // It uses the fact that sharing model use // split_tokens_on_selection_boundaries and selection not. So depending on // this we select the right way of finding the click location. - if (!options_.split_tokens_on_selection_boundaries()) { + if (!options_->split_tokens_on_selection_boundaries()) { // SmartSelection model. return internal::CenterTokenFromClick(span, tokens); } else { @@ -462,15 +482,15 @@ bool FeatureProcessor::SelectionLabelSpans( } void FeatureProcessor::PrepareCodepointRanges( - const std::vector<FeatureProcessorOptions::CodepointRange>& + const std::vector<const FeatureProcessorOptions_::CodepointRange*>& codepoint_ranges, std::vector<CodepointRange>* prepared_codepoint_ranges) { prepared_codepoint_ranges->clear(); prepared_codepoint_ranges->reserve(codepoint_ranges.size()); - for (const FeatureProcessorOptions::CodepointRange& range : + for (const FeatureProcessorOptions_::CodepointRange* range : codepoint_ranges) { prepared_codepoint_ranges->push_back( - CodepointRange(range.start(), range.end())); + CodepointRange(range->start(), range->end())); } std::sort(prepared_codepoint_ranges->begin(), @@ -481,8 +501,10 @@ void FeatureProcessor::PrepareCodepointRanges( } void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() { - for (const int codepoint : options_.ignored_span_boundary_codepoints()) { - ignored_span_boundary_codepoints_.insert(codepoint); + if (options_->ignored_span_boundary_codepoints() != nullptr) { + for (const int codepoint : *options_->ignored_span_boundary_codepoints()) { + ignored_span_boundary_codepoints_.insert(codepoint); + } } } @@ -555,22 +577,25 @@ void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints, std::vector<UnicodeTextRange> FeatureProcessor::SplitContext( const UnicodeText& context_unicode) const { - if (options_.only_use_line_with_click()) { - std::vector<UnicodeTextRange> lines; - std::set<char32> codepoints; - codepoints.insert('\n'); - codepoints.insert('|'); - FindSubstrings(context_unicode, codepoints, &lines); - return lines; - } else { - return {{context_unicode.begin(), context_unicode.end()}}; - } + std::vector<UnicodeTextRange> lines; + const std::set<char32> codepoints{{'\n', '|'}}; + FindSubstrings(context_unicode, codepoints, &lines); + return lines; } CodepointSpan FeatureProcessor::StripBoundaryCodepoints( const std::string& context, CodepointSpan span) const { const UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); + return StripBoundaryCodepoints(context_unicode, span); +} + +CodepointSpan FeatureProcessor::StripBoundaryCodepoints( + const UnicodeText& context_unicode, CodepointSpan span) const { + if (context_unicode.empty() || !ValidNonEmptySpan(span)) { + return span; + } + UnicodeText::const_iterator span_begin = context_unicode.begin(); std::advance(span_begin, span.first); UnicodeText::const_iterator span_end = context_unicode.begin(); @@ -589,21 +614,17 @@ CodepointSpan FeatureProcessor::StripBoundaryCodepoints( } float FeatureProcessor::SupportedCodepointsRatio( - int click_pos, const std::vector<Token>& tokens) const { + const TokenSpan& token_span, const std::vector<Token>& tokens) const { int num_supported = 0; int num_total = 0; - for (int i = click_pos - options_.context_size(); - i <= click_pos + options_.context_size(); ++i) { - const bool is_valid_token = i >= 0 && i < tokens.size(); - if (is_valid_token) { - const UnicodeText value = - UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false); - for (auto codepoint : value) { - if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) { - ++num_supported; - } - ++num_total; + for (int i = token_span.first; i < token_span.second; ++i) { + const UnicodeText value = + UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false); + for (auto codepoint : value) { + if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) { + ++num_supported; } + ++num_total; } } return static_cast<float>(num_supported) / static_cast<float>(num_total); @@ -640,7 +661,7 @@ bool FeatureProcessor::IsCodepointInRanges( int FeatureProcessor::CollectionToLabel(const std::string& collection) const { const auto it = collection_to_label_.find(collection); if (it == collection_to_label_.end()) { - return options_.default_collection(); + return options_->default_collection(); } else { return it->second; } @@ -648,22 +669,24 @@ int FeatureProcessor::CollectionToLabel(const std::string& collection) const { std::string FeatureProcessor::LabelToCollection(int label) const { if (label >= 0 && label < collection_to_label_.size()) { - return options_.collections(label); + return (*options_->collections())[label]->str(); } else { return GetDefaultCollection(); } } void FeatureProcessor::MakeLabelMaps() { - for (int i = 0; i < options_.collections().size(); ++i) { - collection_to_label_[options_.collections(i)] = i; + if (options_->collections() != nullptr) { + for (int i = 0; i < options_->collections()->size(); ++i) { + collection_to_label_[(*options_->collections())[i]->str()] = i; + } } int selection_label_id = 0; - for (int l = 0; l < (options_.max_selection_span() + 1); ++l) { - for (int r = 0; r < (options_.max_selection_span() + 1); ++r) { - if (!options_.selection_reduced_output_space() || - r + l <= options_.max_selection_span()) { + for (int l = 0; l < (options_->max_selection_span() + 1); ++l) { + for (int r = 0; r < (options_->max_selection_span() + 1); ++r) { + if (!options_->selection_reduced_output_space() || + r + l <= options_->max_selection_span()) { TokenSpan token_span{l, r}; selection_to_label_[token_span] = selection_label_id; label_to_selection_.push_back(token_span); @@ -673,19 +696,29 @@ void FeatureProcessor::MakeLabelMaps() { } } -void FeatureProcessor::TokenizeAndFindClick(const std::string& context, - CodepointSpan input_span, - std::vector<Token>* tokens, - int* click_pos) const { +void FeatureProcessor::RetokenizeAndFindClick(const std::string& context, + CodepointSpan input_span, + bool only_use_line_with_click, + std::vector<Token>* tokens, + int* click_pos) const { + const UnicodeText context_unicode = + UTF8ToUnicodeText(context, /*do_copy=*/false); + RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click, + tokens, click_pos); +} + +void FeatureProcessor::RetokenizeAndFindClick( + const UnicodeText& context_unicode, CodepointSpan input_span, + bool only_use_line_with_click, std::vector<Token>* tokens, + int* click_pos) const { TC_CHECK(tokens != nullptr); - *tokens = Tokenize(context); - if (options_.split_tokens_on_selection_boundaries()) { + if (options_->split_tokens_on_selection_boundaries()) { internal::SplitTokensOnSelectionBoundaries(input_span, tokens); } - if (options_.only_use_line_with_click()) { - StripTokensFromOtherLines(context, input_span, tokens); + if (only_use_line_with_click) { + StripTokensFromOtherLines(context_unicode, input_span, tokens); } int local_click_pos; @@ -693,6 +726,11 @@ void FeatureProcessor::TokenizeAndFindClick(const std::string& context, click_pos = &local_click_pos; } *click_pos = FindCenterToken(input_span, *tokens); + if (*click_pos == kInvalidIndex) { + // If the default click method failed, let's try to do sub-token matching + // before we fail. + *click_pos = internal::CenterTokenFromClick(input_span, *tokens); + } } namespace internal { @@ -733,126 +771,104 @@ void StripOrPadTokens(TokenSpan relative_click_span, int context_size, } // namespace internal -bool FeatureProcessor::ExtractFeatures( - const std::string& context, CodepointSpan input_span, - TokenSpan relative_click_span, const FeatureVectorFn& feature_vector_fn, - int feature_vector_size, std::vector<Token>* tokens, int* click_pos, - std::unique_ptr<CachedFeatures>* cached_features) const { - TokenizeAndFindClick(context, input_span, tokens, click_pos); - - if (input_span.first != kInvalidIndex && input_span.second != kInvalidIndex) { - // If the default click method failed, let's try to do sub-token matching - // before we fail. - if (*click_pos == kInvalidIndex) { - *click_pos = internal::CenterTokenFromClick(input_span, *tokens); - if (*click_pos == kInvalidIndex) { - return false; - } - } - } else { - // If input_span is unspecified, click the first token and extract features - // from all tokens. - *click_pos = 0; - relative_click_span = {0, tokens->size()}; - } - - internal::StripOrPadTokens(relative_click_span, options_.context_size(), - tokens, click_pos); - - if (options_.min_supported_codepoint_ratio() > 0) { +bool FeatureProcessor::HasEnoughSupportedCodepoints( + const std::vector<Token>& tokens, TokenSpan token_span) const { + if (options_->min_supported_codepoint_ratio() > 0) { const float supported_codepoint_ratio = - SupportedCodepointsRatio(*click_pos, *tokens); - if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) { + SupportedCodepointsRatio(token_span, tokens); + if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) { TC_VLOG(1) << "Not enough supported codepoints in the context: " << supported_codepoint_ratio; return false; } } + return true; +} - std::vector<std::vector<int>> sparse_features(tokens->size()); - std::vector<std::vector<float>> dense_features(tokens->size()); - for (int i = 0; i < tokens->size(); ++i) { - const Token& token = (*tokens)[i]; - if (!feature_extractor_.Extract(token, token.IsContainedInSpan(input_span), - &(sparse_features[i]), - &(dense_features[i]))) { - TC_LOG(ERROR) << "Could not extract token's features: " << token; +bool FeatureProcessor::ExtractFeatures( + const std::vector<Token>& tokens, TokenSpan token_span, + CodepointSpan selection_span_for_feature, + const EmbeddingExecutor* embedding_executor, + EmbeddingCache* embedding_cache, int feature_vector_size, + std::unique_ptr<CachedFeatures>* cached_features) const { + std::unique_ptr<std::vector<float>> features(new std::vector<float>()); + features->reserve(feature_vector_size * TokenSpanSize(token_span)); + for (int i = token_span.first; i < token_span.second; ++i) { + if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature, + embedding_executor, embedding_cache, + features.get())) { + TC_LOG(ERROR) << "Could not get token features."; return false; } } - cached_features->reset(new CachedFeatures( - *tokens, options_.context_size(), sparse_features, dense_features, - feature_vector_fn, feature_vector_size)); - - if (*cached_features == nullptr) { + std::unique_ptr<std::vector<float>> padding_features( + new std::vector<float>()); + padding_features->reserve(feature_vector_size); + if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature, + embedding_executor, embedding_cache, + padding_features.get())) { + TC_LOG(ERROR) << "Count not get padding token features."; return false; } - if (options_.feature_version() == 0) { - (*cached_features) - ->SetV0FeatureMode(feature_vector_size - - feature_extractor_.DenseFeaturesCount()); + *cached_features = CachedFeatures::Create(token_span, std::move(features), + std::move(padding_features), + options_, feature_vector_size); + if (!*cached_features) { + TC_LOG(ERROR) << "Cound not create cached features."; + return false; } return true; } -bool FeatureProcessor::ICUTokenize(const std::string& context, +bool FeatureProcessor::ICUTokenize(const UnicodeText& context_unicode, std::vector<Token>* result) const { -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - icu::ErrorCode status; - icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(context); - std::unique_ptr<icu::BreakIterator> break_iterator( - icu::BreakIterator::createWordInstance(icu::Locale("en"), status)); - if (!status.isSuccess()) { - TC_LOG(ERROR) << "Break iterator did not initialize properly: " - << status.errorName(); + std::unique_ptr<UniLib::BreakIterator> break_iterator = + unilib_->CreateBreakIterator(context_unicode); + if (!break_iterator) { return false; } - - break_iterator->setText(unicode_text); - - size_t last_break_index = 0; - size_t break_index = 0; - size_t last_unicode_index = 0; - size_t unicode_index = 0; - while ((break_index = break_iterator->next()) != icu::BreakIterator::DONE) { - icu::UnicodeString token(unicode_text, last_break_index, - break_index - last_break_index); - int token_length = token.countChar32(); + int last_break_index = 0; + int break_index = 0; + int last_unicode_index = 0; + int unicode_index = 0; + auto token_begin_it = context_unicode.begin(); + while ((break_index = break_iterator->Next()) != + UniLib::BreakIterator::kDone) { + const int token_length = break_index - last_break_index; unicode_index = last_unicode_index + token_length; - std::string token_utf8; - token.toUTF8String(token_utf8); + auto token_end_it = token_begin_it; + std::advance(token_end_it, token_length); + // Determine if the whole token is whitespace. bool is_whitespace = true; - for (int i = 0; i < token.length(); i++) { - if (!u_isWhitespace(token.char32At(i))) { + for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) { + if (!unilib_->IsWhitespace(*char_it)) { is_whitespace = false; + break; } } - if (!is_whitespace || options_.icu_preserve_whitespace_tokens()) { - result->push_back(Token(token_utf8, last_unicode_index, unicode_index)); + const std::string token = + context_unicode.UTF8Substring(token_begin_it, token_end_it); + + if (!is_whitespace || options_->icu_preserve_whitespace_tokens()) { + result->push_back(Token(token, last_unicode_index, unicode_index)); } last_break_index = break_index; last_unicode_index = unicode_index; + token_begin_it = token_end_it; } return true; -#else - TC_LOG(WARNING) << "Can't tokenize, ICU not supported"; - return false; -#endif } -void FeatureProcessor::InternalRetokenize(const std::string& context, +void FeatureProcessor::InternalRetokenize(const UnicodeText& unicode_text, std::vector<Token>* tokens) const { - const UnicodeText unicode_text = - UTF8ToUnicodeText(context, /*do_copy=*/false); - std::vector<Token> result; CodepointSpan span(-1, -1); for (Token& token : *tokens) { @@ -914,4 +930,69 @@ void FeatureProcessor::TokenizeSubstring(const UnicodeText& unicode_text, } } -} // namespace libtextclassifier +bool FeatureProcessor::AppendTokenFeaturesWithCache( + const Token& token, CodepointSpan selection_span_for_feature, + const EmbeddingExecutor* embedding_executor, + EmbeddingCache* embedding_cache, + std::vector<float>* output_features) const { + // Look for the embedded features for the token in the cache, if there is one. + if (embedding_cache) { + const auto it = embedding_cache->find({token.start, token.end}); + if (it != embedding_cache->end()) { + // The embedded features were found in the cache, extract only the dense + // features. + std::vector<float> dense_features; + if (!feature_extractor_.Extract( + token, token.IsContainedInSpan(selection_span_for_feature), + /*sparse_features=*/nullptr, &dense_features)) { + TC_LOG(ERROR) << "Could not extract token's dense features."; + return false; + } + + // Append both embedded and dense features to the output and return. + output_features->insert(output_features->end(), it->second.begin(), + it->second.end()); + output_features->insert(output_features->end(), dense_features.begin(), + dense_features.end()); + return true; + } + } + + // Extract the sparse and dense features. + std::vector<int> sparse_features; + std::vector<float> dense_features; + if (!feature_extractor_.Extract( + token, token.IsContainedInSpan(selection_span_for_feature), + &sparse_features, &dense_features)) { + TC_LOG(ERROR) << "Could not extract token's features."; + return false; + } + + // Embed the sparse features, appending them directly to the output. + const int embedding_size = GetOptions()->embedding_size(); + output_features->resize(output_features->size() + embedding_size); + float* output_features_end = + output_features->data() + output_features->size(); + if (!embedding_executor->AddEmbedding( + TensorView<int>(sparse_features.data(), + {static_cast<int>(sparse_features.size())}), + /*dest=*/output_features_end - embedding_size, + /*dest_size=*/embedding_size)) { + TC_LOG(ERROR) << "Cound not embed token's sparse features."; + return false; + } + + // If there is a cache, the embedded features for the token were not in it, + // so insert them. + if (embedding_cache) { + (*embedding_cache)[{token.start, token.end}] = std::vector<float>( + output_features_end - embedding_size, output_features_end); + } + + // Append the dense features to the output. + output_features->insert(output_features->end(), dense_features.begin(), + dense_features.end()); + return true; +} + +} // namespace libtextclassifier2 diff --git a/smartselect/feature-processor.h b/feature-processor.h index ef9a3df..98d3449 100644 --- a/smartselect/feature-processor.h +++ b/feature-processor.h @@ -16,42 +16,33 @@ // Feature processing for FFModel (feed-forward SmartSelection model). -#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ -#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ +#ifndef LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_ +#define LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_ +#include <map> #include <memory> #include <set> #include <string> #include <vector> -#include "smartselect/cached-features.h" -#include "smartselect/text-classification-model.pb.h" -#include "smartselect/token-feature-extractor.h" -#include "smartselect/tokenizer.h" -#include "smartselect/types.h" +#include "cached-features.h" +#include "model_generated.h" +#include "token-feature-extractor.h" +#include "tokenizer.h" +#include "types.h" +#include "util/base/integral_types.h" #include "util/base/logging.h" #include "util/utf8/unicodetext.h" +#include "util/utf8/unilib.h" -namespace libtextclassifier { +namespace libtextclassifier2 { constexpr int kInvalidLabel = -1; -// Maps a vector of sparse features and a vector of dense features to a vector -// of features that combines both. -// The output is written to the memory location pointed to by the last float* -// argument. -// Returns true on success false on failure. -using FeatureVectorFn = std::function<bool(const std::vector<int>&, - const std::vector<float>&, float*)>; - namespace internal { -// Parses the serialized protocol buffer. -FeatureProcessorOptions ParseSerializedOptions( - const std::string& serialized_options); - TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( - const FeatureProcessorOptions& options); + const FeatureProcessorOptions* options); // Splits tokens that contain the selection boundary inside them. // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" @@ -73,6 +64,11 @@ int CenterTokenFromMiddleOfSelection( void StripOrPadTokens(TokenSpan relative_click_span, int context_size, std::vector<Token>* tokens, int* click_pos); +// If unilib is not nullptr, just returns unilib. Otherwise, if unilib is +// nullptr, will create UniLib, assign ownership to owned_unilib, and return it. +const UniLib* MaybeCreateUnilib(const UniLib* unilib, + std::unique_ptr<UniLib>* owned_unilib); + } // namespace internal // Converts a codepoint span to a token span in the given list of tokens. @@ -90,29 +86,48 @@ CodepointSpan TokenSpanToCodepointSpan( // Takes care of preparing features for the span prediction model. class FeatureProcessor { public: - explicit FeatureProcessor(const FeatureProcessorOptions& options) - : feature_extractor_( - internal::BuildTokenFeatureExtractorOptions(options)), + // A cache mapping codepoint spans to embedded tokens features. An instance + // can be provided to multiple calls to ExtractFeatures() operating on the + // same context (the same codepoint spans corresponding to the same tokens), + // as an optimization. Note that the tokenizations do not have to be + // identical. + typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache; + + // If unilib is nullptr, will create and own an instance of a UniLib, + // otherwise will use what's passed in. + explicit FeatureProcessor(const FeatureProcessorOptions* options, + const UniLib* unilib = nullptr) + : owned_unilib_(nullptr), + unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)), + feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options), + *unilib_), options_(options), - tokenizer_({options.tokenization_codepoint_config().begin(), - options.tokenization_codepoint_config().end()}) { + tokenizer_( + options->tokenization_codepoint_config() != nullptr + ? Tokenizer({options->tokenization_codepoint_config()->begin(), + options->tokenization_codepoint_config()->end()}, + options->tokenize_on_script_change()) + : Tokenizer({}, /*split_on_script_change=*/false)) { MakeLabelMaps(); - PrepareCodepointRanges({options.supported_codepoint_ranges().begin(), - options.supported_codepoint_ranges().end()}, - &supported_codepoint_ranges_); - PrepareCodepointRanges( - {options.internal_tokenizer_codepoint_ranges().begin(), - options.internal_tokenizer_codepoint_ranges().end()}, - &internal_tokenizer_codepoint_ranges_); + if (options->supported_codepoint_ranges() != nullptr) { + PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(), + options->supported_codepoint_ranges()->end()}, + &supported_codepoint_ranges_); + } + if (options->internal_tokenizer_codepoint_ranges() != nullptr) { + PrepareCodepointRanges( + {options->internal_tokenizer_codepoint_ranges()->begin(), + options->internal_tokenizer_codepoint_ranges()->end()}, + &internal_tokenizer_codepoint_ranges_); + } PrepareIgnoredSpanBoundaryCodepoints(); } - explicit FeatureProcessor(const std::string& serialized_options) - : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) { - } - // Tokenizes the input string using the selected tokenization method. - std::vector<Token> Tokenize(const std::string& utf8_text) const; + std::vector<Token> Tokenize(const std::string& text) const; + + // Same as above but takes UnicodeText. + std::vector<Token> Tokenize(const UnicodeText& text_unicode) const; // Converts a label into a token span. bool LabelToTokenSpan(int label, TokenSpan* token_span) const; @@ -129,22 +144,32 @@ class FeatureProcessor { // Gets the name of the default collection. std::string GetDefaultCollection() const; - const FeatureProcessorOptions& GetOptions() const { return options_; } + const FeatureProcessorOptions* GetOptions() const { return options_; } + + // Retokenizes the context and input span, and finds the click position. + // Depending on the options, might modify tokens (split them or remove them). + void RetokenizeAndFindClick(const std::string& context, + CodepointSpan input_span, + bool only_use_line_with_click, + std::vector<Token>* tokens, int* click_pos) const; - // Tokenizes the context and input span, and finds the click position. - void TokenizeAndFindClick(const std::string& context, - CodepointSpan input_span, - std::vector<Token>* tokens, int* click_pos) const; + // Same as above but takes UnicodeText. + void RetokenizeAndFindClick(const UnicodeText& context_unicode, + CodepointSpan input_span, + bool only_use_line_with_click, + std::vector<Token>* tokens, int* click_pos) const; + + // Returns true if the token span has enough supported codepoints (as defined + // in the model config) or not and model should not run. + bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens, + TokenSpan token_span) const; // Extracts features as a CachedFeatures object that can be used for repeated // inference over token spans in the given context. - // When input_span == {kInvalidIndex, kInvalidIndex} then, relative_click_span - // is ignored, and all tokens extracted from context will be considered. - bool ExtractFeatures(const std::string& context, CodepointSpan input_span, - TokenSpan relative_click_span, - const FeatureVectorFn& feature_vector_fn, - int feature_vector_size, std::vector<Token>* tokens, - int* click_pos, + bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span, + CodepointSpan selection_span_for_feature, + const EmbeddingExecutor* embedding_executor, + EmbeddingCache* embedding_cache, int feature_vector_size, std::unique_ptr<CachedFeatures>* cached_features) const; // Fills selection_label_spans with CodepointSpans that correspond to the @@ -158,7 +183,9 @@ class FeatureProcessor { return feature_extractor_.DenseFeaturesCount(); } - // Splits context to several segments according to configuration. + int EmbeddingSize() const { return options_->embedding_size(); } + + // Splits context to several segments. std::vector<UnicodeTextRange> SplitContext( const UnicodeText& context_unicode) const; @@ -168,6 +195,10 @@ class FeatureProcessor { CodepointSpan StripBoundaryCodepoints(const std::string& context, CodepointSpan span) const; + // Same as above but takes UnicodeText. + CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode, + CodepointSpan span) const; + protected: // Represents a codepoint range [start, end). struct CodepointRange { @@ -191,7 +222,7 @@ class FeatureProcessor { // Spannable tokens are those tokens of context, which the model predicts // selection spans over (i.e., there is 1:1 correspondence between the output // classes of the model and each of the spannable tokens). - int GetNumContextTokens() const { return options_.context_size() * 2 + 1; } + int GetNumContextTokens() const { return options_->context_size() * 2 + 1; } // Converts a label into a span of codepoint indices corresponding to it // given output_tokens. @@ -206,13 +237,13 @@ class FeatureProcessor { int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const; void PrepareCodepointRanges( - const std::vector<FeatureProcessorOptions::CodepointRange>& + const std::vector<const FeatureProcessorOptions_::CodepointRange*>& codepoint_ranges, std::vector<CodepointRange>* prepared_codepoint_ranges); // Returns the ratio of supported codepoints to total number of codepoints in - // the input context around given click position. - float SupportedCodepointsRatio(int click_pos, + // the given token span. + float SupportedCodepointsRatio(const TokenSpan& token_span, const std::vector<Token>& tokens) const; // Returns true if given codepoint is covered by the given sorted vector of @@ -238,12 +269,12 @@ class FeatureProcessor { const std::vector<Token>& tokens) const; // Tokenizes the input text using ICU tokenizer. - bool ICUTokenize(const std::string& context, + bool ICUTokenize(const UnicodeText& context_unicode, std::vector<Token>* result) const; // Takes the result of ICU tokenization and retokenizes stretches of tokens // made of a specific subset of characters using the internal tokenizer. - void InternalRetokenize(const std::string& context, + void InternalRetokenize(const UnicodeText& unicode_text, std::vector<Token>* tokens) const; // Tokenizes a substring of the unicode string, appending the resulting tokens @@ -257,6 +288,25 @@ class FeatureProcessor { void StripTokensFromOtherLines(const std::string& context, CodepointSpan span, std::vector<Token>* tokens) const; + // Same as above but takes UnicodeText. + void StripTokensFromOtherLines(const UnicodeText& context_unicode, + CodepointSpan span, + std::vector<Token>* tokens) const; + + // Extracts the features of a token and appends them to the output vector. + // Uses the embedding cache to to avoid re-extracting the re-embedding the + // sparse features for the same token. + bool AppendTokenFeaturesWithCache(const Token& token, + CodepointSpan selection_span_for_feature, + const EmbeddingExecutor* embedding_executor, + EmbeddingCache* embedding_cache, + std::vector<float>* output_features) const; + + private: + std::unique_ptr<UniLib> owned_unilib_; + const UniLib* unilib_; + + protected: const TokenFeatureExtractor feature_extractor_; // Codepoint ranges that define what codepoints are supported by the model. @@ -274,7 +324,7 @@ class FeatureProcessor { // predicted spans. std::set<int32> ignored_span_boundary_codepoints_; - const FeatureProcessorOptions options_; + const FeatureProcessorOptions* const options_; // Mapping between token selection spans and labels ids. std::map<TokenSpan, int> selection_to_label_; @@ -286,6 +336,6 @@ class FeatureProcessor { Tokenizer tokenizer_; }; -} // namespace libtextclassifier +} // namespace libtextclassifier2 -#endif // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ +#endif // LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_ diff --git a/smartselect/feature-processor_test.cc b/feature-processor_test.cc index 9bee67a..58b3033 100644 --- a/smartselect/feature-processor_test.cc +++ b/feature-processor_test.cc @@ -14,16 +14,40 @@ * limitations under the License. */ -#include "smartselect/feature-processor.h" +#include "feature-processor.h" + +#include "model-executor.h" +#include "tensor-view.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier { +namespace libtextclassifier2 { namespace { using testing::ElementsAreArray; using testing::FloatEq; +using testing::Matcher; + +flatbuffers::DetachedBuffer PackFeatureProcessorOptions( + const FeatureProcessorOptionsT& options) { + flatbuffers::FlatBufferBuilder builder; + builder.Finish(CreateFeatureProcessorOptions(builder, &options)); + return builder.Release(); +} + +template <typename T> +std::vector<T> Subvector(const std::vector<T>& vector, int start, int end) { + return std::vector<T>(vector.begin() + start, vector.begin() + end); +} + +Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) { + std::vector<Matcher<float>> matchers; + for (const float value : values) { + matchers.push_back(FloatEq(value)); + } + return ElementsAreArray(matchers); +} class TestingFeatureProcessor : public FeatureProcessor { public: @@ -37,6 +61,24 @@ class TestingFeatureProcessor : public FeatureProcessor { using FeatureProcessor::SupportedCodepointsRatio; }; +// EmbeddingExecutor that always returns features based on +class FakeEmbeddingExecutor : public EmbeddingExecutor { + public: + bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, + int dest_size) const override { + TC_CHECK_GE(dest_size, 4); + EXPECT_EQ(sparse_features.size(), 1); + dest[0] = sparse_features.data()[0]; + dest[1] = sparse_features.data()[0]; + dest[2] = -sparse_features.data()[0]; + dest[3] = -sparse_features.data()[0]; + return true; + } + + private: + std::vector<float> storage_; +}; + TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) { std::vector<Token> tokens{Token("Hělló", 0, 5), Token("fěěbař@google.com", 6, 23), @@ -119,9 +161,13 @@ TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) { } TEST(FeatureProcessorTest, KeepLineWithClickFirst) { - FeatureProcessorOptions options; - options.set_only_use_line_with_click(true); - TestingFeatureProcessor feature_processor(options); + CREATE_UNILIB_FOR_TESTING; + FeatureProcessorOptionsT options; + options.only_use_line_with_click = true; + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; const CodepointSpan span = {0, 5}; @@ -141,9 +187,13 @@ TEST(FeatureProcessorTest, KeepLineWithClickFirst) { } TEST(FeatureProcessorTest, KeepLineWithClickSecond) { - FeatureProcessorOptions options; - options.set_only_use_line_with_click(true); - TestingFeatureProcessor feature_processor(options); + CREATE_UNILIB_FOR_TESTING; + FeatureProcessorOptionsT options; + options.only_use_line_with_click = true; + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; const CodepointSpan span = {18, 22}; @@ -163,9 +213,13 @@ TEST(FeatureProcessorTest, KeepLineWithClickSecond) { } TEST(FeatureProcessorTest, KeepLineWithClickThird) { - FeatureProcessorOptions options; - options.set_only_use_line_with_click(true); - TestingFeatureProcessor feature_processor(options); + CREATE_UNILIB_FOR_TESTING; + FeatureProcessorOptionsT options; + options.only_use_line_with_click = true; + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; const CodepointSpan span = {24, 33}; @@ -185,9 +239,13 @@ TEST(FeatureProcessorTest, KeepLineWithClickThird) { } TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) { - FeatureProcessorOptions options; - options.set_only_use_line_with_click(true); - TestingFeatureProcessor feature_processor(options); + CREATE_UNILIB_FOR_TESTING; + FeatureProcessorOptionsT options; + options.only_use_line_with_click = true; + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině"; const CodepointSpan span = {18, 22}; @@ -207,9 +265,13 @@ TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) { } TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) { - FeatureProcessorOptions options; - options.set_only_use_line_with_click(true); - TestingFeatureProcessor feature_processor(options); + CREATE_UNILIB_FOR_TESTING; + FeatureProcessorOptionsT options; + options.only_use_line_with_click = true; + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; const CodepointSpan span = {5, 23}; @@ -231,18 +293,23 @@ TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) { } TEST(FeatureProcessorTest, SpanToLabel) { - FeatureProcessorOptions options; - options.set_context_size(1); - options.set_max_selection_span(1); - options.set_snap_label_span_boundaries_to_containing_tokens(false); - - TokenizationCodepointRange* config = - options.add_tokenization_codepoint_config(); - config->set_start(32); - config->set_end(33); - config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); - - TestingFeatureProcessor feature_processor(options); + CREATE_UNILIB_FOR_TESTING; + FeatureProcessorOptionsT options; + options.context_size = 1; + options.max_selection_span = 1; + options.snap_label_span_boundaries_to_containing_tokens = false; + + options.tokenization_codepoint_config.emplace_back( + new TokenizationCodepointRangeT()); + auto& config = options.tokenization_codepoint_config.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); std::vector<Token> tokens = feature_processor.Tokenize("one, two, three"); ASSERT_EQ(3, tokens.size()); int label; @@ -256,8 +323,12 @@ TEST(FeatureProcessorTest, SpanToLabel) { EXPECT_EQ(0, token_span.second); // Reconfigure with snapping enabled. - options.set_snap_label_span_boundaries_to_containing_tokens(true); - TestingFeatureProcessor feature_processor2(options); + options.snap_label_span_boundaries_to_containing_tokens = true; + flatbuffers::DetachedBuffer options2_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor2( + flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), + &unilib); int label2; ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2)); EXPECT_EQ(label, label2); @@ -273,9 +344,13 @@ TEST(FeatureProcessorTest, SpanToLabel) { EXPECT_EQ(kInvalidLabel, label2); // Multiple tokens. - options.set_context_size(2); - options.set_max_selection_span(2); - TestingFeatureProcessor feature_processor3(options); + options.context_size = 2; + options.max_selection_span = 2; + flatbuffers::DetachedBuffer options3_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor3( + flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), + &unilib); tokens = feature_processor3.Tokenize("zero, one, two, three, four"); ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2)); EXPECT_NE(kInvalidLabel, label2); @@ -293,18 +368,23 @@ TEST(FeatureProcessorTest, SpanToLabel) { } TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { - FeatureProcessorOptions options; - options.set_context_size(1); - options.set_max_selection_span(1); - options.set_snap_label_span_boundaries_to_containing_tokens(false); - - TokenizationCodepointRange* config = - options.add_tokenization_codepoint_config(); - config->set_start(32); - config->set_end(33); - config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); - - TestingFeatureProcessor feature_processor(options); + CREATE_UNILIB_FOR_TESTING; + FeatureProcessorOptionsT options; + options.context_size = 1; + options.max_selection_span = 1; + options.snap_label_span_boundaries_to_containing_tokens = false; + + options.tokenization_codepoint_config.emplace_back( + new TokenizationCodepointRangeT()); + auto& config = options.tokenization_codepoint_config.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); std::vector<Token> tokens = feature_processor.Tokenize("one, two, three"); ASSERT_EQ(3, tokens.size()); int label; @@ -318,8 +398,12 @@ TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { EXPECT_EQ(0, token_span.second); // Reconfigure with snapping enabled. - options.set_snap_label_span_boundaries_to_containing_tokens(true); - TestingFeatureProcessor feature_processor2(options); + options.snap_label_span_boundaries_to_containing_tokens = true; + flatbuffers::DetachedBuffer options2_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor2( + flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), + &unilib); int label2; ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2)); EXPECT_EQ(label, label2); @@ -335,9 +419,13 @@ TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { EXPECT_EQ(kInvalidLabel, label2); // Multiple tokens. - options.set_context_size(2); - options.set_max_selection_span(2); - TestingFeatureProcessor feature_processor3(options); + options.context_size = 2; + options.max_selection_span = 2; + flatbuffers::DetachedBuffer options3_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor3( + flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), + &unilib); tokens = feature_processor3.Tokenize("zero, one, two, three, four"); ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2)); EXPECT_NE(kInvalidLabel, label2); @@ -420,39 +508,66 @@ TEST(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) { } TEST(FeatureProcessorTest, SupportedCodepointsRatio) { - FeatureProcessorOptions options; - options.set_context_size(2); - options.set_max_selection_span(2); - options.set_snap_label_span_boundaries_to_containing_tokens(false); - - TokenizationCodepointRange* config = - options.add_tokenization_codepoint_config(); - config->set_start(32); - config->set_end(33); - config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); - - FeatureProcessorOptions::CodepointRange* range; - range = options.add_supported_codepoint_ranges(); - range->set_start(0); - range->set_end(128); - - range = options.add_supported_codepoint_ranges(); - range->set_start(10000); - range->set_end(10001); - - range = options.add_supported_codepoint_ranges(); - range->set_start(20000); - range->set_end(30000); - - TestingFeatureProcessor feature_processor(options); + FeatureProcessorOptionsT options; + options.context_size = 2; + options.max_selection_span = 2; + options.snap_label_span_boundaries_to_containing_tokens = false; + options.feature_version = 2; + options.embedding_size = 4; + options.bounds_sensitive_features.reset( + new FeatureProcessorOptions_::BoundsSensitiveFeaturesT()); + options.bounds_sensitive_features->enabled = true; + options.bounds_sensitive_features->num_tokens_before = 5; + options.bounds_sensitive_features->num_tokens_inside_left = 3; + options.bounds_sensitive_features->num_tokens_inside_right = 3; + options.bounds_sensitive_features->num_tokens_after = 5; + options.bounds_sensitive_features->include_inside_bag = true; + options.bounds_sensitive_features->include_inside_length = true; + + options.tokenization_codepoint_config.emplace_back( + new TokenizationCodepointRangeT()); + auto& config = options.tokenization_codepoint_config.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + { + options.supported_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.supported_codepoint_ranges.back(); + range->start = 0; + range->end = 128; + } + + { + options.supported_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.supported_codepoint_ranges.back(); + range->start = 10000; + range->end = 10001; + } + + { + options.supported_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.supported_codepoint_ranges.back(); + range->start = 20000; + range->end = 30000; + } + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + CREATE_UNILIB_FOR_TESTING; + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); EXPECT_THAT(feature_processor.SupportedCodepointsRatio( - 1, feature_processor.Tokenize("aaa bbb ccc")), + {0, 3}, feature_processor.Tokenize("aaa bbb ccc")), FloatEq(1.0)); EXPECT_THAT(feature_processor.SupportedCodepointsRatio( - 1, feature_processor.Tokenize("aaa bbb ěěě")), + {0, 3}, feature_processor.Tokenize("aaa bbb ěěě")), FloatEq(2.0 / 3)); EXPECT_THAT(feature_processor.SupportedCodepointsRatio( - 1, feature_processor.Tokenize("ěěě řřř ěěě")), + {0, 3}, feature_processor.Tokenize("ěěě řřř ěěě")), FloatEq(0.0)); EXPECT_FALSE(feature_processor.IsCodepointInRanges( -1, feature_processor.supported_codepoint_ranges_)); @@ -473,32 +588,142 @@ TEST(FeatureProcessorTest, SupportedCodepointsRatio) { EXPECT_TRUE(feature_processor.IsCodepointInRanges( 25000, feature_processor.supported_codepoint_ranges_)); - std::vector<Token> tokens; - int click_pos; - std::vector<float> extra_features; + const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7), + Token("eee", 8, 11)}; + + options.min_supported_codepoint_ratio = 0.0; + flatbuffers::DetachedBuffer options2_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor2( + flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), + &unilib); + EXPECT_TRUE(feature_processor2.HasEnoughSupportedCodepoints( + tokens, /*token_span=*/{0, 3})); + + options.min_supported_codepoint_ratio = 0.2; + flatbuffers::DetachedBuffer options3_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor3( + flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), + &unilib); + EXPECT_TRUE(feature_processor3.HasEnoughSupportedCodepoints( + tokens, /*token_span=*/{0, 3})); + + options.min_supported_codepoint_ratio = 0.5; + flatbuffers::DetachedBuffer options4_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor4( + flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()), + &unilib); + EXPECT_FALSE(feature_processor4.HasEnoughSupportedCodepoints( + tokens, /*token_span=*/{0, 3})); +} + +TEST(FeatureProcessorTest, InSpanFeature) { + FeatureProcessorOptionsT options; + options.context_size = 2; + options.max_selection_span = 2; + options.snap_label_span_boundaries_to_containing_tokens = false; + options.feature_version = 2; + options.embedding_size = 4; + options.extract_selection_mask_feature = true; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + CREATE_UNILIB_FOR_TESTING; + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); + + std::unique_ptr<CachedFeatures> cached_features; + + FakeEmbeddingExecutor embedding_executor; + + const std::vector<Token> tokens = {Token("aaa", 0, 3), Token("bbb", 4, 7), + Token("ccc", 8, 11), Token("ddd", 12, 15)}; + + EXPECT_TRUE(feature_processor.ExtractFeatures( + tokens, /*token_span=*/{0, 4}, + /*selection_span_for_feature=*/{4, 11}, &embedding_executor, + /*embedding_cache=*/nullptr, /*feature_vector_size=*/5, + &cached_features)); + std::vector<float> features; + cached_features->AppendClickContextFeaturesForClick(1, &features); + ASSERT_EQ(features.size(), 25); + EXPECT_THAT(features[4], FloatEq(0.0)); + EXPECT_THAT(features[9], FloatEq(0.0)); + EXPECT_THAT(features[14], FloatEq(1.0)); + EXPECT_THAT(features[19], FloatEq(1.0)); + EXPECT_THAT(features[24], FloatEq(0.0)); +} + +TEST(FeatureProcessorTest, EmbeddingCache) { + FeatureProcessorOptionsT options; + options.context_size = 2; + options.max_selection_span = 2; + options.snap_label_span_boundaries_to_containing_tokens = false; + options.feature_version = 2; + options.embedding_size = 4; + options.bounds_sensitive_features.reset( + new FeatureProcessorOptions_::BoundsSensitiveFeaturesT()); + options.bounds_sensitive_features->enabled = true; + options.bounds_sensitive_features->num_tokens_before = 3; + options.bounds_sensitive_features->num_tokens_inside_left = 2; + options.bounds_sensitive_features->num_tokens_inside_right = 2; + options.bounds_sensitive_features->num_tokens_after = 3; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + CREATE_UNILIB_FOR_TESTING; + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); + std::unique_ptr<CachedFeatures> cached_features; - auto feature_fn = [](const std::vector<int>& sparse_features, - const std::vector<float>& dense_features, - float* embedding) { return true; }; - - options.set_min_supported_codepoint_ratio(0.0); - TestingFeatureProcessor feature_processor2(options); - EXPECT_TRUE(feature_processor2.ExtractFeatures("ěěě řřř eee", {4, 7}, {0, 0}, - feature_fn, 2, &tokens, - &click_pos, &cached_features)); - - options.set_min_supported_codepoint_ratio(0.2); - TestingFeatureProcessor feature_processor3(options); - EXPECT_TRUE(feature_processor3.ExtractFeatures("ěěě řřř eee", {4, 7}, {0, 0}, - feature_fn, 2, &tokens, - &click_pos, &cached_features)); - - options.set_min_supported_codepoint_ratio(0.5); - TestingFeatureProcessor feature_processor4(options); - EXPECT_FALSE(feature_processor4.ExtractFeatures( - "ěěě řřř eee", {4, 7}, {0, 0}, feature_fn, 2, &tokens, &click_pos, + FakeEmbeddingExecutor embedding_executor; + + const std::vector<Token> tokens = { + Token("aaa", 0, 3), Token("bbb", 4, 7), Token("ccc", 8, 11), + Token("ddd", 12, 15), Token("eee", 16, 19), Token("fff", 20, 23)}; + + // We pre-populate the cache with dummy embeddings, to make sure they are + // used when populating the features vector. + const std::vector<float> cached_padding_features = {10.0, -10.0, 10.0, -10.0}; + const std::vector<float> cached_features1 = {1.0, 2.0, 3.0, 4.0}; + const std::vector<float> cached_features2 = {5.0, 6.0, 7.0, 8.0}; + FeatureProcessor::EmbeddingCache embedding_cache = { + {{kInvalidIndex, kInvalidIndex}, cached_padding_features}, + {{4, 7}, cached_features1}, + {{12, 15}, cached_features2}, + }; + + EXPECT_TRUE(feature_processor.ExtractFeatures( + tokens, /*token_span=*/{0, 6}, + /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, + &embedding_executor, &embedding_cache, /*feature_vector_size=*/4, &cached_features)); + std::vector<float> features; + cached_features->AppendBoundsSensitiveFeaturesForSpan({2, 4}, &features); + ASSERT_EQ(features.size(), 40); + // Check that the dummy embeddings were used. + EXPECT_THAT(Subvector(features, 0, 4), + ElementsAreFloat(cached_padding_features)); + EXPECT_THAT(Subvector(features, 8, 12), ElementsAreFloat(cached_features1)); + EXPECT_THAT(Subvector(features, 16, 20), ElementsAreFloat(cached_features2)); + EXPECT_THAT(Subvector(features, 24, 28), ElementsAreFloat(cached_features2)); + EXPECT_THAT(Subvector(features, 36, 40), + ElementsAreFloat(cached_padding_features)); + // Check that the real embeddings were cached. + EXPECT_EQ(embedding_cache.size(), 7); + EXPECT_THAT(Subvector(features, 4, 8), + ElementsAreFloat(embedding_cache.at({0, 3}))); + EXPECT_THAT(Subvector(features, 12, 16), + ElementsAreFloat(embedding_cache.at({8, 11}))); + EXPECT_THAT(Subvector(features, 20, 24), + ElementsAreFloat(embedding_cache.at({8, 11}))); + EXPECT_THAT(Subvector(features, 28, 32), + ElementsAreFloat(embedding_cache.at({16, 19}))); + EXPECT_THAT(Subvector(features, 32, 36), + ElementsAreFloat(embedding_cache.at({20, 23}))); } TEST(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) { @@ -613,12 +838,48 @@ TEST(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) { EXPECT_EQ(click_index, 5); } +TEST(FeatureProcessorTest, InternalTokenizeOnScriptChange) { + CREATE_UNILIB_FOR_TESTING; + FeatureProcessorOptionsT options; + options.tokenization_codepoint_config.emplace_back( + new TokenizationCodepointRangeT()); + { + auto& config = options.tokenization_codepoint_config.back(); + config->start = 0; + config->end = 256; + config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + config->script_id = 1; + } + options.tokenize_on_script_change = false; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); + + EXPECT_EQ(feature_processor.Tokenize("앨라배마123웹사이트"), + std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)})); + + options.tokenize_on_script_change = true; + flatbuffers::DetachedBuffer options_fb2 = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor2( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb2.data()), + &unilib); + + EXPECT_EQ(feature_processor2.Tokenize("앨라배마123웹사이트"), + std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7), + Token("웹사이트", 7, 11)})); +} + +#ifdef LIBTEXTCLASSIFIER_TEST_ICU TEST(FeatureProcessorTest, ICUTokenize) { - FeatureProcessorOptions options; - options.set_tokenization_type( - libtextclassifier::FeatureProcessorOptions::ICU); + FeatureProcessorOptionsT options; + options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU; - TestingFeatureProcessor feature_processor(options); + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data())); std::vector<Token> tokens = feature_processor.Tokenize("พระบาทสมเด็จพระปรมิ"); ASSERT_EQ(tokens, // clang-format off @@ -629,14 +890,17 @@ TEST(FeatureProcessorTest, ICUTokenize) { Token("มิ", 17, 19)})); // clang-format on } +#endif +#ifdef LIBTEXTCLASSIFIER_TEST_ICU TEST(FeatureProcessorTest, ICUTokenizeWithWhitespaces) { - FeatureProcessorOptions options; - options.set_tokenization_type( - libtextclassifier::FeatureProcessorOptions::ICU); - options.set_icu_preserve_whitespace_tokens(true); + FeatureProcessorOptionsT options; + options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU; + options.icu_preserve_whitespace_tokens = true; - TestingFeatureProcessor feature_processor(options); + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data())); std::vector<Token> tokens = feature_processor.Tokenize("พระบาท สมเด็จ พระ ปร มิ"); ASSERT_EQ(tokens, @@ -652,36 +916,55 @@ TEST(FeatureProcessorTest, ICUTokenizeWithWhitespaces) { Token("มิ", 21, 23)})); // clang-format on } +#endif +#ifdef LIBTEXTCLASSIFIER_TEST_ICU TEST(FeatureProcessorTest, MixedTokenize) { - FeatureProcessorOptions options; - options.set_tokenization_type( - libtextclassifier::FeatureProcessorOptions::MIXED); - - TokenizationCodepointRange* config = - options.add_tokenization_codepoint_config(); - config->set_start(32); - config->set_end(33); - config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); - - FeatureProcessorOptions::CodepointRange* range; - range = options.add_internal_tokenizer_codepoint_ranges(); - range->set_start(0); - range->set_end(128); - - range = options.add_internal_tokenizer_codepoint_ranges(); - range->set_start(128); - range->set_end(256); - - range = options.add_internal_tokenizer_codepoint_ranges(); - range->set_start(256); - range->set_end(384); - - range = options.add_internal_tokenizer_codepoint_ranges(); - range->set_start(384); - range->set_end(592); - - TestingFeatureProcessor feature_processor(options); + FeatureProcessorOptionsT options; + options.tokenization_type = FeatureProcessorOptions_::TokenizationType_MIXED; + + options.tokenization_codepoint_config.emplace_back( + new TokenizationCodepointRangeT()); + auto& config = options.tokenization_codepoint_config.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + { + options.internal_tokenizer_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.internal_tokenizer_codepoint_ranges.back(); + range->start = 0; + range->end = 128; + } + + { + options.internal_tokenizer_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.internal_tokenizer_codepoint_ranges.back(); + range->start = 128; + range->end = 256; + } + + { + options.internal_tokenizer_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.internal_tokenizer_codepoint_ranges.back(); + range->start = 256; + range->end = 384; + } + + { + options.internal_tokenizer_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.internal_tokenizer_codepoint_ranges.back(); + range->start = 384; + range->end = 592; + } + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data())); std::vector<Token> tokens = feature_processor.Tokenize( "こんにちはJapanese-ląnguagę text 世界 http://www.google.com/"); ASSERT_EQ(tokens, @@ -693,15 +976,20 @@ TEST(FeatureProcessorTest, MixedTokenize) { Token("http://www.google.com/", 31, 53)})); // clang-format on } +#endif TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) { - FeatureProcessorOptions options; - options.add_ignored_span_boundary_codepoints('.'); - options.add_ignored_span_boundary_codepoints(','); - options.add_ignored_span_boundary_codepoints('['); - options.add_ignored_span_boundary_codepoints(']'); - - TestingFeatureProcessor feature_processor(options); + CREATE_UNILIB_FOR_TESTING; + FeatureProcessorOptionsT options; + options.ignored_span_boundary_codepoints.push_back('.'); + options.ignored_span_boundary_codepoints.push_back(','); + options.ignored_span_boundary_codepoints.push_back('['); + options.ignored_span_boundary_codepoints.push_back(']'); + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); const std::string text1_utf8 = "ěščř"; const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false); @@ -834,4 +1122,4 @@ TEST(FeatureProcessorTest, CodepointSpanToTokenSpan) { } } // namespace -} // namespace libtextclassifier +} // namespace libtextclassifier2 @@ -1,7 +1,7 @@ -{ - # Export symbols that correspond to our JNIEXPORTed functions. +VERS_1.0 { + # Export JNI symbols. global: - Java_android_view_textclassifier_*; + Java_*; # Hide everything else. local: diff --git a/lang_id/custom-tokenizer.cc b/lang_id/custom-tokenizer.cc deleted file mode 100644 index 7e30cc7..0000000 --- a/lang_id/custom-tokenizer.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "lang_id/custom-tokenizer.h" - -#include <ctype.h> - -#include <string> - -#include "util/strings/utf8.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -namespace { -inline bool IsTokenSeparator(int num_bytes, const char *curr) { - if (num_bytes != 1) { - return false; - } - return !isalpha(*curr); -} -} // namespace - -const char *GetSafeEndOfString(const char *data, size_t size) { - const char *const hard_end = data + size; - const char *curr = data; - while (curr < hard_end) { - int num_bytes = GetNumBytesForUTF8Char(curr); - if (num_bytes == 0) { - break; - } - const char *new_curr = curr + num_bytes; - if (new_curr > hard_end) { - return curr; - } - curr = new_curr; - } - return curr; -} - -void TokenizeTextForLangId(const std::string &text, LightSentence *sentence) { - const char *const start = text.data(); - const char *curr = start; - const char *end = GetSafeEndOfString(start, text.size()); - - // Corner case: empty safe part of the text. - if (curr >= end) { - return; - } - - // Number of bytes for UTF8 character starting at *curr. Note: the loop below - // is guaranteed to terminate because in each iteration, we move curr by at - // least num_bytes, and num_bytes is guaranteed to be > 0. - int num_bytes = GetNumBytesForNonZeroUTF8Char(curr); - while (curr < end) { - // Jump over consecutive token separators. - while (IsTokenSeparator(num_bytes, curr)) { - curr += num_bytes; - if (curr >= end) { - return; - } - num_bytes = GetNumBytesForNonZeroUTF8Char(curr); - } - - // If control reaches this point, we are at beginning of a non-empty token. - std::string *word = sentence->add_word(); - - // Add special token-start character. - word->push_back('^'); - - // Add UTF8 characters to word, until we hit the end of the safe text or a - // token separator. - while (true) { - word->append(curr, num_bytes); - curr += num_bytes; - if (curr >= end) { - break; - } - num_bytes = GetNumBytesForNonZeroUTF8Char(curr); - if (IsTokenSeparator(num_bytes, curr)) { - curr += num_bytes; - num_bytes = GetNumBytesForNonZeroUTF8Char(curr); - break; - } - } - word->push_back('$'); - - // Note: we intentionally do not token.set_start()/end(), as those fields - // are not used by the langid model. - } -} - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/lang_id/custom-tokenizer.h b/lang_id/custom-tokenizer.h deleted file mode 100644 index c9c291c..0000000 --- a/lang_id/custom-tokenizer.h +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_LANG_ID_CUSTOM_TOKENIZER_H_ -#define LIBTEXTCLASSIFIER_LANG_ID_CUSTOM_TOKENIZER_H_ - -#include <cstddef> -#include <string> - -#include "lang_id/light-sentence.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -// Perform custom tokenization of text. Customized for the language -// identification project. Currently (Sep 15, 2016) we tokenize on space, -// newline, and tab, ignore all empty tokens, and (for each of the remaining -// tokens) prepend "^" (special token begin marker) and append "$" (special -// token end marker). -// -// Tokens are stored into the words of the LightSentence *sentence. -void TokenizeTextForLangId(const std::string &text, LightSentence *sentence); - -// Returns a pointer "end" inside [data, data + size) such that the prefix from -// [data, end) is the largest one that does not contain '\0' and offers the -// following guarantee: if one starts with -// -// curr = text.data() -// -// and keeps executing -// -// curr += utils::GetNumBytesForNonZeroUTF8Char(curr) -// -// one would eventually reach curr == end (the pointer returned by this -// function) without accessing data outside the std::string. This guards -// against scenarios like a broken UTF-8 string which has only e.g., the first 2 -// bytes from a 3-byte UTF8 sequence. -const char *GetSafeEndOfString(const char *data, size_t size); - -static inline const char *GetSafeEndOfString(const std::string &text) { - return GetSafeEndOfString(text.data(), text.size()); -} - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_LANG_ID_CUSTOM_TOKENIZER_H_ diff --git a/lang_id/lang-id-brain-interface.h b/lang_id/lang-id-brain-interface.h deleted file mode 100644 index ce79497..0000000 --- a/lang_id/lang-id-brain-interface.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_ -#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_ - -#include <string> -#include <vector> - -#include "common/embedding-feature-extractor.h" -#include "common/feature-extractor.h" -#include "common/task-context.h" -#include "common/workspace.h" -#include "lang_id/light-sentence-features.h" -#include "lang_id/light-sentence.h" -#include "util/base/macros.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -// Specialization of EmbeddingFeatureExtractor that extracts from LightSentence. -class LangIdEmbeddingFeatureExtractor - : public EmbeddingFeatureExtractor<LightSentenceExtractor, LightSentence> { - public: - LangIdEmbeddingFeatureExtractor() {} - const std::string ArgPrefix() const override { return "language_identifier"; } - - TC_DISALLOW_COPY_AND_ASSIGN(LangIdEmbeddingFeatureExtractor); -}; - -// Handles sentence -> numeric_features and numeric_prediction -> language -// conversions. -class LangIdBrainInterface { - public: - LangIdBrainInterface() {} - - // Initializes resources and parameters. - bool Init(TaskContext *context) { - if (!feature_extractor_.Init(context)) { - return false; - } - feature_extractor_.RequestWorkspaces(&workspace_registry_); - return true; - } - - // Extract features from sentence. On return, FeatureVector features[i] - // contains the features for the embedding space #i. - void GetFeatures(LightSentence *sentence, - std::vector<FeatureVector> *features) const { - WorkspaceSet workspace; - workspace.Reset(workspace_registry_); - feature_extractor_.Preprocess(&workspace, sentence); - return feature_extractor_.ExtractFeatures(workspace, *sentence, features); - } - - int NumEmbeddings() const { - return feature_extractor_.NumEmbeddings(); - } - - private: - // Typed feature extractor for embeddings. - LangIdEmbeddingFeatureExtractor feature_extractor_; - - // The registry of shared workspaces in the feature extractor. - WorkspaceRegistry workspace_registry_; - - TC_DISALLOW_COPY_AND_ASSIGN(LangIdBrainInterface); -}; - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_ diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc deleted file mode 100644 index 8383d33..0000000 --- a/lang_id/lang-id.cc +++ /dev/null @@ -1,402 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "lang_id/lang-id.h" - -#include <stdio.h> - -#include <algorithm> -#include <limits> -#include <memory> -#include <string> -#include <vector> - -#include "common/algorithm.h" -#include "common/embedding-network-params-from-proto.h" -#include "common/embedding-network.pb.h" -#include "common/embedding-network.h" -#include "common/feature-extractor.h" -#include "common/file-utils.h" -#include "common/list-of-strings.pb.h" -#include "common/memory_image/in-memory-model-data.h" -#include "common/mmap.h" -#include "common/softmax.h" -#include "common/task-context.h" -#include "lang_id/custom-tokenizer.h" -#include "lang_id/lang-id-brain-interface.h" -#include "lang_id/language-identifier-features.h" -#include "lang_id/light-sentence-features.h" -#include "lang_id/light-sentence.h" -#include "lang_id/relevant-script-feature.h" -#include "util/base/logging.h" -#include "util/base/macros.h" - -using ::libtextclassifier::nlp_core::file_utils::ParseProtoFromMemory; - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -namespace { -// Default value for the probability threshold; see comments for -// LangId::SetProbabilityThreshold(). -static const float kDefaultProbabilityThreshold = 0.50; - -// Default value for min text size below which our model can't provide a -// meaningful prediction. -static const int kDefaultMinTextSizeInBytes = 20; - -// Initial value for the default language for LangId::FindLanguage(). The -// default language can be changed (for an individual LangId object) using -// LangId::SetDefaultLanguage(). -static const char kInitialDefaultLanguage[] = ""; - -// Returns total number of bytes of the words from sentence, without the ^ -// (start-of-word) and $ (end-of-word) markers. Note: "real text" means that -// this ignores whitespace and punctuation characters from the original text. -int GetRealTextSize(const LightSentence &sentence) { - int total = 0; - for (int i = 0; i < sentence.num_words(); ++i) { - TC_DCHECK(!sentence.word(i).empty()); - TC_DCHECK_EQ('^', sentence.word(i).front()); - TC_DCHECK_EQ('$', sentence.word(i).back()); - total += sentence.word(i).size() - 2; - } - return total; -} - -} // namespace - -// Class that performs all work behind LangId. -class LangIdImpl { - public: - explicit LangIdImpl(const std::string &filename) { - // Using mmap as a fast way to read the model bytes. - ScopedMmap scoped_mmap(filename); - MmapHandle mmap_handle = scoped_mmap.handle(); - if (!mmap_handle.ok()) { - TC_LOG(ERROR) << "Unable to read model bytes."; - return; - } - - Initialize(mmap_handle.to_stringpiece()); - } - - explicit LangIdImpl(int fd) { - // Using mmap as a fast way to read the model bytes. - ScopedMmap scoped_mmap(fd); - MmapHandle mmap_handle = scoped_mmap.handle(); - if (!mmap_handle.ok()) { - TC_LOG(ERROR) << "Unable to read model bytes."; - return; - } - - Initialize(mmap_handle.to_stringpiece()); - } - - LangIdImpl(const char *ptr, size_t length) { - Initialize(StringPiece(ptr, length)); - } - - void Initialize(StringPiece model_bytes) { - // Will set valid_ to true only on successful initialization. - valid_ = false; - - // Make sure all relevant features are registered: - ContinuousBagOfNgramsFunction::RegisterClass(); - RelevantScriptFeature::RegisterClass(); - - // NOTE(salcianu): code below relies on the fact that the current features - // do not rely on data from a TaskInput. Otherwise, one would have to use - // the more complex model registration mechanism, which requires more code. - InMemoryModelData model_data(model_bytes); - TaskContext context; - if (!model_data.GetTaskSpec(context.mutable_spec())) { - TC_LOG(ERROR) << "Unable to get model TaskSpec"; - return; - } - - if (!ParseNetworkParams(model_data, &context)) { - return; - } - if (!ParseListOfKnownLanguages(model_data, &context)) { - return; - } - - network_.reset(new EmbeddingNetwork(network_params_.get())); - if (!network_->is_valid()) { - return; - } - - probability_threshold_ = - context.Get("reliability_thresh", kDefaultProbabilityThreshold); - min_text_size_in_bytes_ = - context.Get("min_text_size_in_bytes", kDefaultMinTextSizeInBytes); - version_ = context.Get("version", 0); - - if (!lang_id_brain_interface_.Init(&context)) { - return; - } - valid_ = true; - } - - void SetProbabilityThreshold(float threshold) { - probability_threshold_ = threshold; - } - - void SetDefaultLanguage(const std::string &lang) { default_language_ = lang; } - - std::string FindLanguage(const std::string &text) const { - std::vector<float> scores = ScoreLanguages(text); - if (scores.empty()) { - return default_language_; - } - - // Softmax label with max score. - int label = GetArgMax(scores); - float probability = scores[label]; - if (probability < probability_threshold_) { - return default_language_; - } - return GetLanguageForSoftmaxLabel(label); - } - - std::vector<std::pair<std::string, float>> FindLanguages( - const std::string &text) const { - std::vector<float> scores = ScoreLanguages(text); - - std::vector<std::pair<std::string, float>> result; - for (int i = 0; i < scores.size(); i++) { - result.push_back({GetLanguageForSoftmaxLabel(i), scores[i]}); - } - - // To avoid crashing clients that always expect at least one predicted - // language, we promised (see doc for this method) that the result always - // contains at least one element. - if (result.empty()) { - // We use a tiny probability, such that any client that uses a meaningful - // probability threshold ignores this prediction. We don't use 0.0f, to - // avoid crashing clients that normalize the probabilities we return here. - result.push_back({default_language_, 0.001f}); - } - return result; - } - - std::vector<float> ScoreLanguages(const std::string &text) const { - if (!is_valid()) { - return {}; - } - - // Create a Sentence storing the input text. - LightSentence sentence; - TokenizeTextForLangId(text, &sentence); - - if (GetRealTextSize(sentence) < min_text_size_in_bytes_) { - return {}; - } - - // TODO(salcianu): reuse vector<FeatureVector>. - std::vector<FeatureVector> features( - lang_id_brain_interface_.NumEmbeddings()); - lang_id_brain_interface_.GetFeatures(&sentence, &features); - - // Predict language. - EmbeddingNetwork::Vector scores; - network_->ComputeFinalScores(features, &scores); - - return ComputeSoftmax(scores); - } - - bool is_valid() const { return valid_; } - - int version() const { return version_; } - - private: - // Returns name of the (in-memory) file for the indicated TaskInput from - // context. - static std::string GetInMemoryFileNameForTaskInput( - const std::string &input_name, TaskContext *context) { - TaskInput *task_input = context->GetInput(input_name); - if (task_input->part_size() != 1) { - TC_LOG(ERROR) << "TaskInput " << input_name << " has " - << task_input->part_size() << " parts"; - return ""; - } - return task_input->part(0).file_pattern(); - } - - bool ParseNetworkParams(const InMemoryModelData &model_data, - TaskContext *context) { - const std::string input_name = "language-identifier-network"; - const std::string input_file_name = - GetInMemoryFileNameForTaskInput(input_name, context); - if (input_file_name.empty()) { - TC_LOG(ERROR) << "No input file name for TaskInput " << input_name; - return false; - } - StringPiece bytes = model_data.GetBytesForInputFile(input_file_name); - if (bytes.data() == nullptr) { - TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name; - return false; - } - std::unique_ptr<EmbeddingNetworkProto> proto(new EmbeddingNetworkProto()); - if (!ParseProtoFromMemory(bytes, proto.get())) { - TC_LOG(ERROR) << "Unable to parse EmbeddingNetworkProto"; - return false; - } - network_params_.reset( - new EmbeddingNetworkParamsFromProto(std::move(proto))); - if (!network_params_->is_valid()) { - TC_LOG(ERROR) << "EmbeddingNetworkParamsFromProto not valid"; - return false; - } - return true; - } - - // Parses dictionary with known languages (i.e., field languages_) from a - // TaskInput of context. Note: that TaskInput should be a ListOfStrings proto - // with a single element, the serialized form of a ListOfStrings. - // - bool ParseListOfKnownLanguages(const InMemoryModelData &model_data, - TaskContext *context) { - const std::string input_name = "language-name-id-map"; - const std::string input_file_name = - GetInMemoryFileNameForTaskInput(input_name, context); - if (input_file_name.empty()) { - TC_LOG(ERROR) << "No input file name for TaskInput " << input_name; - return false; - } - StringPiece bytes = model_data.GetBytesForInputFile(input_file_name); - if (bytes.data() == nullptr) { - TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name; - return false; - } - ListOfStrings records; - if (!ParseProtoFromMemory(bytes, &records)) { - TC_LOG(ERROR) << "Unable to parse ListOfStrings from TaskInput " - << input_name; - return false; - } - if (records.element_size() != 1) { - TC_LOG(ERROR) << "Wrong number of records in TaskInput " << input_name - << " : " << records.element_size(); - return false; - } - if (!ParseProtoFromMemory(std::string(records.element(0)), &languages_)) { - TC_LOG(ERROR) << "Unable to parse dictionary with known languages"; - return false; - } - return true; - } - - // Returns language code for a softmax label. See comments for languages_ - // field. If label is out of range, returns default_language_. - std::string GetLanguageForSoftmaxLabel(int label) const { - if ((label >= 0) && (label < languages_.element_size())) { - return languages_.element(label); - } else { - TC_LOG(ERROR) << "Softmax label " << label << " outside range [0, " - << languages_.element_size() << ")"; - return default_language_; - } - } - - LangIdBrainInterface lang_id_brain_interface_; - - // Parameters for the neural network network_ (see below). - std::unique_ptr<EmbeddingNetworkParamsFromProto> network_params_; - - // Neural network to use for scoring. - std::unique_ptr<EmbeddingNetwork> network_; - - // True if this object is ready to perform language predictions. - bool valid_; - - // Only predictions with a probability (confidence) above this threshold are - // reported. Otherwise, we report default_language_. - float probability_threshold_ = kDefaultProbabilityThreshold; - - // Min size of the input text for our predictions to be meaningful. Below - // this threshold, the underlying model may report a wrong language and a high - // confidence score. - int min_text_size_in_bytes_ = kDefaultMinTextSizeInBytes; - - // Version of the model. - int version_ = -1; - - // Known languages: softmax label i (an integer) means languages_.element(i) - // (something like "en", "fr", "ru", etc). - ListOfStrings languages_; - - // Language code to return in case of errors. - std::string default_language_ = kInitialDefaultLanguage; - - TC_DISALLOW_COPY_AND_ASSIGN(LangIdImpl); -}; - -LangId::LangId(const std::string &filename) : pimpl_(new LangIdImpl(filename)) { - if (!pimpl_->is_valid()) { - TC_LOG(ERROR) << "Unable to construct a valid LangId based " - << "on the data from " << filename - << "; nothing should crash, but " - << "accuracy will be bad."; - } -} - -LangId::LangId(int fd) : pimpl_(new LangIdImpl(fd)) { - if (!pimpl_->is_valid()) { - TC_LOG(ERROR) << "Unable to construct a valid LangId based " - << "on the data from descriptor " << fd - << "; nothing should crash, " - << "but accuracy will be bad."; - } -} - -LangId::LangId(const char *ptr, size_t length) - : pimpl_(new LangIdImpl(ptr, length)) { - if (!pimpl_->is_valid()) { - TC_LOG(ERROR) << "Unable to construct a valid LangId based " - << "on the memory region; nothing should crash, " - << "but accuracy will be bad."; - } -} - -LangId::~LangId() = default; - -void LangId::SetProbabilityThreshold(float threshold) { - pimpl_->SetProbabilityThreshold(threshold); -} - -void LangId::SetDefaultLanguage(const std::string &lang) { - pimpl_->SetDefaultLanguage(lang); -} - -std::string LangId::FindLanguage(const std::string &text) const { - return pimpl_->FindLanguage(text); -} - -std::vector<std::pair<std::string, float>> LangId::FindLanguages( - const std::string &text) const { - return pimpl_->FindLanguages(text); -} - -bool LangId::is_valid() const { return pimpl_->is_valid(); } - -int LangId::version() const { return pimpl_->version(); } - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/lang_id/lang-id.h b/lang_id/lang-id.h deleted file mode 100644 index 7653dde..0000000 --- a/lang_id/lang-id.h +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_H_ -#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_H_ - -// Clients who want to perform language identification should use this header. -// -// Note for lang id implementors: keep this header as linght as possible. E.g., -// any macro defined here (or in a transitively #included file) is a potential -// name conflict with our clients. - -#include <memory> -#include <string> -#include <vector> - -#include "util/base/macros.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -// Forward-declaration of the class that performs all underlying work. -class LangIdImpl; - -// Class for detecting the language of a document. -// -// NOTE: this class is thread-unsafe. -class LangId { - public: - // Constructs a LangId object, loading an EmbeddingNetworkProto model from the - // indicated file. - // - // Note: we don't crash if we detect a problem at construction time (e.g., - // file doesn't exist, or its content is corrupted). Instead, we mark the - // newly-constructed object as invalid; clients can invoke FindLanguage() on - // an invalid object: nothing crashes, but accuracy will be bad. - explicit LangId(const std::string &filename); - - // Same as above but uses a file descriptor. - explicit LangId(int fd); - - // Same as above but uses already mapped memory region - explicit LangId(const char *ptr, size_t length); - - virtual ~LangId(); - - // Sets probability threshold for predictions. If our likeliest prediction is - // below this threshold, we report the default language (see - // SetDefaultLanguage()). Othewise, we report the likelist language. - // - // By default (if this method is not called) we use the probability threshold - // stored in the model, as the task parameter "reliability_thresh". If that - // task parameter is not specified, we use 0.5. A client can use this method - // to get a different precision / recall trade-off. The higher the threshold, - // the higher the precision and lower the recall rate. - void SetProbabilityThreshold(float threshold); - - // Sets default language to report if errors prevent running the real - // inference code or if prediction confidence is too small. - void SetDefaultLanguage(const std::string &lang); - - // Returns language code for the most likely language that text is written in. - // Note: if this LangId object is not valid (see - // is_valid()), this method returns the default language specified via - // SetDefaultLanguage() or (if that method was never invoked), the empty - // std::string. - std::string FindLanguage(const std::string &text) const; - - // Returns a vector of language codes along with the probability for each - // language. The result contains at least one element. The sum of - // probabilities may be less than 1.0. - std::vector<std::pair<std::string, float>> FindLanguages( - const std::string &text) const; - - // Returns true if this object has been correctly initialized and is ready to - // perform predictions. For more info, see doc for LangId - // constructor above. - bool is_valid() const; - - // Returns version number for the model. - int version() const; - - private: - // Returns a vector of probabilities of languages of the text. - std::vector<float> ScoreLanguages(const std::string &text) const; - - // Pimpl ("pointer to implementation") pattern, to hide all internals from our - // clients. - std::unique_ptr<LangIdImpl> pimpl_; - - TC_DISALLOW_COPY_AND_ASSIGN(LangId); -}; - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_H_ diff --git a/lang_id/lang-id_test.cc b/lang_id/lang-id_test.cc deleted file mode 100644 index 2f8aedd..0000000 --- a/lang_id/lang-id_test.cc +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "lang_id/lang-id.h" - -#include <memory> -#include <string> -#include <utility> -#include <vector> - -#include "util/base/logging.h" -#include "gtest/gtest.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -namespace { - -std::string GetModelPath() { - return TEST_DATA_DIR "langid.model"; -} - -// Creates a LangId with default model. Passes ownership to -// the caller. -LangId *CreateLanguageDetector() { return new LangId(GetModelPath()); } - -} // namespace - -TEST(LangIdTest, Normal) { - std::unique_ptr<LangId> lang_id(CreateLanguageDetector()); - - EXPECT_EQ("en", lang_id->FindLanguage("This text is written in English.")); - EXPECT_EQ("en", - lang_id->FindLanguage("This text is written in English. ")); - EXPECT_EQ("en", - lang_id->FindLanguage(" This text is written in English. ")); - EXPECT_EQ("fr", lang_id->FindLanguage("Vive la France! Vive la France!")); - EXPECT_EQ("ro", lang_id->FindLanguage("Sunt foarte foarte foarte fericit!")); -} - -// Test that for very small queries, we return the default language and a low -// confidence score. -TEST(LangIdTest, SuperSmallQueries) { - std::unique_ptr<LangId> lang_id(CreateLanguageDetector()); - - // Use a default language different from any real language: to be sure the - // result is the default language, not a language that happens to be the - // default language. - const std::string kDefaultLanguage = "dflt-lng"; - lang_id->SetDefaultLanguage(kDefaultLanguage); - - // Test the simple FindLanguage() method: that method returns a single - // language. - EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("y")); - EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("j")); - EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("l")); - EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("w")); - EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("z")); - EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("zulu")); - - // Test the more complex FindLanguages() method: that method returns a vector - // of (language, confidence_score) pairs. - std::vector<std::pair<std::string, float>> languages; - languages = lang_id->FindLanguages("y"); - EXPECT_EQ(1, languages.size()); - EXPECT_EQ(kDefaultLanguage, languages[0].first); - EXPECT_GT(0.01f, languages[0].second); - - languages = lang_id->FindLanguages("Todoist"); - EXPECT_EQ(1, languages.size()); - EXPECT_EQ(kDefaultLanguage, languages[0].first); - EXPECT_GT(0.01f, languages[0].second); - - // A few tests with a default language that is a real language code. - const std::string kJapanese = "ja"; - lang_id->SetDefaultLanguage(kJapanese); - EXPECT_EQ(kJapanese, lang_id->FindLanguage("y")); - EXPECT_EQ(kJapanese, lang_id->FindLanguage("j")); - EXPECT_EQ(kJapanese, lang_id->FindLanguage("l")); - languages = lang_id->FindLanguages("y"); - EXPECT_EQ(1, languages.size()); - EXPECT_EQ(kJapanese, languages[0].first); - EXPECT_GT(0.01f, languages[0].second); - - // Make sure the min text size limit is applied to the number of real - // characters (e.g., without spaces and punctuation chars, which don't - // influence language identification). - const std::string kWhitespaces = " \t \n \t\t\t\n \t"; - const std::string kPunctuation = "... ?!!--- -%%^...-"; - std::string still_small_string = kWhitespaces + "y" + kWhitespaces + - kPunctuation + kWhitespaces + kPunctuation + - kPunctuation; - EXPECT_LE(100, still_small_string.size()); - lang_id->SetDefaultLanguage(kDefaultLanguage); - EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage(still_small_string)); - languages = lang_id->FindLanguages(still_small_string); - EXPECT_EQ(1, languages.size()); - EXPECT_EQ(kDefaultLanguage, languages[0].first); - EXPECT_GT(0.01f, languages[0].second); -} - -namespace { -void CheckPredictionForGibberishStrings(const std::string &default_language) { - static const char *const kGibberish[] = { - "", - " ", - " ", - " ___ ", - "123 456 789", - "><> (-_-) <><", - nullptr, - }; - - std::unique_ptr<LangId> lang_id(CreateLanguageDetector()); - TC_LOG(INFO) << "Default language: " << default_language; - lang_id->SetDefaultLanguage(default_language); - for (int i = 0; true; ++i) { - const char *gibberish = kGibberish[i]; - if (gibberish == nullptr) { - break; - } - const std::string predicted_language = lang_id->FindLanguage(gibberish); - TC_LOG(INFO) << "Predicted " << predicted_language << " for \"" << gibberish - << "\""; - EXPECT_EQ(default_language, predicted_language); - } -} -} // namespace - -TEST(LangIdTest, CornerCases) { - CheckPredictionForGibberishStrings("en"); - CheckPredictionForGibberishStrings("ro"); - CheckPredictionForGibberishStrings("fr"); -} - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/lang_id/language-identifier-features.cc b/lang_id/language-identifier-features.cc deleted file mode 100644 index 2e3912e..0000000 --- a/lang_id/language-identifier-features.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "lang_id/language-identifier-features.h" - -#include <utility> -#include <vector> - -#include "common/feature-extractor.h" -#include "common/feature-types.h" -#include "common/task-context.h" -#include "util/hash/hash.h" -#include "util/strings/utf8.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) { - // Parameters in the feature function descriptor. - ngram_id_dimension_ = GetIntParameter("id_dim", 10000); - ngram_size_ = GetIntParameter("size", 3); - - counts_.assign(ngram_id_dimension_, 0); - return true; -} - -bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) { - set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_)); - return true; -} - -int ContinuousBagOfNgramsFunction::ComputeNgramCounts( - const LightSentence &sentence) const { - // Invariant 1: counts_.size() == ngram_id_dimension_. Holds at the end of - // the constructor. After that, no method changes counts_.size(). - TC_DCHECK_EQ(counts_.size(), ngram_id_dimension_); - - // Invariant 2: the vector non_zero_count_indices_ is empty. The vector - // non_zero_count_indices_ is empty at construction time and gets emptied at - // the end of each call to Evaluate(). Hence, this invariant holds at the - // beginning of each run of Evaluate(), where the only call to this code takes - // place. - TC_DCHECK(non_zero_count_indices_.empty()); - - int total_count = 0; - - for (int i = 0; i < sentence.num_words(); ++i) { - const std::string &word = sentence.word(i); - const char *const word_end = word.data() + word.size(); - - // Set ngram_start at the start of the current token (word). - const char *ngram_start = word.data(); - - // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each - // UTF8 character contains between 1 and 4 bytes. - const char *ngram_end = ngram_start; - int num_utf8_chars = 0; - do { - ngram_end += GetNumBytesForNonZeroUTF8Char(ngram_end); - num_utf8_chars++; - } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end)); - - if (num_utf8_chars < ngram_size_) { - // Current token is so small, it does not contain a single ngram of - // ngram_size UTF8 characters. Not much we can do in this case ... - continue; - } - - // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size - // UTF8 characters from current token. - while (true) { - // Compute ngram_id: hash(ngram) % ngram_id_dimension - int ngram_id = - (Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start) % - ngram_id_dimension_); - - // Use a reference to the actual count, such that we can both test whether - // the count was 0 and increment it without perfoming two lookups. - // - // Due to the way we compute ngram_id, 0 <= ngram_id < ngram_id_dimension. - // Hence, by Invariant 1 (above), the access counts_[ngram_id] is safe. - int &ref_to_count_for_ngram = counts_[ngram_id]; - if (ref_to_count_for_ngram == 0) { - non_zero_count_indices_.push_back(ngram_id); - } - ref_to_count_for_ngram++; - total_count++; - if (ngram_end >= word_end) { - break; - } - - // Advance both ngram_start and ngram_end by one UTF8 character. This - // way, the number of UTF8 characters between them remains constant - // (ngram_size). - ngram_start += GetNumBytesForNonZeroUTF8Char(ngram_start); - ngram_end += GetNumBytesForNonZeroUTF8Char(ngram_end); - } - } // end of loop over tokens. - - return total_count; -} - -void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces, - const LightSentence &sentence, - FeatureVector *result) const { - // Find the char ngram counts. - int total_count = ComputeNgramCounts(sentence); - - // Populate the feature vector. - const float norm = static_cast<float>(total_count); - - for (int ngram_id : non_zero_count_indices_) { - const float weight = counts_[ngram_id] / norm; - FloatFeatureValue value(ngram_id, weight); - result->add(feature_type(), value.discrete_value); - - // Clear up counts_, for the next invocation of Evaluate(). - counts_[ngram_id] = 0; - } - - // Clear up non_zero_count_indices_, for the next invocation of Evaluate(). - non_zero_count_indices_.clear(); -} - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/lang_id/language-identifier-features.h b/lang_id/language-identifier-features.h deleted file mode 100644 index a4e3b3d..0000000 --- a/lang_id/language-identifier-features.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANGUAGE_IDENTIFIER_FEATURES_H_ -#define LIBTEXTCLASSIFIER_LANG_ID_LANGUAGE_IDENTIFIER_FEATURES_H_ - -#include <string> - -#include "common/feature-extractor.h" -#include "common/task-context.h" -#include "common/workspace.h" -#include "lang_id/light-sentence-features.h" -#include "lang_id/light-sentence.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -// Class for computing continuous char ngram features. -// -// Feature function descriptor parameters: -// id_dim(int, 10000): -// The integer id of each char ngram is computed as follows: -// Hash32WithDefaultSeed(char ngram) % id_dim. -// size(int, 3): -// Only ngrams of this size will be extracted. -// -// NOTE: this class is not thread-safe. TODO(salcianu): make it thread-safe. -class ContinuousBagOfNgramsFunction : public LightSentenceFeature { - public: - bool Setup(TaskContext *context) override; - bool Init(TaskContext *context) override; - - // Appends the features computed from the sentence to the feature vector. - void Evaluate(const WorkspaceSet &workspaces, const LightSentence &sentence, - FeatureVector *result) const override; - - TC_DEFINE_REGISTRATION_METHOD("continuous-bag-of-ngrams", - ContinuousBagOfNgramsFunction); - - private: - // Auxiliary for Evaluate(). Fills counts_ and non_zero_count_indices_ (see - // below), and returns the total ngram count. - int ComputeNgramCounts(const LightSentence &sentence) const; - - // counts_[i] is the count of all ngrams with id i. Work data for Evaluate(). - // NOTE: we declare this vector as a field, such that its underlying capacity - // stays allocated in between calls to Evaluate(). - mutable std::vector<int> counts_; - - // Indices of non-zero elements of counts_. See comments for counts_. - mutable std::vector<int> non_zero_count_indices_; - - // The integer id of each char ngram is computed as follows: - // Hash32WithDefaultSeed(char_ngram) % ngram_id_dimension_. - int ngram_id_dimension_; - - // Only ngrams of size ngram_size_ will be extracted. - int ngram_size_; -}; - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_LANG_ID_LANGUAGE_IDENTIFIER_FEATURES_H_ diff --git a/lang_id/light-sentence-features.h b/lang_id/light-sentence-features.h deleted file mode 100644 index a140f65..0000000 --- a/lang_id/light-sentence-features.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_FEATURES_H_ -#define LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_FEATURES_H_ - -#include "common/feature-extractor.h" -#include "lang_id/light-sentence.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -// Feature function that extracts features from LightSentences. -typedef FeatureFunction<LightSentence> LightSentenceFeature; - -// Feature extractor for LightSentences. -typedef FeatureExtractor<LightSentence> LightSentenceExtractor; - -} // namespace lang_id - -// Should be used in namespace libtextclassifier::nlp_core. -TC_DECLARE_CLASS_REGISTRY_NAME(lang_id::LightSentenceFeature); - -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_FEATURES_H_ diff --git a/lang_id/light-sentence.h b/lang_id/light-sentence.h deleted file mode 100644 index e8451be..0000000 --- a/lang_id/light-sentence.h +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_H_ -#define LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_H_ - -#include <string> -#include <vector> - -#include "util/base/logging.h" -#include "util/base/macros.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -// Simplified replacement for the Sentence proto, for internal use in the -// language identification code. -// -// In this simplified form, a sentence is a vector of words, each word being a -// string. -class LightSentence { - public: - LightSentence() {} - - // Adds a new word after all existing ones, and returns a pointer to it. The - // new word is initialized to the empty string. - std::string *add_word() { - words_.emplace_back(); - return &(words_.back()); - } - - // Returns number of words from this LightSentence. - int num_words() const { return words_.size(); } - - // Returns the ith word from this LightSentence. Note: undefined behavior if - // i is out of bounds. - const std::string &word(int i) const { - TC_DCHECK((i >= 0) && (i < num_words())); - return words_[i]; - } - - private: - std::vector<std::string> words_; - - TC_DISALLOW_COPY_AND_ASSIGN(LightSentence); -}; - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_H_ diff --git a/lang_id/relevant-script-feature.cc b/lang_id/relevant-script-feature.cc deleted file mode 100644 index c865ce5..0000000 --- a/lang_id/relevant-script-feature.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "lang_id/relevant-script-feature.h" - -#include <string> - -#include "common/feature-extractor.h" -#include "common/feature-types.h" -#include "common/task-context.h" -#include "common/workspace.h" -#include "lang_id/script-detector.h" -#include "util/base/logging.h" -#include "util/strings/utf8.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -bool RelevantScriptFeature::Setup(TaskContext *context) { return true; } - -bool RelevantScriptFeature::Init(TaskContext *context) { - set_feature_type(new NumericFeatureType(name(), kNumRelevantScripts)); - return true; -} - -void RelevantScriptFeature::Evaluate(const WorkspaceSet &workspaces, - const LightSentence &sentence, - FeatureVector *result) const { - // We expect kNumRelevantScripts to be small, so we stack-allocate the array - // of counts. Still, if that changes, we want to find out. - static_assert( - kNumRelevantScripts < 25, - "switch counts to vector<int>: too big for stack-allocated int[]"); - - // counts[s] is the number of characters with script s. - // Note: {} "value-initializes" the array to zero. - int counts[kNumRelevantScripts]{}; - int total_count = 0; - for (int i = 0; i < sentence.num_words(); ++i) { - const std::string &word = sentence.word(i); - const char *const word_end = word.data() + word.size(); - const char *curr = word.data(); - - // Skip over token start '^'. - TC_DCHECK_EQ(*curr, '^'); - curr += GetNumBytesForNonZeroUTF8Char(curr); - while (true) { - const int num_bytes = GetNumBytesForNonZeroUTF8Char(curr); - Script script = GetScript(curr, num_bytes); - - // We do this update and the if (...) break below *before* incrementing - // counts[script] in order to skip the token end '$'. - curr += num_bytes; - if (curr >= word_end) { - TC_DCHECK_EQ(*(curr - num_bytes), '$'); - break; - } - TC_DCHECK_GE(script, 0); - TC_DCHECK_LT(script, kNumRelevantScripts); - counts[script]++; - total_count++; - } - } - - for (int script_id = 0; script_id < kNumRelevantScripts; ++script_id) { - int count = counts[script_id]; - if (count > 0) { - const float weight = static_cast<float>(count) / total_count; - FloatFeatureValue value(script_id, weight); - result->add(feature_type(), value.discrete_value); - } - } -} - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier diff --git a/lang_id/relevant-script-feature.h b/lang_id/relevant-script-feature.h deleted file mode 100644 index 2aa2420..0000000 --- a/lang_id/relevant-script-feature.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_LANG_ID_RELEVANT_SCRIPT_FEATURE_H_ -#define LIBTEXTCLASSIFIER_LANG_ID_RELEVANT_SCRIPT_FEATURE_H_ - -#include "common/feature-extractor.h" -#include "common/task-context.h" -#include "common/workspace.h" -#include "lang_id/light-sentence-features.h" -#include "lang_id/light-sentence.h" - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -// Given a sentence, generates one FloatFeatureValue for each "relevant" Unicode -// script (see below): each such feature indicates the script and the ratio of -// UTF8 characters in that script, in the given sentence. -// -// What is a relevant script? Recognizing all 100+ Unicode scripts would -// require too much code size and runtime. Instead, we focus only on a few -// scripts that communicate a lot of language information: e.g., the use of -// Hiragana characters almost always indicates Japanese, so Hiragana is a -// "relevant" script for us. The Latin script is used by dozens of language, so -// Latin is not relevant in this context. -class RelevantScriptFeature : public LightSentenceFeature { - public: - // Idiomatic SAFT Setup() and Init(). - bool Setup(TaskContext *context) override; - bool Init(TaskContext *context) override; - - // Appends the features computed from the sentence to the feature vector. - void Evaluate(const WorkspaceSet &workspaces, const LightSentence &sentence, - FeatureVector *result) const override; - - TC_DEFINE_REGISTRATION_METHOD("continuous-bag-of-relevant-scripts", - RelevantScriptFeature); -}; - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_LANG_ID_RELEVANT_SCRIPT_FEATURE_H_ diff --git a/lang_id/script-detector.h b/lang_id/script-detector.h deleted file mode 100644 index cf816ee..0000000 --- a/lang_id/script-detector.h +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_LANG_ID_SCRIPT_DETECTOR_H_ -#define LIBTEXTCLASSIFIER_LANG_ID_SCRIPT_DETECTOR_H_ - -namespace libtextclassifier { -namespace nlp_core { -namespace lang_id { - -// Unicode scripts we care about. To get compact and fast code, we detect only -// a few Unicode scripts that offer a strong indication about the language of -// the text (e.g., Hiragana -> Japanese). -enum Script { - // Special value to indicate internal errors in the script detection code. - kScriptError, - - // Special values for all Unicode scripts that we do not detect. One special - // value for Unicode characters of 1, 2, 3, respectively 4 bytes (as we - // already have that information, we use it). kScriptOtherUtf8OneByte means - // ~Latin and kScriptOtherUtf8FourBytes means ~Han. - kScriptOtherUtf8OneByte, - kScriptOtherUtf8TwoBytes, - kScriptOtherUtf8ThreeBytes, - kScriptOtherUtf8FourBytes, - - kScriptGreek, - kScriptCyrillic, - kScriptHebrew, - kScriptArabic, - kScriptHangulJamo, // Used primarily for Korean. - kScriptHiragana, // Used primarily for Japanese. - kScriptKatakana, // Used primarily for Japanese. - - // Add new scripts here. - - // Do not add any script after kNumRelevantScripts. This value indicates the - // number of elements in this enum Script (except this value) such that we can - // easily iterate over the scripts. - kNumRelevantScripts, -}; - -template<typename IntType> -inline bool InRange(IntType value, IntType low, IntType hi) { - return (value >= low) && (value <= hi); -} - -// Returns Script for the UTF8 character that starts at address p. -// Precondition: p points to a valid UTF8 character of num_bytes bytes. -inline Script GetScript(const unsigned char *p, int num_bytes) { - switch (num_bytes) { - case 1: - return kScriptOtherUtf8OneByte; - - case 2: { - // 2-byte UTF8 characters have 11 bits of information. unsigned int has - // at least 16 bits (http://en.cppreference.com/w/cpp/language/types) so - // it's enough. It's also usually the fastest int type on the current - // CPU, so it's better to use than int32. - static const unsigned int kGreekStart = 0x370; - - // Commented out (unsued in the code): kGreekEnd = 0x3FF; - static const unsigned int kCyrillicStart = 0x400; - static const unsigned int kCyrillicEnd = 0x4FF; - static const unsigned int kHebrewStart = 0x590; - - // Commented out (unsued in the code): kHebrewEnd = 0x5FF; - static const unsigned int kArabicStart = 0x600; - static const unsigned int kArabicEnd = 0x6FF; - const unsigned int codepoint = ((p[0] & 0x1F) << 6) | (p[1] & 0x3F); - if (codepoint > kCyrillicEnd) { - if (codepoint >= kArabicStart) { - if (codepoint <= kArabicEnd) { - return kScriptArabic; - } - } else { - // At this point, codepoint < kArabicStart = kHebrewEnd + 1, so - // codepoint <= kHebrewEnd. - if (codepoint >= kHebrewStart) { - return kScriptHebrew; - } - } - } else { - if (codepoint >= kCyrillicStart) { - return kScriptCyrillic; - } else { - // At this point, codepoint < kCyrillicStart = kGreekEnd + 1, so - // codepoint <= kGreekEnd. - if (codepoint >= kGreekStart) { - return kScriptGreek; - } - } - } - return kScriptOtherUtf8TwoBytes; - } - - case 3: { - // 3-byte UTF8 characters have 16 bits of information. unsigned int has - // at least 16 bits. - static const unsigned int kHangulJamoStart = 0x1100; - static const unsigned int kHangulJamoEnd = 0x11FF; - static const unsigned int kHiraganaStart = 0x3041; - static const unsigned int kHiraganaEnd = 0x309F; - - // Commented out (unsued in the code): kKatakanaStart = 0x30A0; - static const unsigned int kKatakanaEnd = 0x30FF; - const unsigned int codepoint = - ((p[0] & 0x0F) << 12) | ((p[1] & 0x3F) << 6) | (p[2] & 0x3F); - if (codepoint > kHiraganaEnd) { - // On this branch, codepoint > kHiraganaEnd = kKatakanaStart - 1, so - // codepoint >= kKatakanaStart. - if (codepoint <= kKatakanaEnd) { - return kScriptKatakana; - } - } else { - if (codepoint >= kHiraganaStart) { - return kScriptHiragana; - } else { - if (InRange(codepoint, kHangulJamoStart, kHangulJamoEnd)) { - return kScriptHangulJamo; - } - } - } - return kScriptOtherUtf8ThreeBytes; - } - - case 4: - return kScriptOtherUtf8FourBytes; - - default: - return kScriptError; - } -} - -// Returns Script for the UTF8 character that starts at address p. Similar to -// the previous version of GetScript, except for "char" vs "unsigned char". -// Most code works with "char *" pointers, ignoring the fact that char is -// unsigned (by default) on most platforms, but signed on iOS. This code takes -// care of making sure we always treat chars as unsigned. -inline Script GetScript(const char *p, int num_bytes) { - return GetScript(reinterpret_cast<const unsigned char *>(p), - num_bytes); -} - -} // namespace lang_id -} // namespace nlp_core -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_LANG_ID_SCRIPT_DETECTOR_H_ diff --git a/model-executor.cc b/model-executor.cc new file mode 100644 index 0000000..69931cb --- /dev/null +++ b/model-executor.cc @@ -0,0 +1,162 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "model-executor.h" + +#include "quantization.h" +#include "util/base/logging.h" + +namespace libtextclassifier2 { +namespace internal { +bool FromModelSpec(const tflite::Model* model_spec, + std::unique_ptr<const tflite::FlatBufferModel>* model) { + *model = tflite::FlatBufferModel::BuildFromModel(model_spec); + if (!(*model) || !(*model)->initialized()) { + TC_LOG(ERROR) << "Could not build TFLite model from a model spec. "; + return false; + } + return true; +} +} // namespace internal + +std::unique_ptr<tflite::Interpreter> ModelExecutor::CreateInterpreter() const { + std::unique_ptr<tflite::Interpreter> interpreter; + tflite::InterpreterBuilder(*model_, builtins_)(&interpreter); + return interpreter; +} + +std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::Instance( + const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, + int quantization_bits) { + const tflite::Model* model_spec = + flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data()); + flatbuffers::Verifier verifier(model_spec_buffer->data(), + model_spec_buffer->Length()); + std::unique_ptr<const tflite::FlatBufferModel> model; + if (!model_spec->Verify(verifier) || + !internal::FromModelSpec(model_spec, &model)) { + TC_LOG(ERROR) << "Could not load TFLite model."; + return nullptr; + } + + std::unique_ptr<tflite::Interpreter> interpreter; + tflite::ops::builtin::BuiltinOpResolver builtins; + tflite::InterpreterBuilder(*model, builtins)(&interpreter); + if (!interpreter) { + TC_LOG(ERROR) << "Could not build TFLite interpreter for embeddings."; + return nullptr; + } + + if (interpreter->tensors_size() != 2) { + return nullptr; + } + const TfLiteTensor* embeddings = interpreter->tensor(0); + if (embeddings->dims->size != 2) { + return nullptr; + } + int num_buckets = embeddings->dims->data[0]; + const TfLiteTensor* scales = interpreter->tensor(1); + if (scales->dims->size != 2 || scales->dims->data[0] != num_buckets || + scales->dims->data[1] != 1) { + return nullptr; + } + int bytes_per_embedding = embeddings->dims->data[1]; + if (!CheckQuantizationParams(bytes_per_embedding, quantization_bits, + embedding_size)) { + TC_LOG(ERROR) << "Mismatch in quantization parameters."; + return nullptr; + } + + return std::unique_ptr<TFLiteEmbeddingExecutor>(new TFLiteEmbeddingExecutor( + std::move(model), quantization_bits, num_buckets, bytes_per_embedding, + embedding_size, scales, embeddings, std::move(interpreter))); +} + +TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor( + std::unique_ptr<const tflite::FlatBufferModel> model, int quantization_bits, + int num_buckets, int bytes_per_embedding, int output_embedding_size, + const TfLiteTensor* scales, const TfLiteTensor* embeddings, + std::unique_ptr<tflite::Interpreter> interpreter) + : model_(std::move(model)), + quantization_bits_(quantization_bits), + num_buckets_(num_buckets), + bytes_per_embedding_(bytes_per_embedding), + output_embedding_size_(output_embedding_size), + scales_(scales), + embeddings_(embeddings), + interpreter_(std::move(interpreter)) {} + +bool TFLiteEmbeddingExecutor::AddEmbedding( + const TensorView<int>& sparse_features, float* dest, int dest_size) const { + if (dest_size != output_embedding_size_) { + TC_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: " + << dest_size << " " << output_embedding_size_; + return false; + } + const int num_sparse_features = sparse_features.size(); + for (int i = 0; i < num_sparse_features; ++i) { + const int bucket_id = sparse_features.data()[i]; + if (bucket_id >= num_buckets_) { + return false; + } + + if (!DequantizeAdd(scales_->data.f, embeddings_->data.uint8, + bytes_per_embedding_, num_sparse_features, + quantization_bits_, bucket_id, dest, dest_size)) { + return false; + } + } + return true; +} + +TensorView<float> ComputeLogitsHelper(const int input_index_features, + const int output_index_logits, + const TensorView<float>& features, + tflite::Interpreter* interpreter) { + if (!interpreter) { + return TensorView<float>::Invalid(); + } + interpreter->ResizeInputTensor(input_index_features, features.shape()); + if (interpreter->AllocateTensors() != kTfLiteOk) { + TC_VLOG(1) << "Allocation failed."; + return TensorView<float>::Invalid(); + } + + TfLiteTensor* features_tensor = + interpreter->tensor(interpreter->inputs()[input_index_features]); + int size = 1; + for (int i = 0; i < features_tensor->dims->size; ++i) { + size *= features_tensor->dims->data[i]; + } + features.copy_to(features_tensor->data.f, size); + + if (interpreter->Invoke() != kTfLiteOk) { + TC_VLOG(1) << "Interpreter failed."; + return TensorView<float>::Invalid(); + } + + TfLiteTensor* logits_tensor = + interpreter->tensor(interpreter->outputs()[output_index_logits]); + + std::vector<int> output_shape(logits_tensor->dims->size); + for (int i = 0; i < logits_tensor->dims->size; ++i) { + output_shape[i] = logits_tensor->dims->data[i]; + } + + return TensorView<float>(logits_tensor->data.f, output_shape); +} + +} // namespace libtextclassifier2 diff --git a/model-executor.h b/model-executor.h new file mode 100644 index 0000000..ef6d36f --- /dev/null +++ b/model-executor.h @@ -0,0 +1,137 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Contains classes that can execute different models/parts of a model. + +#ifndef LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ +#define LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ + +#include <memory> + +#include "tensor-view.h" +#include "types.h" +#include "util/base/logging.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" + +namespace libtextclassifier2 { + +namespace internal { +bool FromModelSpec(const tflite::Model* model_spec, + std::unique_ptr<const tflite::FlatBufferModel>* model); +} // namespace internal + +// A helper function that given indices of feature and logits tensor, feature +// values computes the logits using given interpreter. +TensorView<float> ComputeLogitsHelper(const int input_index_features, + const int output_index_logits, + const TensorView<float>& features, + tflite::Interpreter* interpreter); + +// Executor for the text selection prediction and classification models. +class ModelExecutor { + public: + static std::unique_ptr<const ModelExecutor> Instance( + const flatbuffers::Vector<uint8_t>* model_spec_buffer) { + const tflite::Model* model = + flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data()); + flatbuffers::Verifier verifier(model_spec_buffer->data(), + model_spec_buffer->Length()); + if (!model->Verify(verifier)) { + return nullptr; + } + return Instance(model); + } + + static std::unique_ptr<const ModelExecutor> Instance( + const tflite::Model* model_spec) { + std::unique_ptr<const tflite::FlatBufferModel> model; + if (!internal::FromModelSpec(model_spec, &model)) { + return nullptr; + } + return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); + } + + // Creates an Interpreter for the model that serves as a scratch-pad for the + // inference. The Interpreter is NOT thread-safe. + std::unique_ptr<tflite::Interpreter> CreateInterpreter() const; + + TensorView<float> ComputeLogits(const TensorView<float>& features, + tflite::Interpreter* interpreter) const { + return ComputeLogitsHelper(kInputIndexFeatures, kOutputIndexLogits, + features, interpreter); + } + + protected: + explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model) + : model_(std::move(model)) {} + + static const int kInputIndexFeatures = 0; + static const int kOutputIndexLogits = 0; + + std::unique_ptr<const tflite::FlatBufferModel> model_; + tflite::ops::builtin::BuiltinOpResolver builtins_; +}; + +// Executor for embedding sparse features into a dense vector. +class EmbeddingExecutor { + public: + virtual ~EmbeddingExecutor() {} + + // Embeds the sparse_features into a dense embedding and adds (+) it + // element-wise to the dest vector. + virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, + int dest_size) const = 0; + + // Returns true when the model is ready to be used, false otherwise. + virtual bool IsReady() const { return true; } +}; + +class TFLiteEmbeddingExecutor : public EmbeddingExecutor { + public: + static std::unique_ptr<TFLiteEmbeddingExecutor> Instance( + const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, + int quantization_bits); + + bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, + int dest_size) const override; + + protected: + explicit TFLiteEmbeddingExecutor( + std::unique_ptr<const tflite::FlatBufferModel> model, + int quantization_bits, int num_buckets, int bytes_per_embedding, + int output_embedding_size, const TfLiteTensor* scales, + const TfLiteTensor* embeddings, + std::unique_ptr<tflite::Interpreter> interpreter); + + std::unique_ptr<const tflite::FlatBufferModel> model_; + + int quantization_bits_; + int num_buckets_ = -1; + int bytes_per_embedding_ = -1; + int output_embedding_size_ = -1; + const TfLiteTensor* scales_ = nullptr; + const TfLiteTensor* embeddings_ = nullptr; + + // NOTE: This interpreter is used in a read-only way (as a storage for the + // model params), thus is still thread-safe. + std::unique_ptr<tflite::Interpreter> interpreter_; +}; + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ diff --git a/model.fbs b/model.fbs new file mode 100755 index 0000000..fb9778b --- /dev/null +++ b/model.fbs @@ -0,0 +1,577 @@ +// +// Copyright (C) 2017 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +file_identifier "TC2 "; + +// The possible model modes, represents a bit field. +namespace libtextclassifier2; +enum ModeFlag : int { + NONE = 0, + ANNOTATION = 1, + CLASSIFICATION = 2, + ANNOTATION_AND_CLASSIFICATION = 3, + SELECTION = 4, + ANNOTATION_AND_SELECTION = 5, + CLASSIFICATION_AND_SELECTION = 6, + ALL = 7, +} + +namespace libtextclassifier2; +enum DatetimeExtractorType : int { + UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0, + AM = 1, + PM = 2, + JANUARY = 3, + FEBRUARY = 4, + MARCH = 5, + APRIL = 6, + MAY = 7, + JUNE = 8, + JULY = 9, + AUGUST = 10, + SEPTEMBER = 11, + OCTOBER = 12, + NOVEMBER = 13, + DECEMBER = 14, + NEXT = 15, + NEXT_OR_SAME = 16, + LAST = 17, + NOW = 18, + TOMORROW = 19, + YESTERDAY = 20, + PAST = 21, + FUTURE = 22, + DAY = 23, + WEEK = 24, + MONTH = 25, + YEAR = 26, + MONDAY = 27, + TUESDAY = 28, + WEDNESDAY = 29, + THURSDAY = 30, + FRIDAY = 31, + SATURDAY = 32, + SUNDAY = 33, + DAYS = 34, + WEEKS = 35, + MONTHS = 36, + HOURS = 37, + MINUTES = 38, + SECONDS = 39, + YEARS = 40, + DIGITS = 41, + SIGNEDDIGITS = 42, + ZERO = 43, + ONE = 44, + TWO = 45, + THREE = 46, + FOUR = 47, + FIVE = 48, + SIX = 49, + SEVEN = 50, + EIGHT = 51, + NINE = 52, + TEN = 53, + ELEVEN = 54, + TWELVE = 55, + THIRTEEN = 56, + FOURTEEN = 57, + FIFTEEN = 58, + SIXTEEN = 59, + SEVENTEEN = 60, + EIGHTEEN = 61, + NINETEEN = 62, + TWENTY = 63, + THIRTY = 64, + FORTY = 65, + FIFTY = 66, + SIXTY = 67, + SEVENTY = 68, + EIGHTY = 69, + NINETY = 70, + HUNDRED = 71, + THOUSAND = 72, +} + +namespace libtextclassifier2; +enum DatetimeGroupType : int { + GROUP_UNKNOWN = 0, + GROUP_UNUSED = 1, + GROUP_YEAR = 2, + GROUP_MONTH = 3, + GROUP_DAY = 4, + GROUP_HOUR = 5, + GROUP_MINUTE = 6, + GROUP_SECOND = 7, + GROUP_AMPM = 8, + GROUP_RELATIONDISTANCE = 9, + GROUP_RELATION = 10, + GROUP_RELATIONTYPE = 11, + + // Dummy groups serve just as an inflator of the selection. E.g. we might want + // to select more text than was contained in an envelope of all extractor + // spans. + GROUP_DUMMY1 = 12, + + GROUP_DUMMY2 = 13, +} + +namespace libtextclassifier2; +table CompressedBuffer { + buffer:[ubyte]; + uncompressed_size:int; +} + +// Options for the model that predicts text selection. +namespace libtextclassifier2; +table SelectionModelOptions { + // If true, before the selection is returned, the unpaired brackets contained + // in the predicted selection are stripped from the both selection ends. + // The bracket codepoints are defined in the Unicode standard: + // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt + strip_unpaired_brackets:bool = 1; + + // Number of hypothetical click positions on either side of the actual click + // to consider in order to enforce symmetry. + symmetry_context_size:int; + + // Number of examples to bundle in one batch for inference. + batch_size:int = 1024; + + // Whether to always classify a suggested selection or only on demand. + always_classify_suggested_selection:bool = 0; +} + +// Options for the model that classifies a text selection. +namespace libtextclassifier2; +table ClassificationModelOptions { + // Limits for phone numbers. + phone_min_num_digits:int = 7; + + phone_max_num_digits:int = 15; + + // Limits for addresses. + address_min_num_tokens:int; + + // Maximum number of tokens to attempt a classification (-1 is unlimited). + max_num_tokens:int = -1; +} + +// List of regular expression matchers to check. +namespace libtextclassifier2.RegexModel_; +table Pattern { + // The name of the collection of a match. + collection_name:string; + + // The pattern to check. + // Can specify a single capturing group used as match boundaries. + pattern:string; + + // The modes for which to apply the patterns. + enabled_modes:libtextclassifier2.ModeFlag = ALL; + + // The final score to assign to the results of this pattern. + target_classification_score:float = 1; + + // Priority score used for conflict resolution with the other models. + priority_score:float = 0; + + // If true, will use an approximate matching implementation implemented + // using Find() instead of the true Match(). This approximate matching will + // use the first Find() result and then check that it spans the whole input. + use_approximate_matching:bool = 0; + + compressed_pattern:libtextclassifier2.CompressedBuffer; +} + +namespace libtextclassifier2; +table RegexModel { + patterns:[libtextclassifier2.RegexModel_.Pattern]; +} + +// List of regex patterns. +namespace libtextclassifier2.DatetimeModelPattern_; +table Regex { + pattern:string; + + // The ith entry specifies the type of the ith capturing group. + // This is used to decide how the matched content has to be parsed. + groups:[libtextclassifier2.DatetimeGroupType]; + + compressed_pattern:libtextclassifier2.CompressedBuffer; +} + +namespace libtextclassifier2; +table DatetimeModelPattern { + regexes:[libtextclassifier2.DatetimeModelPattern_.Regex]; + + // List of locale indices in DatetimeModel that represent the locales that + // these patterns should be used for. If empty, can be used for all locales. + locales:[int]; + + // The final score to assign to the results of this pattern. + target_classification_score:float = 1; + + // Priority score used for conflict resulution with the other models. + priority_score:float = 0; + + // The modes for which to apply the patterns. + enabled_modes:libtextclassifier2.ModeFlag = ALL; +} + +namespace libtextclassifier2; +table DatetimeModelExtractor { + extractor:libtextclassifier2.DatetimeExtractorType; + pattern:string; + locales:[int]; + compressed_pattern:libtextclassifier2.CompressedBuffer; +} + +namespace libtextclassifier2; +table DatetimeModel { + // List of BCP 47 locale strings representing all locales supported by the + // model. The individual patterns refer back to them using an index. + locales:[string]; + + patterns:[libtextclassifier2.DatetimeModelPattern]; + extractors:[libtextclassifier2.DatetimeModelExtractor]; + + // If true, will use the extractors for determining the match location as + // opposed to using the location where the global pattern matched. + use_extractors_for_locating:bool = 1; + + // List of locale ids, rules of whose are always run, after the requested + // ones. + default_locales:[int]; +} + +namespace libtextclassifier2.DatetimeModelLibrary_; +table Item { + key:string; + value:libtextclassifier2.DatetimeModel; +} + +// A set of named DateTime models. +namespace libtextclassifier2; +table DatetimeModelLibrary { + models:[libtextclassifier2.DatetimeModelLibrary_.Item]; +} + +// Options controlling the output of the Tensorflow Lite models. +namespace libtextclassifier2; +table ModelTriggeringOptions { + // Lower bound threshold for filtering annotation model outputs. + min_annotate_confidence:float = 0; + + // The modes for which to enable the models. + enabled_modes:libtextclassifier2.ModeFlag = ALL; +} + +// Options controlling the output of the classifier. +namespace libtextclassifier2; +table OutputOptions { + // Lists of collection names that will be filtered out at the output: + // - For annotation, the spans of given collection are simply dropped. + // - For classification, the result is mapped to the class "other". + // - For selection, the spans of given class are returned as + // single-selection. + filtered_collections_annotation:[string]; + + filtered_collections_classification:[string]; + filtered_collections_selection:[string]; +} + +namespace libtextclassifier2; +table Model { + // Comma-separated list of locales supported by the model as BCP 47 tags. + locales:string; + + version:int; + + // A name for the model that can be used for e.g. logging. + name:string; + + selection_feature_options:libtextclassifier2.FeatureProcessorOptions; + classification_feature_options:libtextclassifier2.FeatureProcessorOptions; + + // Tensorflow Lite models. + selection_model:[ubyte] (force_align: 16); + + classification_model:[ubyte] (force_align: 16); + embedding_model:[ubyte] (force_align: 16); + + // Options for the different models. + selection_options:libtextclassifier2.SelectionModelOptions; + + classification_options:libtextclassifier2.ClassificationModelOptions; + regex_model:libtextclassifier2.RegexModel; + datetime_model:libtextclassifier2.DatetimeModel; + + // Options controlling the output of the models. + triggering_options:libtextclassifier2.ModelTriggeringOptions; + + // Global switch that controls if SuggestSelection(), ClassifyText() and + // Annotate() will run. If a mode is disabled it returns empty/no-op results. + enabled_modes:libtextclassifier2.ModeFlag = ALL; + + // If true, will snap the selections that consist only of whitespaces to the + // containing suggested span. Otherwise, no suggestion is proposed, since the + // selections are not part of any token. + snap_whitespace_selections:bool = 1; + + // Global configuration for the output of SuggestSelection(), ClassifyText() + // and Annotate(). + output_options:libtextclassifier2.OutputOptions; +} + +// Role of the codepoints in the range. +namespace libtextclassifier2.TokenizationCodepointRange_; +enum Role : int { + // Concatenates the codepoint to the current run of codepoints. + DEFAULT_ROLE = 0, + + // Splits a run of codepoints before the current codepoint. + SPLIT_BEFORE = 1, + + // Splits a run of codepoints after the current codepoint. + SPLIT_AFTER = 2, + + // Each codepoint will be a separate token. Good e.g. for Chinese + // characters. + TOKEN_SEPARATOR = 3, + + // Discards the codepoint. + DISCARD_CODEPOINT = 4, + + // Common values: + // Splits on the characters and discards them. Good e.g. for the space + // character. + WHITESPACE_SEPARATOR = 7, +} + +// Represents a codepoint range [start, end) with its role for tokenization. +namespace libtextclassifier2; +table TokenizationCodepointRange { + start:int; + end:int; + role:libtextclassifier2.TokenizationCodepointRange_.Role; + + // Integer identifier of the script this range denotes. Negative values are + // reserved for Tokenizer's internal use. + script_id:int; +} + +// Method for selecting the center token. +namespace libtextclassifier2.FeatureProcessorOptions_; +enum CenterTokenSelectionMethod : int { + DEFAULT_CENTER_TOKEN_METHOD = 0, + + // Use click indices to determine the center token. + CENTER_TOKEN_FROM_CLICK = 1, + + // Use selection indices to get a token range, and select the middle of it + // as the center token. + CENTER_TOKEN_MIDDLE_OF_SELECTION = 2, +} + +// Controls the type of tokenization the model will use for the input text. +namespace libtextclassifier2.FeatureProcessorOptions_; +enum TokenizationType : int { + INVALID_TOKENIZATION_TYPE = 0, + + // Use the internal tokenizer for tokenization. + INTERNAL_TOKENIZER = 1, + + // Use ICU for tokenization. + ICU = 2, + + // First apply ICU tokenization. Then identify stretches of tokens + // consisting only of codepoints in internal_tokenizer_codepoint_ranges + // and re-tokenize them using the internal tokenizer. + MIXED = 3, +} + +// Range of codepoints start - end, where end is exclusive. +namespace libtextclassifier2.FeatureProcessorOptions_; +table CodepointRange { + start:int; + end:int; +} + +// Bounds-sensitive feature extraction configuration. +namespace libtextclassifier2.FeatureProcessorOptions_; +table BoundsSensitiveFeatures { + // Enables the extraction of bounds-sensitive features, instead of the click + // context features. + enabled:bool; + + // The numbers of tokens to extract in specific locations relative to the + // bounds. + // Immediately before the span. + num_tokens_before:int; + + // Inside the span, aligned with the beginning. + num_tokens_inside_left:int; + + // Inside the span, aligned with the end. + num_tokens_inside_right:int; + + // Immediately after the span. + num_tokens_after:int; + + // If true, also extracts the tokens of the entire span and adds up their + // features forming one "token" to include in the extracted features. + include_inside_bag:bool; + + // If true, includes the selection length (in the number of tokens) as a + // feature. + include_inside_length:bool; + + // If true, for selection, single token spans are not run through the model + // and their score is assumed to be zero. + score_single_token_spans_as_zero:bool; +} + +namespace libtextclassifier2.FeatureProcessorOptions_; +table AlternativeCollectionMapEntry { + key:string; + value:string; +} + +namespace libtextclassifier2; +table FeatureProcessorOptions { + // Number of buckets used for hashing charactergrams. + num_buckets:int = -1; + + // Size of the embedding. + embedding_size:int = -1; + + // Number of bits for quantization for embeddings. + embedding_quantization_bits:int = 8; + + // Context size defines the number of words to the left and to the right of + // the selected word to be used as context. For example, if context size is + // N, then we take N words to the left and N words to the right of the + // selected word as its context. + context_size:int = -1; + + // Maximum number of words of the context to select in total. + max_selection_span:int = -1; + + // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3 + // character trigrams etc. + chargram_orders:[int]; + + // Maximum length of a word, in codepoints. + max_word_length:int = 20; + + // If true, will use the unicode-aware functionality for extracting features. + unicode_aware_features:bool = 0; + + // Whether to extract the token case feature. + extract_case_feature:bool = 0; + + // Whether to extract the selection mask feature. + extract_selection_mask_feature:bool = 0; + + // List of regexps to run over each token. For each regexp, if there is a + // match, a dense feature of 1.0 is emitted. Otherwise -1.0 is used. + regexp_feature:[string]; + + // Whether to remap all digits to a single number. + remap_digits:bool = 0; + + // Whether to lower-case each token before generating hashgrams. + lowercase_tokens:bool; + + // If true, the selection classifier output will contain only the selections + // that are feasible (e.g., those that are shorter than max_selection_span), + // if false, the output will be a complete cross-product of possible + // selections to the left and possible selections to the right, including the + // infeasible ones. + // NOTE: Exists mainly for compatibility with older models that were trained + // with the non-reduced output space. + selection_reduced_output_space:bool = 1; + + // Collection names. + collections:[string]; + + // An index of collection in collections to be used if a collection name can't + // be mapped to an id. + default_collection:int = -1; + + // If true, will split the input by lines, and only use the line that contains + // the clicked token. + only_use_line_with_click:bool = 0; + + // If true, will split tokens that contain the selection boundary, at the + // position of the boundary. + // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" + split_tokens_on_selection_boundaries:bool = 0; + + // Codepoint ranges that determine how different codepoints are tokenized. + // The ranges must not overlap. + tokenization_codepoint_config:[libtextclassifier2.TokenizationCodepointRange]; + + center_token_selection_method:libtextclassifier2.FeatureProcessorOptions_.CenterTokenSelectionMethod; + + // If true, span boundaries will be snapped to containing tokens and not + // required to exactly match token boundaries. + snap_label_span_boundaries_to_containing_tokens:bool; + + // A set of codepoint ranges supported by the model. + supported_codepoint_ranges:[libtextclassifier2.FeatureProcessorOptions_.CodepointRange]; + + // A set of codepoint ranges to use in the mixed tokenization mode to identify + // stretches of tokens to re-tokenize using the internal tokenizer. + internal_tokenizer_codepoint_ranges:[libtextclassifier2.FeatureProcessorOptions_.CodepointRange]; + + // Minimum ratio of supported codepoints in the input context. If the ratio + // is lower than this, the feature computation will fail. + min_supported_codepoint_ratio:float = 0; + + // Used for versioning the format of features the model expects. + // - feature_version == 0: + // For each token the features consist of: + // - chargram embeddings + // - dense features + // Chargram embeddings for tokens are concatenated first together, + // and at the end, the dense features for the tokens are concatenated + // to it. So the resulting feature vector has two regions. + feature_version:int = 0; + + tokenization_type:libtextclassifier2.FeatureProcessorOptions_.TokenizationType = INTERNAL_TOKENIZER; + icu_preserve_whitespace_tokens:bool = 0; + + // List of codepoints that will be stripped from beginning and end of + // predicted spans. + ignored_span_boundary_codepoints:[int]; + + bounds_sensitive_features:libtextclassifier2.FeatureProcessorOptions_.BoundsSensitiveFeatures; + + // List of allowed charactergrams. The extracted charactergrams are filtered + // using this list, and charactergrams that are not present are interpreted as + // out-of-vocabulary. + // If no allowed_chargrams are specified, all charactergrams are allowed. + // The field is typed as bytes type to allow non-UTF8 chargrams. + allowed_chargrams:[string]; + + // If true, tokens will be also split when the codepoint's script_id changes + // as defined in TokenizationCodepointRange. + tokenize_on_script_change:bool = 0; +} + +root_type libtextclassifier2.Model; diff --git a/model_generated.h b/model_generated.h new file mode 100755 index 0000000..6ef75f6 --- /dev/null +++ b/model_generated.h @@ -0,0 +1,3718 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_H_ +#define FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace libtextclassifier2 { + +struct CompressedBuffer; +struct CompressedBufferT; + +struct SelectionModelOptions; +struct SelectionModelOptionsT; + +struct ClassificationModelOptions; +struct ClassificationModelOptionsT; + +namespace RegexModel_ { + +struct Pattern; +struct PatternT; + +} // namespace RegexModel_ + +struct RegexModel; +struct RegexModelT; + +namespace DatetimeModelPattern_ { + +struct Regex; +struct RegexT; + +} // namespace DatetimeModelPattern_ + +struct DatetimeModelPattern; +struct DatetimeModelPatternT; + +struct DatetimeModelExtractor; +struct DatetimeModelExtractorT; + +struct DatetimeModel; +struct DatetimeModelT; + +namespace DatetimeModelLibrary_ { + +struct Item; +struct ItemT; + +} // namespace DatetimeModelLibrary_ + +struct DatetimeModelLibrary; +struct DatetimeModelLibraryT; + +struct ModelTriggeringOptions; +struct ModelTriggeringOptionsT; + +struct OutputOptions; +struct OutputOptionsT; + +struct Model; +struct ModelT; + +struct TokenizationCodepointRange; +struct TokenizationCodepointRangeT; + +namespace FeatureProcessorOptions_ { + +struct CodepointRange; +struct CodepointRangeT; + +struct BoundsSensitiveFeatures; +struct BoundsSensitiveFeaturesT; + +struct AlternativeCollectionMapEntry; +struct AlternativeCollectionMapEntryT; + +} // namespace FeatureProcessorOptions_ + +struct FeatureProcessorOptions; +struct FeatureProcessorOptionsT; + +enum ModeFlag { + ModeFlag_NONE = 0, + ModeFlag_ANNOTATION = 1, + ModeFlag_CLASSIFICATION = 2, + ModeFlag_ANNOTATION_AND_CLASSIFICATION = 3, + ModeFlag_SELECTION = 4, + ModeFlag_ANNOTATION_AND_SELECTION = 5, + ModeFlag_CLASSIFICATION_AND_SELECTION = 6, + ModeFlag_ALL = 7, + ModeFlag_MIN = ModeFlag_NONE, + ModeFlag_MAX = ModeFlag_ALL +}; + +inline ModeFlag (&EnumValuesModeFlag())[8] { + static ModeFlag values[] = { + ModeFlag_NONE, + ModeFlag_ANNOTATION, + ModeFlag_CLASSIFICATION, + ModeFlag_ANNOTATION_AND_CLASSIFICATION, + ModeFlag_SELECTION, + ModeFlag_ANNOTATION_AND_SELECTION, + ModeFlag_CLASSIFICATION_AND_SELECTION, + ModeFlag_ALL + }; + return values; +} + +inline const char **EnumNamesModeFlag() { + static const char *names[] = { + "NONE", + "ANNOTATION", + "CLASSIFICATION", + "ANNOTATION_AND_CLASSIFICATION", + "SELECTION", + "ANNOTATION_AND_SELECTION", + "CLASSIFICATION_AND_SELECTION", + "ALL", + nullptr + }; + return names; +} + +inline const char *EnumNameModeFlag(ModeFlag e) { + const size_t index = static_cast<int>(e); + return EnumNamesModeFlag()[index]; +} + +enum DatetimeExtractorType { + DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0, + DatetimeExtractorType_AM = 1, + DatetimeExtractorType_PM = 2, + DatetimeExtractorType_JANUARY = 3, + DatetimeExtractorType_FEBRUARY = 4, + DatetimeExtractorType_MARCH = 5, + DatetimeExtractorType_APRIL = 6, + DatetimeExtractorType_MAY = 7, + DatetimeExtractorType_JUNE = 8, + DatetimeExtractorType_JULY = 9, + DatetimeExtractorType_AUGUST = 10, + DatetimeExtractorType_SEPTEMBER = 11, + DatetimeExtractorType_OCTOBER = 12, + DatetimeExtractorType_NOVEMBER = 13, + DatetimeExtractorType_DECEMBER = 14, + DatetimeExtractorType_NEXT = 15, + DatetimeExtractorType_NEXT_OR_SAME = 16, + DatetimeExtractorType_LAST = 17, + DatetimeExtractorType_NOW = 18, + DatetimeExtractorType_TOMORROW = 19, + DatetimeExtractorType_YESTERDAY = 20, + DatetimeExtractorType_PAST = 21, + DatetimeExtractorType_FUTURE = 22, + DatetimeExtractorType_DAY = 23, + DatetimeExtractorType_WEEK = 24, + DatetimeExtractorType_MONTH = 25, + DatetimeExtractorType_YEAR = 26, + DatetimeExtractorType_MONDAY = 27, + DatetimeExtractorType_TUESDAY = 28, + DatetimeExtractorType_WEDNESDAY = 29, + DatetimeExtractorType_THURSDAY = 30, + DatetimeExtractorType_FRIDAY = 31, + DatetimeExtractorType_SATURDAY = 32, + DatetimeExtractorType_SUNDAY = 33, + DatetimeExtractorType_DAYS = 34, + DatetimeExtractorType_WEEKS = 35, + DatetimeExtractorType_MONTHS = 36, + DatetimeExtractorType_HOURS = 37, + DatetimeExtractorType_MINUTES = 38, + DatetimeExtractorType_SECONDS = 39, + DatetimeExtractorType_YEARS = 40, + DatetimeExtractorType_DIGITS = 41, + DatetimeExtractorType_SIGNEDDIGITS = 42, + DatetimeExtractorType_ZERO = 43, + DatetimeExtractorType_ONE = 44, + DatetimeExtractorType_TWO = 45, + DatetimeExtractorType_THREE = 46, + DatetimeExtractorType_FOUR = 47, + DatetimeExtractorType_FIVE = 48, + DatetimeExtractorType_SIX = 49, + DatetimeExtractorType_SEVEN = 50, + DatetimeExtractorType_EIGHT = 51, + DatetimeExtractorType_NINE = 52, + DatetimeExtractorType_TEN = 53, + DatetimeExtractorType_ELEVEN = 54, + DatetimeExtractorType_TWELVE = 55, + DatetimeExtractorType_THIRTEEN = 56, + DatetimeExtractorType_FOURTEEN = 57, + DatetimeExtractorType_FIFTEEN = 58, + DatetimeExtractorType_SIXTEEN = 59, + DatetimeExtractorType_SEVENTEEN = 60, + DatetimeExtractorType_EIGHTEEN = 61, + DatetimeExtractorType_NINETEEN = 62, + DatetimeExtractorType_TWENTY = 63, + DatetimeExtractorType_THIRTY = 64, + DatetimeExtractorType_FORTY = 65, + DatetimeExtractorType_FIFTY = 66, + DatetimeExtractorType_SIXTY = 67, + DatetimeExtractorType_SEVENTY = 68, + DatetimeExtractorType_EIGHTY = 69, + DatetimeExtractorType_NINETY = 70, + DatetimeExtractorType_HUNDRED = 71, + DatetimeExtractorType_THOUSAND = 72, + DatetimeExtractorType_MIN = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE, + DatetimeExtractorType_MAX = DatetimeExtractorType_THOUSAND +}; + +inline DatetimeExtractorType (&EnumValuesDatetimeExtractorType())[73] { + static DatetimeExtractorType values[] = { + DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE, + DatetimeExtractorType_AM, + DatetimeExtractorType_PM, + DatetimeExtractorType_JANUARY, + DatetimeExtractorType_FEBRUARY, + DatetimeExtractorType_MARCH, + DatetimeExtractorType_APRIL, + DatetimeExtractorType_MAY, + DatetimeExtractorType_JUNE, + DatetimeExtractorType_JULY, + DatetimeExtractorType_AUGUST, + DatetimeExtractorType_SEPTEMBER, + DatetimeExtractorType_OCTOBER, + DatetimeExtractorType_NOVEMBER, + DatetimeExtractorType_DECEMBER, + DatetimeExtractorType_NEXT, + DatetimeExtractorType_NEXT_OR_SAME, + DatetimeExtractorType_LAST, + DatetimeExtractorType_NOW, + DatetimeExtractorType_TOMORROW, + DatetimeExtractorType_YESTERDAY, + DatetimeExtractorType_PAST, + DatetimeExtractorType_FUTURE, + DatetimeExtractorType_DAY, + DatetimeExtractorType_WEEK, + DatetimeExtractorType_MONTH, + DatetimeExtractorType_YEAR, + DatetimeExtractorType_MONDAY, + DatetimeExtractorType_TUESDAY, + DatetimeExtractorType_WEDNESDAY, + DatetimeExtractorType_THURSDAY, + DatetimeExtractorType_FRIDAY, + DatetimeExtractorType_SATURDAY, + DatetimeExtractorType_SUNDAY, + DatetimeExtractorType_DAYS, + DatetimeExtractorType_WEEKS, + DatetimeExtractorType_MONTHS, + DatetimeExtractorType_HOURS, + DatetimeExtractorType_MINUTES, + DatetimeExtractorType_SECONDS, + DatetimeExtractorType_YEARS, + DatetimeExtractorType_DIGITS, + DatetimeExtractorType_SIGNEDDIGITS, + DatetimeExtractorType_ZERO, + DatetimeExtractorType_ONE, + DatetimeExtractorType_TWO, + DatetimeExtractorType_THREE, + DatetimeExtractorType_FOUR, + DatetimeExtractorType_FIVE, + DatetimeExtractorType_SIX, + DatetimeExtractorType_SEVEN, + DatetimeExtractorType_EIGHT, + DatetimeExtractorType_NINE, + DatetimeExtractorType_TEN, + DatetimeExtractorType_ELEVEN, + DatetimeExtractorType_TWELVE, + DatetimeExtractorType_THIRTEEN, + DatetimeExtractorType_FOURTEEN, + DatetimeExtractorType_FIFTEEN, + DatetimeExtractorType_SIXTEEN, + DatetimeExtractorType_SEVENTEEN, + DatetimeExtractorType_EIGHTEEN, + DatetimeExtractorType_NINETEEN, + DatetimeExtractorType_TWENTY, + DatetimeExtractorType_THIRTY, + DatetimeExtractorType_FORTY, + DatetimeExtractorType_FIFTY, + DatetimeExtractorType_SIXTY, + DatetimeExtractorType_SEVENTY, + DatetimeExtractorType_EIGHTY, + DatetimeExtractorType_NINETY, + DatetimeExtractorType_HUNDRED, + DatetimeExtractorType_THOUSAND + }; + return values; +} + +inline const char **EnumNamesDatetimeExtractorType() { + static const char *names[] = { + "UNKNOWN_DATETIME_EXTRACTOR_TYPE", + "AM", + "PM", + "JANUARY", + "FEBRUARY", + "MARCH", + "APRIL", + "MAY", + "JUNE", + "JULY", + "AUGUST", + "SEPTEMBER", + "OCTOBER", + "NOVEMBER", + "DECEMBER", + "NEXT", + "NEXT_OR_SAME", + "LAST", + "NOW", + "TOMORROW", + "YESTERDAY", + "PAST", + "FUTURE", + "DAY", + "WEEK", + "MONTH", + "YEAR", + "MONDAY", + "TUESDAY", + "WEDNESDAY", + "THURSDAY", + "FRIDAY", + "SATURDAY", + "SUNDAY", + "DAYS", + "WEEKS", + "MONTHS", + "HOURS", + "MINUTES", + "SECONDS", + "YEARS", + "DIGITS", + "SIGNEDDIGITS", + "ZERO", + "ONE", + "TWO", + "THREE", + "FOUR", + "FIVE", + "SIX", + "SEVEN", + "EIGHT", + "NINE", + "TEN", + "ELEVEN", + "TWELVE", + "THIRTEEN", + "FOURTEEN", + "FIFTEEN", + "SIXTEEN", + "SEVENTEEN", + "EIGHTEEN", + "NINETEEN", + "TWENTY", + "THIRTY", + "FORTY", + "FIFTY", + "SIXTY", + "SEVENTY", + "EIGHTY", + "NINETY", + "HUNDRED", + "THOUSAND", + nullptr + }; + return names; +} + +inline const char *EnumNameDatetimeExtractorType(DatetimeExtractorType e) { + const size_t index = static_cast<int>(e); + return EnumNamesDatetimeExtractorType()[index]; +} + +enum DatetimeGroupType { + DatetimeGroupType_GROUP_UNKNOWN = 0, + DatetimeGroupType_GROUP_UNUSED = 1, + DatetimeGroupType_GROUP_YEAR = 2, + DatetimeGroupType_GROUP_MONTH = 3, + DatetimeGroupType_GROUP_DAY = 4, + DatetimeGroupType_GROUP_HOUR = 5, + DatetimeGroupType_GROUP_MINUTE = 6, + DatetimeGroupType_GROUP_SECOND = 7, + DatetimeGroupType_GROUP_AMPM = 8, + DatetimeGroupType_GROUP_RELATIONDISTANCE = 9, + DatetimeGroupType_GROUP_RELATION = 10, + DatetimeGroupType_GROUP_RELATIONTYPE = 11, + DatetimeGroupType_GROUP_DUMMY1 = 12, + DatetimeGroupType_GROUP_DUMMY2 = 13, + DatetimeGroupType_MIN = DatetimeGroupType_GROUP_UNKNOWN, + DatetimeGroupType_MAX = DatetimeGroupType_GROUP_DUMMY2 +}; + +inline DatetimeGroupType (&EnumValuesDatetimeGroupType())[14] { + static DatetimeGroupType values[] = { + DatetimeGroupType_GROUP_UNKNOWN, + DatetimeGroupType_GROUP_UNUSED, + DatetimeGroupType_GROUP_YEAR, + DatetimeGroupType_GROUP_MONTH, + DatetimeGroupType_GROUP_DAY, + DatetimeGroupType_GROUP_HOUR, + DatetimeGroupType_GROUP_MINUTE, + DatetimeGroupType_GROUP_SECOND, + DatetimeGroupType_GROUP_AMPM, + DatetimeGroupType_GROUP_RELATIONDISTANCE, + DatetimeGroupType_GROUP_RELATION, + DatetimeGroupType_GROUP_RELATIONTYPE, + DatetimeGroupType_GROUP_DUMMY1, + DatetimeGroupType_GROUP_DUMMY2 + }; + return values; +} + +inline const char **EnumNamesDatetimeGroupType() { + static const char *names[] = { + "GROUP_UNKNOWN", + "GROUP_UNUSED", + "GROUP_YEAR", + "GROUP_MONTH", + "GROUP_DAY", + "GROUP_HOUR", + "GROUP_MINUTE", + "GROUP_SECOND", + "GROUP_AMPM", + "GROUP_RELATIONDISTANCE", + "GROUP_RELATION", + "GROUP_RELATIONTYPE", + "GROUP_DUMMY1", + "GROUP_DUMMY2", + nullptr + }; + return names; +} + +inline const char *EnumNameDatetimeGroupType(DatetimeGroupType e) { + const size_t index = static_cast<int>(e); + return EnumNamesDatetimeGroupType()[index]; +} + +namespace TokenizationCodepointRange_ { + +enum Role { + Role_DEFAULT_ROLE = 0, + Role_SPLIT_BEFORE = 1, + Role_SPLIT_AFTER = 2, + Role_TOKEN_SEPARATOR = 3, + Role_DISCARD_CODEPOINT = 4, + Role_WHITESPACE_SEPARATOR = 7, + Role_MIN = Role_DEFAULT_ROLE, + Role_MAX = Role_WHITESPACE_SEPARATOR +}; + +inline Role (&EnumValuesRole())[6] { + static Role values[] = { + Role_DEFAULT_ROLE, + Role_SPLIT_BEFORE, + Role_SPLIT_AFTER, + Role_TOKEN_SEPARATOR, + Role_DISCARD_CODEPOINT, + Role_WHITESPACE_SEPARATOR + }; + return values; +} + +inline const char **EnumNamesRole() { + static const char *names[] = { + "DEFAULT_ROLE", + "SPLIT_BEFORE", + "SPLIT_AFTER", + "TOKEN_SEPARATOR", + "DISCARD_CODEPOINT", + "", + "", + "WHITESPACE_SEPARATOR", + nullptr + }; + return names; +} + +inline const char *EnumNameRole(Role e) { + const size_t index = static_cast<int>(e); + return EnumNamesRole()[index]; +} + +} // namespace TokenizationCodepointRange_ + +namespace FeatureProcessorOptions_ { + +enum CenterTokenSelectionMethod { + CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD = 0, + CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK = 1, + CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION = 2, + CenterTokenSelectionMethod_MIN = CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD, + CenterTokenSelectionMethod_MAX = CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION +}; + +inline CenterTokenSelectionMethod (&EnumValuesCenterTokenSelectionMethod())[3] { + static CenterTokenSelectionMethod values[] = { + CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD, + CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK, + CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION + }; + return values; +} + +inline const char **EnumNamesCenterTokenSelectionMethod() { + static const char *names[] = { + "DEFAULT_CENTER_TOKEN_METHOD", + "CENTER_TOKEN_FROM_CLICK", + "CENTER_TOKEN_MIDDLE_OF_SELECTION", + nullptr + }; + return names; +} + +inline const char *EnumNameCenterTokenSelectionMethod(CenterTokenSelectionMethod e) { + const size_t index = static_cast<int>(e); + return EnumNamesCenterTokenSelectionMethod()[index]; +} + +enum TokenizationType { + TokenizationType_INVALID_TOKENIZATION_TYPE = 0, + TokenizationType_INTERNAL_TOKENIZER = 1, + TokenizationType_ICU = 2, + TokenizationType_MIXED = 3, + TokenizationType_MIN = TokenizationType_INVALID_TOKENIZATION_TYPE, + TokenizationType_MAX = TokenizationType_MIXED +}; + +inline TokenizationType (&EnumValuesTokenizationType())[4] { + static TokenizationType values[] = { + TokenizationType_INVALID_TOKENIZATION_TYPE, + TokenizationType_INTERNAL_TOKENIZER, + TokenizationType_ICU, + TokenizationType_MIXED + }; + return values; +} + +inline const char **EnumNamesTokenizationType() { + static const char *names[] = { + "INVALID_TOKENIZATION_TYPE", + "INTERNAL_TOKENIZER", + "ICU", + "MIXED", + nullptr + }; + return names; +} + +inline const char *EnumNameTokenizationType(TokenizationType e) { + const size_t index = static_cast<int>(e); + return EnumNamesTokenizationType()[index]; +} + +} // namespace FeatureProcessorOptions_ + +struct CompressedBufferT : public flatbuffers::NativeTable { + typedef CompressedBuffer TableType; + std::vector<uint8_t> buffer; + int32_t uncompressed_size; + CompressedBufferT() + : uncompressed_size(0) { + } +}; + +struct CompressedBuffer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CompressedBufferT NativeTableType; + enum { + VT_BUFFER = 4, + VT_UNCOMPRESSED_SIZE = 6 + }; + const flatbuffers::Vector<uint8_t> *buffer() const { + return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_BUFFER); + } + int32_t uncompressed_size() const { + return GetField<int32_t>(VT_UNCOMPRESSED_SIZE, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_BUFFER) && + verifier.Verify(buffer()) && + VerifyField<int32_t>(verifier, VT_UNCOMPRESSED_SIZE) && + verifier.EndTable(); + } + CompressedBufferT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CompressedBufferT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<CompressedBuffer> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CompressedBufferBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_buffer(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer) { + fbb_.AddOffset(CompressedBuffer::VT_BUFFER, buffer); + } + void add_uncompressed_size(int32_t uncompressed_size) { + fbb_.AddElement<int32_t>(CompressedBuffer::VT_UNCOMPRESSED_SIZE, uncompressed_size, 0); + } + explicit CompressedBufferBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CompressedBufferBuilder &operator=(const CompressedBufferBuilder &); + flatbuffers::Offset<CompressedBuffer> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<CompressedBuffer>(end); + return o; + } +}; + +inline flatbuffers::Offset<CompressedBuffer> CreateCompressedBuffer( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer = 0, + int32_t uncompressed_size = 0) { + CompressedBufferBuilder builder_(_fbb); + builder_.add_uncompressed_size(uncompressed_size); + builder_.add_buffer(buffer); + return builder_.Finish(); +} + +inline flatbuffers::Offset<CompressedBuffer> CreateCompressedBufferDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<uint8_t> *buffer = nullptr, + int32_t uncompressed_size = 0) { + return libtextclassifier2::CreateCompressedBuffer( + _fbb, + buffer ? _fbb.CreateVector<uint8_t>(*buffer) : 0, + uncompressed_size); +} + +flatbuffers::Offset<CompressedBuffer> CreateCompressedBuffer(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SelectionModelOptionsT : public flatbuffers::NativeTable { + typedef SelectionModelOptions TableType; + bool strip_unpaired_brackets; + int32_t symmetry_context_size; + int32_t batch_size; + bool always_classify_suggested_selection; + SelectionModelOptionsT() + : strip_unpaired_brackets(true), + symmetry_context_size(0), + batch_size(1024), + always_classify_suggested_selection(false) { + } +}; + +struct SelectionModelOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SelectionModelOptionsT NativeTableType; + enum { + VT_STRIP_UNPAIRED_BRACKETS = 4, + VT_SYMMETRY_CONTEXT_SIZE = 6, + VT_BATCH_SIZE = 8, + VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION = 10 + }; + bool strip_unpaired_brackets() const { + return GetField<uint8_t>(VT_STRIP_UNPAIRED_BRACKETS, 1) != 0; + } + int32_t symmetry_context_size() const { + return GetField<int32_t>(VT_SYMMETRY_CONTEXT_SIZE, 0); + } + int32_t batch_size() const { + return GetField<int32_t>(VT_BATCH_SIZE, 1024); + } + bool always_classify_suggested_selection() const { + return GetField<uint8_t>(VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<uint8_t>(verifier, VT_STRIP_UNPAIRED_BRACKETS) && + VerifyField<int32_t>(verifier, VT_SYMMETRY_CONTEXT_SIZE) && + VerifyField<int32_t>(verifier, VT_BATCH_SIZE) && + VerifyField<uint8_t>(verifier, VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION) && + verifier.EndTable(); + } + SelectionModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SelectionModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<SelectionModelOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SelectionModelOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_strip_unpaired_brackets(bool strip_unpaired_brackets) { + fbb_.AddElement<uint8_t>(SelectionModelOptions::VT_STRIP_UNPAIRED_BRACKETS, static_cast<uint8_t>(strip_unpaired_brackets), 1); + } + void add_symmetry_context_size(int32_t symmetry_context_size) { + fbb_.AddElement<int32_t>(SelectionModelOptions::VT_SYMMETRY_CONTEXT_SIZE, symmetry_context_size, 0); + } + void add_batch_size(int32_t batch_size) { + fbb_.AddElement<int32_t>(SelectionModelOptions::VT_BATCH_SIZE, batch_size, 1024); + } + void add_always_classify_suggested_selection(bool always_classify_suggested_selection) { + fbb_.AddElement<uint8_t>(SelectionModelOptions::VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION, static_cast<uint8_t>(always_classify_suggested_selection), 0); + } + explicit SelectionModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SelectionModelOptionsBuilder &operator=(const SelectionModelOptionsBuilder &); + flatbuffers::Offset<SelectionModelOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<SelectionModelOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<SelectionModelOptions> CreateSelectionModelOptions( + flatbuffers::FlatBufferBuilder &_fbb, + bool strip_unpaired_brackets = true, + int32_t symmetry_context_size = 0, + int32_t batch_size = 1024, + bool always_classify_suggested_selection = false) { + SelectionModelOptionsBuilder builder_(_fbb); + builder_.add_batch_size(batch_size); + builder_.add_symmetry_context_size(symmetry_context_size); + builder_.add_always_classify_suggested_selection(always_classify_suggested_selection); + builder_.add_strip_unpaired_brackets(strip_unpaired_brackets); + return builder_.Finish(); +} + +flatbuffers::Offset<SelectionModelOptions> CreateSelectionModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ClassificationModelOptionsT : public flatbuffers::NativeTable { + typedef ClassificationModelOptions TableType; + int32_t phone_min_num_digits; + int32_t phone_max_num_digits; + int32_t address_min_num_tokens; + int32_t max_num_tokens; + ClassificationModelOptionsT() + : phone_min_num_digits(7), + phone_max_num_digits(15), + address_min_num_tokens(0), + max_num_tokens(-1) { + } +}; + +struct ClassificationModelOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ClassificationModelOptionsT NativeTableType; + enum { + VT_PHONE_MIN_NUM_DIGITS = 4, + VT_PHONE_MAX_NUM_DIGITS = 6, + VT_ADDRESS_MIN_NUM_TOKENS = 8, + VT_MAX_NUM_TOKENS = 10 + }; + int32_t phone_min_num_digits() const { + return GetField<int32_t>(VT_PHONE_MIN_NUM_DIGITS, 7); + } + int32_t phone_max_num_digits() const { + return GetField<int32_t>(VT_PHONE_MAX_NUM_DIGITS, 15); + } + int32_t address_min_num_tokens() const { + return GetField<int32_t>(VT_ADDRESS_MIN_NUM_TOKENS, 0); + } + int32_t max_num_tokens() const { + return GetField<int32_t>(VT_MAX_NUM_TOKENS, -1); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_PHONE_MIN_NUM_DIGITS) && + VerifyField<int32_t>(verifier, VT_PHONE_MAX_NUM_DIGITS) && + VerifyField<int32_t>(verifier, VT_ADDRESS_MIN_NUM_TOKENS) && + VerifyField<int32_t>(verifier, VT_MAX_NUM_TOKENS) && + verifier.EndTable(); + } + ClassificationModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ClassificationModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<ClassificationModelOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ClassificationModelOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_phone_min_num_digits(int32_t phone_min_num_digits) { + fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_PHONE_MIN_NUM_DIGITS, phone_min_num_digits, 7); + } + void add_phone_max_num_digits(int32_t phone_max_num_digits) { + fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_PHONE_MAX_NUM_DIGITS, phone_max_num_digits, 15); + } + void add_address_min_num_tokens(int32_t address_min_num_tokens) { + fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_ADDRESS_MIN_NUM_TOKENS, address_min_num_tokens, 0); + } + void add_max_num_tokens(int32_t max_num_tokens) { + fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_MAX_NUM_TOKENS, max_num_tokens, -1); + } + explicit ClassificationModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ClassificationModelOptionsBuilder &operator=(const ClassificationModelOptionsBuilder &); + flatbuffers::Offset<ClassificationModelOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ClassificationModelOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t phone_min_num_digits = 7, + int32_t phone_max_num_digits = 15, + int32_t address_min_num_tokens = 0, + int32_t max_num_tokens = -1) { + ClassificationModelOptionsBuilder builder_(_fbb); + builder_.add_max_num_tokens(max_num_tokens); + builder_.add_address_min_num_tokens(address_min_num_tokens); + builder_.add_phone_max_num_digits(phone_max_num_digits); + builder_.add_phone_min_num_digits(phone_min_num_digits); + return builder_.Finish(); +} + +flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +namespace RegexModel_ { + +struct PatternT : public flatbuffers::NativeTable { + typedef Pattern TableType; + std::string collection_name; + std::string pattern; + libtextclassifier2::ModeFlag enabled_modes; + float target_classification_score; + float priority_score; + bool use_approximate_matching; + std::unique_ptr<libtextclassifier2::CompressedBufferT> compressed_pattern; + PatternT() + : enabled_modes(libtextclassifier2::ModeFlag_ALL), + target_classification_score(1.0f), + priority_score(0.0f), + use_approximate_matching(false) { + } +}; + +struct Pattern FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef PatternT NativeTableType; + enum { + VT_COLLECTION_NAME = 4, + VT_PATTERN = 6, + VT_ENABLED_MODES = 8, + VT_TARGET_CLASSIFICATION_SCORE = 10, + VT_PRIORITY_SCORE = 12, + VT_USE_APPROXIMATE_MATCHING = 14, + VT_COMPRESSED_PATTERN = 16 + }; + const flatbuffers::String *collection_name() const { + return GetPointer<const flatbuffers::String *>(VT_COLLECTION_NAME); + } + const flatbuffers::String *pattern() const { + return GetPointer<const flatbuffers::String *>(VT_PATTERN); + } + libtextclassifier2::ModeFlag enabled_modes() const { + return static_cast<libtextclassifier2::ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7)); + } + float target_classification_score() const { + return GetField<float>(VT_TARGET_CLASSIFICATION_SCORE, 1.0f); + } + float priority_score() const { + return GetField<float>(VT_PRIORITY_SCORE, 0.0f); + } + bool use_approximate_matching() const { + return GetField<uint8_t>(VT_USE_APPROXIMATE_MATCHING, 0) != 0; + } + const libtextclassifier2::CompressedBuffer *compressed_pattern() const { + return GetPointer<const libtextclassifier2::CompressedBuffer *>(VT_COMPRESSED_PATTERN); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_COLLECTION_NAME) && + verifier.Verify(collection_name()) && + VerifyOffset(verifier, VT_PATTERN) && + verifier.Verify(pattern()) && + VerifyField<int32_t>(verifier, VT_ENABLED_MODES) && + VerifyField<float>(verifier, VT_TARGET_CLASSIFICATION_SCORE) && + VerifyField<float>(verifier, VT_PRIORITY_SCORE) && + VerifyField<uint8_t>(verifier, VT_USE_APPROXIMATE_MATCHING) && + VerifyOffset(verifier, VT_COMPRESSED_PATTERN) && + verifier.VerifyTable(compressed_pattern()) && + verifier.EndTable(); + } + PatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PatternT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<Pattern> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PatternT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PatternBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_collection_name(flatbuffers::Offset<flatbuffers::String> collection_name) { + fbb_.AddOffset(Pattern::VT_COLLECTION_NAME, collection_name); + } + void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) { + fbb_.AddOffset(Pattern::VT_PATTERN, pattern); + } + void add_enabled_modes(libtextclassifier2::ModeFlag enabled_modes) { + fbb_.AddElement<int32_t>(Pattern::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7); + } + void add_target_classification_score(float target_classification_score) { + fbb_.AddElement<float>(Pattern::VT_TARGET_CLASSIFICATION_SCORE, target_classification_score, 1.0f); + } + void add_priority_score(float priority_score) { + fbb_.AddElement<float>(Pattern::VT_PRIORITY_SCORE, priority_score, 0.0f); + } + void add_use_approximate_matching(bool use_approximate_matching) { + fbb_.AddElement<uint8_t>(Pattern::VT_USE_APPROXIMATE_MATCHING, static_cast<uint8_t>(use_approximate_matching), 0); + } + void add_compressed_pattern(flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern) { + fbb_.AddOffset(Pattern::VT_COMPRESSED_PATTERN, compressed_pattern); + } + explicit PatternBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + PatternBuilder &operator=(const PatternBuilder &); + flatbuffers::Offset<Pattern> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<Pattern>(end); + return o; + } +}; + +inline flatbuffers::Offset<Pattern> CreatePattern( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::String> collection_name = 0, + flatbuffers::Offset<flatbuffers::String> pattern = 0, + libtextclassifier2::ModeFlag enabled_modes = libtextclassifier2::ModeFlag_ALL, + float target_classification_score = 1.0f, + float priority_score = 0.0f, + bool use_approximate_matching = false, + flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) { + PatternBuilder builder_(_fbb); + builder_.add_compressed_pattern(compressed_pattern); + builder_.add_priority_score(priority_score); + builder_.add_target_classification_score(target_classification_score); + builder_.add_enabled_modes(enabled_modes); + builder_.add_pattern(pattern); + builder_.add_collection_name(collection_name); + builder_.add_use_approximate_matching(use_approximate_matching); + return builder_.Finish(); +} + +inline flatbuffers::Offset<Pattern> CreatePatternDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *collection_name = nullptr, + const char *pattern = nullptr, + libtextclassifier2::ModeFlag enabled_modes = libtextclassifier2::ModeFlag_ALL, + float target_classification_score = 1.0f, + float priority_score = 0.0f, + bool use_approximate_matching = false, + flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) { + return libtextclassifier2::RegexModel_::CreatePattern( + _fbb, + collection_name ? _fbb.CreateString(collection_name) : 0, + pattern ? _fbb.CreateString(pattern) : 0, + enabled_modes, + target_classification_score, + priority_score, + use_approximate_matching, + compressed_pattern); +} + +flatbuffers::Offset<Pattern> CreatePattern(flatbuffers::FlatBufferBuilder &_fbb, const PatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +} // namespace RegexModel_ + +struct RegexModelT : public flatbuffers::NativeTable { + typedef RegexModel TableType; + std::vector<std::unique_ptr<libtextclassifier2::RegexModel_::PatternT>> patterns; + RegexModelT() { + } +}; + +struct RegexModel FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef RegexModelT NativeTableType; + enum { + VT_PATTERNS = 4 + }; + const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> *patterns() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> *>(VT_PATTERNS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_PATTERNS) && + verifier.Verify(patterns()) && + verifier.VerifyVectorOfTables(patterns()) && + verifier.EndTable(); + } + RegexModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RegexModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<RegexModel> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RegexModelBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_patterns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>>> patterns) { + fbb_.AddOffset(RegexModel::VT_PATTERNS, patterns); + } + explicit RegexModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + RegexModelBuilder &operator=(const RegexModelBuilder &); + flatbuffers::Offset<RegexModel> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<RegexModel>(end); + return o; + } +}; + +inline flatbuffers::Offset<RegexModel> CreateRegexModel( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>>> patterns = 0) { + RegexModelBuilder builder_(_fbb); + builder_.add_patterns(patterns); + return builder_.Finish(); +} + +inline flatbuffers::Offset<RegexModel> CreateRegexModelDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> *patterns = nullptr) { + return libtextclassifier2::CreateRegexModel( + _fbb, + patterns ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>>(*patterns) : 0); +} + +flatbuffers::Offset<RegexModel> CreateRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +namespace DatetimeModelPattern_ { + +struct RegexT : public flatbuffers::NativeTable { + typedef Regex TableType; + std::string pattern; + std::vector<libtextclassifier2::DatetimeGroupType> groups; + std::unique_ptr<libtextclassifier2::CompressedBufferT> compressed_pattern; + RegexT() { + } +}; + +struct Regex FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef RegexT NativeTableType; + enum { + VT_PATTERN = 4, + VT_GROUPS = 6, + VT_COMPRESSED_PATTERN = 8 + }; + const flatbuffers::String *pattern() const { + return GetPointer<const flatbuffers::String *>(VT_PATTERN); + } + const flatbuffers::Vector<int32_t> *groups() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_GROUPS); + } + const libtextclassifier2::CompressedBuffer *compressed_pattern() const { + return GetPointer<const libtextclassifier2::CompressedBuffer *>(VT_COMPRESSED_PATTERN); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_PATTERN) && + verifier.Verify(pattern()) && + VerifyOffset(verifier, VT_GROUPS) && + verifier.Verify(groups()) && + VerifyOffset(verifier, VT_COMPRESSED_PATTERN) && + verifier.VerifyTable(compressed_pattern()) && + verifier.EndTable(); + } + RegexT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RegexT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<Regex> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RegexBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) { + fbb_.AddOffset(Regex::VT_PATTERN, pattern); + } + void add_groups(flatbuffers::Offset<flatbuffers::Vector<int32_t>> groups) { + fbb_.AddOffset(Regex::VT_GROUPS, groups); + } + void add_compressed_pattern(flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern) { + fbb_.AddOffset(Regex::VT_COMPRESSED_PATTERN, compressed_pattern); + } + explicit RegexBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + RegexBuilder &operator=(const RegexBuilder &); + flatbuffers::Offset<Regex> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<Regex>(end); + return o; + } +}; + +inline flatbuffers::Offset<Regex> CreateRegex( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::String> pattern = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> groups = 0, + flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) { + RegexBuilder builder_(_fbb); + builder_.add_compressed_pattern(compressed_pattern); + builder_.add_groups(groups); + builder_.add_pattern(pattern); + return builder_.Finish(); +} + +inline flatbuffers::Offset<Regex> CreateRegexDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *pattern = nullptr, + const std::vector<int32_t> *groups = nullptr, + flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) { + return libtextclassifier2::DatetimeModelPattern_::CreateRegex( + _fbb, + pattern ? _fbb.CreateString(pattern) : 0, + groups ? _fbb.CreateVector<int32_t>(*groups) : 0, + compressed_pattern); +} + +flatbuffers::Offset<Regex> CreateRegex(flatbuffers::FlatBufferBuilder &_fbb, const RegexT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +} // namespace DatetimeModelPattern_ + +struct DatetimeModelPatternT : public flatbuffers::NativeTable { + typedef DatetimeModelPattern TableType; + std::vector<std::unique_ptr<libtextclassifier2::DatetimeModelPattern_::RegexT>> regexes; + std::vector<int32_t> locales; + float target_classification_score; + float priority_score; + ModeFlag enabled_modes; + DatetimeModelPatternT() + : target_classification_score(1.0f), + priority_score(0.0f), + enabled_modes(ModeFlag_ALL) { + } +}; + +struct DatetimeModelPattern FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DatetimeModelPatternT NativeTableType; + enum { + VT_REGEXES = 4, + VT_LOCALES = 6, + VT_TARGET_CLASSIFICATION_SCORE = 8, + VT_PRIORITY_SCORE = 10, + VT_ENABLED_MODES = 12 + }; + const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> *regexes() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> *>(VT_REGEXES); + } + const flatbuffers::Vector<int32_t> *locales() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_LOCALES); + } + float target_classification_score() const { + return GetField<float>(VT_TARGET_CLASSIFICATION_SCORE, 1.0f); + } + float priority_score() const { + return GetField<float>(VT_PRIORITY_SCORE, 0.0f); + } + ModeFlag enabled_modes() const { + return static_cast<ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_REGEXES) && + verifier.Verify(regexes()) && + verifier.VerifyVectorOfTables(regexes()) && + VerifyOffset(verifier, VT_LOCALES) && + verifier.Verify(locales()) && + VerifyField<float>(verifier, VT_TARGET_CLASSIFICATION_SCORE) && + VerifyField<float>(verifier, VT_PRIORITY_SCORE) && + VerifyField<int32_t>(verifier, VT_ENABLED_MODES) && + verifier.EndTable(); + } + DatetimeModelPatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DatetimeModelPatternT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<DatetimeModelPattern> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DatetimeModelPatternBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_regexes(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>>> regexes) { + fbb_.AddOffset(DatetimeModelPattern::VT_REGEXES, regexes); + } + void add_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales) { + fbb_.AddOffset(DatetimeModelPattern::VT_LOCALES, locales); + } + void add_target_classification_score(float target_classification_score) { + fbb_.AddElement<float>(DatetimeModelPattern::VT_TARGET_CLASSIFICATION_SCORE, target_classification_score, 1.0f); + } + void add_priority_score(float priority_score) { + fbb_.AddElement<float>(DatetimeModelPattern::VT_PRIORITY_SCORE, priority_score, 0.0f); + } + void add_enabled_modes(ModeFlag enabled_modes) { + fbb_.AddElement<int32_t>(DatetimeModelPattern::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7); + } + explicit DatetimeModelPatternBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DatetimeModelPatternBuilder &operator=(const DatetimeModelPatternBuilder &); + flatbuffers::Offset<DatetimeModelPattern> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<DatetimeModelPattern>(end); + return o; + } +}; + +inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>>> regexes = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales = 0, + float target_classification_score = 1.0f, + float priority_score = 0.0f, + ModeFlag enabled_modes = ModeFlag_ALL) { + DatetimeModelPatternBuilder builder_(_fbb); + builder_.add_enabled_modes(enabled_modes); + builder_.add_priority_score(priority_score); + builder_.add_target_classification_score(target_classification_score); + builder_.add_locales(locales); + builder_.add_regexes(regexes); + return builder_.Finish(); +} + +inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPatternDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> *regexes = nullptr, + const std::vector<int32_t> *locales = nullptr, + float target_classification_score = 1.0f, + float priority_score = 0.0f, + ModeFlag enabled_modes = ModeFlag_ALL) { + return libtextclassifier2::CreateDatetimeModelPattern( + _fbb, + regexes ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>>(*regexes) : 0, + locales ? _fbb.CreateVector<int32_t>(*locales) : 0, + target_classification_score, + priority_score, + enabled_modes); +} + +flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DatetimeModelExtractorT : public flatbuffers::NativeTable { + typedef DatetimeModelExtractor TableType; + DatetimeExtractorType extractor; + std::string pattern; + std::vector<int32_t> locales; + std::unique_ptr<CompressedBufferT> compressed_pattern; + DatetimeModelExtractorT() + : extractor(DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE) { + } +}; + +struct DatetimeModelExtractor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DatetimeModelExtractorT NativeTableType; + enum { + VT_EXTRACTOR = 4, + VT_PATTERN = 6, + VT_LOCALES = 8, + VT_COMPRESSED_PATTERN = 10 + }; + DatetimeExtractorType extractor() const { + return static_cast<DatetimeExtractorType>(GetField<int32_t>(VT_EXTRACTOR, 0)); + } + const flatbuffers::String *pattern() const { + return GetPointer<const flatbuffers::String *>(VT_PATTERN); + } + const flatbuffers::Vector<int32_t> *locales() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_LOCALES); + } + const CompressedBuffer *compressed_pattern() const { + return GetPointer<const CompressedBuffer *>(VT_COMPRESSED_PATTERN); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_EXTRACTOR) && + VerifyOffset(verifier, VT_PATTERN) && + verifier.Verify(pattern()) && + VerifyOffset(verifier, VT_LOCALES) && + verifier.Verify(locales()) && + VerifyOffset(verifier, VT_COMPRESSED_PATTERN) && + verifier.VerifyTable(compressed_pattern()) && + verifier.EndTable(); + } + DatetimeModelExtractorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DatetimeModelExtractorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<DatetimeModelExtractor> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DatetimeModelExtractorBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_extractor(DatetimeExtractorType extractor) { + fbb_.AddElement<int32_t>(DatetimeModelExtractor::VT_EXTRACTOR, static_cast<int32_t>(extractor), 0); + } + void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) { + fbb_.AddOffset(DatetimeModelExtractor::VT_PATTERN, pattern); + } + void add_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales) { + fbb_.AddOffset(DatetimeModelExtractor::VT_LOCALES, locales); + } + void add_compressed_pattern(flatbuffers::Offset<CompressedBuffer> compressed_pattern) { + fbb_.AddOffset(DatetimeModelExtractor::VT_COMPRESSED_PATTERN, compressed_pattern); + } + explicit DatetimeModelExtractorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DatetimeModelExtractorBuilder &operator=(const DatetimeModelExtractorBuilder &); + flatbuffers::Offset<DatetimeModelExtractor> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<DatetimeModelExtractor>(end); + return o; + } +}; + +inline flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractor( + flatbuffers::FlatBufferBuilder &_fbb, + DatetimeExtractorType extractor = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE, + flatbuffers::Offset<flatbuffers::String> pattern = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales = 0, + flatbuffers::Offset<CompressedBuffer> compressed_pattern = 0) { + DatetimeModelExtractorBuilder builder_(_fbb); + builder_.add_compressed_pattern(compressed_pattern); + builder_.add_locales(locales); + builder_.add_pattern(pattern); + builder_.add_extractor(extractor); + return builder_.Finish(); +} + +inline flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + DatetimeExtractorType extractor = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE, + const char *pattern = nullptr, + const std::vector<int32_t> *locales = nullptr, + flatbuffers::Offset<CompressedBuffer> compressed_pattern = 0) { + return libtextclassifier2::CreateDatetimeModelExtractor( + _fbb, + extractor, + pattern ? _fbb.CreateString(pattern) : 0, + locales ? _fbb.CreateVector<int32_t>(*locales) : 0, + compressed_pattern); +} + +flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractor(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DatetimeModelT : public flatbuffers::NativeTable { + typedef DatetimeModel TableType; + std::vector<std::string> locales; + std::vector<std::unique_ptr<DatetimeModelPatternT>> patterns; + std::vector<std::unique_ptr<DatetimeModelExtractorT>> extractors; + bool use_extractors_for_locating; + std::vector<int32_t> default_locales; + DatetimeModelT() + : use_extractors_for_locating(true) { + } +}; + +struct DatetimeModel FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DatetimeModelT NativeTableType; + enum { + VT_LOCALES = 4, + VT_PATTERNS = 6, + VT_EXTRACTORS = 8, + VT_USE_EXTRACTORS_FOR_LOCATING = 10, + VT_DEFAULT_LOCALES = 12 + }; + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *locales() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_LOCALES); + } + const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>> *patterns() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>> *>(VT_PATTERNS); + } + const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>> *extractors() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>> *>(VT_EXTRACTORS); + } + bool use_extractors_for_locating() const { + return GetField<uint8_t>(VT_USE_EXTRACTORS_FOR_LOCATING, 1) != 0; + } + const flatbuffers::Vector<int32_t> *default_locales() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_DEFAULT_LOCALES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_LOCALES) && + verifier.Verify(locales()) && + verifier.VerifyVectorOfStrings(locales()) && + VerifyOffset(verifier, VT_PATTERNS) && + verifier.Verify(patterns()) && + verifier.VerifyVectorOfTables(patterns()) && + VerifyOffset(verifier, VT_EXTRACTORS) && + verifier.Verify(extractors()) && + verifier.VerifyVectorOfTables(extractors()) && + VerifyField<uint8_t>(verifier, VT_USE_EXTRACTORS_FOR_LOCATING) && + VerifyOffset(verifier, VT_DEFAULT_LOCALES) && + verifier.Verify(default_locales()) && + verifier.EndTable(); + } + DatetimeModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DatetimeModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<DatetimeModel> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DatetimeModelBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_locales(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> locales) { + fbb_.AddOffset(DatetimeModel::VT_LOCALES, locales); + } + void add_patterns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>>> patterns) { + fbb_.AddOffset(DatetimeModel::VT_PATTERNS, patterns); + } + void add_extractors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>>> extractors) { + fbb_.AddOffset(DatetimeModel::VT_EXTRACTORS, extractors); + } + void add_use_extractors_for_locating(bool use_extractors_for_locating) { + fbb_.AddElement<uint8_t>(DatetimeModel::VT_USE_EXTRACTORS_FOR_LOCATING, static_cast<uint8_t>(use_extractors_for_locating), 1); + } + void add_default_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> default_locales) { + fbb_.AddOffset(DatetimeModel::VT_DEFAULT_LOCALES, default_locales); + } + explicit DatetimeModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DatetimeModelBuilder &operator=(const DatetimeModelBuilder &); + flatbuffers::Offset<DatetimeModel> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<DatetimeModel>(end); + return o; + } +}; + +inline flatbuffers::Offset<DatetimeModel> CreateDatetimeModel( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> locales = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>>> patterns = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>>> extractors = 0, + bool use_extractors_for_locating = true, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> default_locales = 0) { + DatetimeModelBuilder builder_(_fbb); + builder_.add_default_locales(default_locales); + builder_.add_extractors(extractors); + builder_.add_patterns(patterns); + builder_.add_locales(locales); + builder_.add_use_extractors_for_locating(use_extractors_for_locating); + return builder_.Finish(); +} + +inline flatbuffers::Offset<DatetimeModel> CreateDatetimeModelDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *locales = nullptr, + const std::vector<flatbuffers::Offset<DatetimeModelPattern>> *patterns = nullptr, + const std::vector<flatbuffers::Offset<DatetimeModelExtractor>> *extractors = nullptr, + bool use_extractors_for_locating = true, + const std::vector<int32_t> *default_locales = nullptr) { + return libtextclassifier2::CreateDatetimeModel( + _fbb, + locales ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*locales) : 0, + patterns ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelPattern>>(*patterns) : 0, + extractors ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>>(*extractors) : 0, + use_extractors_for_locating, + default_locales ? _fbb.CreateVector<int32_t>(*default_locales) : 0); +} + +flatbuffers::Offset<DatetimeModel> CreateDatetimeModel(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +namespace DatetimeModelLibrary_ { + +struct ItemT : public flatbuffers::NativeTable { + typedef Item TableType; + std::string key; + std::unique_ptr<libtextclassifier2::DatetimeModelT> value; + ItemT() { + } +}; + +struct Item FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ItemT NativeTableType; + enum { + VT_KEY = 4, + VT_VALUE = 6 + }; + const flatbuffers::String *key() const { + return GetPointer<const flatbuffers::String *>(VT_KEY); + } + const libtextclassifier2::DatetimeModel *value() const { + return GetPointer<const libtextclassifier2::DatetimeModel *>(VT_VALUE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_KEY) && + verifier.Verify(key()) && + VerifyOffset(verifier, VT_VALUE) && + verifier.VerifyTable(value()) && + verifier.EndTable(); + } + ItemT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ItemT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<Item> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ItemT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ItemBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_key(flatbuffers::Offset<flatbuffers::String> key) { + fbb_.AddOffset(Item::VT_KEY, key); + } + void add_value(flatbuffers::Offset<libtextclassifier2::DatetimeModel> value) { + fbb_.AddOffset(Item::VT_VALUE, value); + } + explicit ItemBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ItemBuilder &operator=(const ItemBuilder &); + flatbuffers::Offset<Item> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<Item>(end); + return o; + } +}; + +inline flatbuffers::Offset<Item> CreateItem( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::String> key = 0, + flatbuffers::Offset<libtextclassifier2::DatetimeModel> value = 0) { + ItemBuilder builder_(_fbb); + builder_.add_value(value); + builder_.add_key(key); + return builder_.Finish(); +} + +inline flatbuffers::Offset<Item> CreateItemDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *key = nullptr, + flatbuffers::Offset<libtextclassifier2::DatetimeModel> value = 0) { + return libtextclassifier2::DatetimeModelLibrary_::CreateItem( + _fbb, + key ? _fbb.CreateString(key) : 0, + value); +} + +flatbuffers::Offset<Item> CreateItem(flatbuffers::FlatBufferBuilder &_fbb, const ItemT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +} // namespace DatetimeModelLibrary_ + +struct DatetimeModelLibraryT : public flatbuffers::NativeTable { + typedef DatetimeModelLibrary TableType; + std::vector<std::unique_ptr<libtextclassifier2::DatetimeModelLibrary_::ItemT>> models; + DatetimeModelLibraryT() { + } +}; + +struct DatetimeModelLibrary FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DatetimeModelLibraryT NativeTableType; + enum { + VT_MODELS = 4 + }; + const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> *models() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> *>(VT_MODELS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_MODELS) && + verifier.Verify(models()) && + verifier.VerifyVectorOfTables(models()) && + verifier.EndTable(); + } + DatetimeModelLibraryT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DatetimeModelLibraryT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<DatetimeModelLibrary> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DatetimeModelLibraryBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_models(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>>> models) { + fbb_.AddOffset(DatetimeModelLibrary::VT_MODELS, models); + } + explicit DatetimeModelLibraryBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DatetimeModelLibraryBuilder &operator=(const DatetimeModelLibraryBuilder &); + flatbuffers::Offset<DatetimeModelLibrary> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<DatetimeModelLibrary>(end); + return o; + } +}; + +inline flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibrary( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>>> models = 0) { + DatetimeModelLibraryBuilder builder_(_fbb); + builder_.add_models(models); + return builder_.Finish(); +} + +inline flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibraryDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> *models = nullptr) { + return libtextclassifier2::CreateDatetimeModelLibrary( + _fbb, + models ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>>(*models) : 0); +} + +flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibrary(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ModelTriggeringOptionsT : public flatbuffers::NativeTable { + typedef ModelTriggeringOptions TableType; + float min_annotate_confidence; + ModeFlag enabled_modes; + ModelTriggeringOptionsT() + : min_annotate_confidence(0.0f), + enabled_modes(ModeFlag_ALL) { + } +}; + +struct ModelTriggeringOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ModelTriggeringOptionsT NativeTableType; + enum { + VT_MIN_ANNOTATE_CONFIDENCE = 4, + VT_ENABLED_MODES = 6 + }; + float min_annotate_confidence() const { + return GetField<float>(VT_MIN_ANNOTATE_CONFIDENCE, 0.0f); + } + ModeFlag enabled_modes() const { + return static_cast<ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<float>(verifier, VT_MIN_ANNOTATE_CONFIDENCE) && + VerifyField<int32_t>(verifier, VT_ENABLED_MODES) && + verifier.EndTable(); + } + ModelTriggeringOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ModelTriggeringOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<ModelTriggeringOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ModelTriggeringOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_min_annotate_confidence(float min_annotate_confidence) { + fbb_.AddElement<float>(ModelTriggeringOptions::VT_MIN_ANNOTATE_CONFIDENCE, min_annotate_confidence, 0.0f); + } + void add_enabled_modes(ModeFlag enabled_modes) { + fbb_.AddElement<int32_t>(ModelTriggeringOptions::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7); + } + explicit ModelTriggeringOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ModelTriggeringOptionsBuilder &operator=(const ModelTriggeringOptionsBuilder &); + flatbuffers::Offset<ModelTriggeringOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ModelTriggeringOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions( + flatbuffers::FlatBufferBuilder &_fbb, + float min_annotate_confidence = 0.0f, + ModeFlag enabled_modes = ModeFlag_ALL) { + ModelTriggeringOptionsBuilder builder_(_fbb); + builder_.add_enabled_modes(enabled_modes); + builder_.add_min_annotate_confidence(min_annotate_confidence); + return builder_.Finish(); +} + +flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OutputOptionsT : public flatbuffers::NativeTable { + typedef OutputOptions TableType; + std::vector<std::string> filtered_collections_annotation; + std::vector<std::string> filtered_collections_classification; + std::vector<std::string> filtered_collections_selection; + OutputOptionsT() { + } +}; + +struct OutputOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OutputOptionsT NativeTableType; + enum { + VT_FILTERED_COLLECTIONS_ANNOTATION = 4, + VT_FILTERED_COLLECTIONS_CLASSIFICATION = 6, + VT_FILTERED_COLLECTIONS_SELECTION = 8 + }; + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_annotation() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_FILTERED_COLLECTIONS_ANNOTATION); + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_classification() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_FILTERED_COLLECTIONS_CLASSIFICATION); + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_selection() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_FILTERED_COLLECTIONS_SELECTION); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_FILTERED_COLLECTIONS_ANNOTATION) && + verifier.Verify(filtered_collections_annotation()) && + verifier.VerifyVectorOfStrings(filtered_collections_annotation()) && + VerifyOffset(verifier, VT_FILTERED_COLLECTIONS_CLASSIFICATION) && + verifier.Verify(filtered_collections_classification()) && + verifier.VerifyVectorOfStrings(filtered_collections_classification()) && + VerifyOffset(verifier, VT_FILTERED_COLLECTIONS_SELECTION) && + verifier.Verify(filtered_collections_selection()) && + verifier.VerifyVectorOfStrings(filtered_collections_selection()) && + verifier.EndTable(); + } + OutputOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OutputOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<OutputOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct OutputOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_filtered_collections_annotation(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_annotation) { + fbb_.AddOffset(OutputOptions::VT_FILTERED_COLLECTIONS_ANNOTATION, filtered_collections_annotation); + } + void add_filtered_collections_classification(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_classification) { + fbb_.AddOffset(OutputOptions::VT_FILTERED_COLLECTIONS_CLASSIFICATION, filtered_collections_classification); + } + void add_filtered_collections_selection(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_selection) { + fbb_.AddOffset(OutputOptions::VT_FILTERED_COLLECTIONS_SELECTION, filtered_collections_selection); + } + explicit OutputOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OutputOptionsBuilder &operator=(const OutputOptionsBuilder &); + flatbuffers::Offset<OutputOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<OutputOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<OutputOptions> CreateOutputOptions( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_annotation = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_classification = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_selection = 0) { + OutputOptionsBuilder builder_(_fbb); + builder_.add_filtered_collections_selection(filtered_collections_selection); + builder_.add_filtered_collections_classification(filtered_collections_classification); + builder_.add_filtered_collections_annotation(filtered_collections_annotation); + return builder_.Finish(); +} + +inline flatbuffers::Offset<OutputOptions> CreateOutputOptionsDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_annotation = nullptr, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_classification = nullptr, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_selection = nullptr) { + return libtextclassifier2::CreateOutputOptions( + _fbb, + filtered_collections_annotation ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*filtered_collections_annotation) : 0, + filtered_collections_classification ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*filtered_collections_classification) : 0, + filtered_collections_selection ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*filtered_collections_selection) : 0); +} + +flatbuffers::Offset<OutputOptions> CreateOutputOptions(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ModelT : public flatbuffers::NativeTable { + typedef Model TableType; + std::string locales; + int32_t version; + std::string name; + std::unique_ptr<FeatureProcessorOptionsT> selection_feature_options; + std::unique_ptr<FeatureProcessorOptionsT> classification_feature_options; + std::vector<uint8_t> selection_model; + std::vector<uint8_t> classification_model; + std::vector<uint8_t> embedding_model; + std::unique_ptr<SelectionModelOptionsT> selection_options; + std::unique_ptr<ClassificationModelOptionsT> classification_options; + std::unique_ptr<RegexModelT> regex_model; + std::unique_ptr<DatetimeModelT> datetime_model; + std::unique_ptr<ModelTriggeringOptionsT> triggering_options; + ModeFlag enabled_modes; + bool snap_whitespace_selections; + std::unique_ptr<OutputOptionsT> output_options; + ModelT() + : version(0), + enabled_modes(ModeFlag_ALL), + snap_whitespace_selections(true) { + } +}; + +struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ModelT NativeTableType; + enum { + VT_LOCALES = 4, + VT_VERSION = 6, + VT_NAME = 8, + VT_SELECTION_FEATURE_OPTIONS = 10, + VT_CLASSIFICATION_FEATURE_OPTIONS = 12, + VT_SELECTION_MODEL = 14, + VT_CLASSIFICATION_MODEL = 16, + VT_EMBEDDING_MODEL = 18, + VT_SELECTION_OPTIONS = 20, + VT_CLASSIFICATION_OPTIONS = 22, + VT_REGEX_MODEL = 24, + VT_DATETIME_MODEL = 26, + VT_TRIGGERING_OPTIONS = 28, + VT_ENABLED_MODES = 30, + VT_SNAP_WHITESPACE_SELECTIONS = 32, + VT_OUTPUT_OPTIONS = 34 + }; + const flatbuffers::String *locales() const { + return GetPointer<const flatbuffers::String *>(VT_LOCALES); + } + int32_t version() const { + return GetField<int32_t>(VT_VERSION, 0); + } + const flatbuffers::String *name() const { + return GetPointer<const flatbuffers::String *>(VT_NAME); + } + const FeatureProcessorOptions *selection_feature_options() const { + return GetPointer<const FeatureProcessorOptions *>(VT_SELECTION_FEATURE_OPTIONS); + } + const FeatureProcessorOptions *classification_feature_options() const { + return GetPointer<const FeatureProcessorOptions *>(VT_CLASSIFICATION_FEATURE_OPTIONS); + } + const flatbuffers::Vector<uint8_t> *selection_model() const { + return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_SELECTION_MODEL); + } + const flatbuffers::Vector<uint8_t> *classification_model() const { + return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CLASSIFICATION_MODEL); + } + const flatbuffers::Vector<uint8_t> *embedding_model() const { + return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_EMBEDDING_MODEL); + } + const SelectionModelOptions *selection_options() const { + return GetPointer<const SelectionModelOptions *>(VT_SELECTION_OPTIONS); + } + const ClassificationModelOptions *classification_options() const { + return GetPointer<const ClassificationModelOptions *>(VT_CLASSIFICATION_OPTIONS); + } + const RegexModel *regex_model() const { + return GetPointer<const RegexModel *>(VT_REGEX_MODEL); + } + const DatetimeModel *datetime_model() const { + return GetPointer<const DatetimeModel *>(VT_DATETIME_MODEL); + } + const ModelTriggeringOptions *triggering_options() const { + return GetPointer<const ModelTriggeringOptions *>(VT_TRIGGERING_OPTIONS); + } + ModeFlag enabled_modes() const { + return static_cast<ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7)); + } + bool snap_whitespace_selections() const { + return GetField<uint8_t>(VT_SNAP_WHITESPACE_SELECTIONS, 1) != 0; + } + const OutputOptions *output_options() const { + return GetPointer<const OutputOptions *>(VT_OUTPUT_OPTIONS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_LOCALES) && + verifier.Verify(locales()) && + VerifyField<int32_t>(verifier, VT_VERSION) && + VerifyOffset(verifier, VT_NAME) && + verifier.Verify(name()) && + VerifyOffset(verifier, VT_SELECTION_FEATURE_OPTIONS) && + verifier.VerifyTable(selection_feature_options()) && + VerifyOffset(verifier, VT_CLASSIFICATION_FEATURE_OPTIONS) && + verifier.VerifyTable(classification_feature_options()) && + VerifyOffset(verifier, VT_SELECTION_MODEL) && + verifier.Verify(selection_model()) && + VerifyOffset(verifier, VT_CLASSIFICATION_MODEL) && + verifier.Verify(classification_model()) && + VerifyOffset(verifier, VT_EMBEDDING_MODEL) && + verifier.Verify(embedding_model()) && + VerifyOffset(verifier, VT_SELECTION_OPTIONS) && + verifier.VerifyTable(selection_options()) && + VerifyOffset(verifier, VT_CLASSIFICATION_OPTIONS) && + verifier.VerifyTable(classification_options()) && + VerifyOffset(verifier, VT_REGEX_MODEL) && + verifier.VerifyTable(regex_model()) && + VerifyOffset(verifier, VT_DATETIME_MODEL) && + verifier.VerifyTable(datetime_model()) && + VerifyOffset(verifier, VT_TRIGGERING_OPTIONS) && + verifier.VerifyTable(triggering_options()) && + VerifyField<int32_t>(verifier, VT_ENABLED_MODES) && + VerifyField<uint8_t>(verifier, VT_SNAP_WHITESPACE_SELECTIONS) && + VerifyOffset(verifier, VT_OUTPUT_OPTIONS) && + verifier.VerifyTable(output_options()) && + verifier.EndTable(); + } + ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<Model> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ModelBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_locales(flatbuffers::Offset<flatbuffers::String> locales) { + fbb_.AddOffset(Model::VT_LOCALES, locales); + } + void add_version(int32_t version) { + fbb_.AddElement<int32_t>(Model::VT_VERSION, version, 0); + } + void add_name(flatbuffers::Offset<flatbuffers::String> name) { + fbb_.AddOffset(Model::VT_NAME, name); + } + void add_selection_feature_options(flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options) { + fbb_.AddOffset(Model::VT_SELECTION_FEATURE_OPTIONS, selection_feature_options); + } + void add_classification_feature_options(flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options) { + fbb_.AddOffset(Model::VT_CLASSIFICATION_FEATURE_OPTIONS, classification_feature_options); + } + void add_selection_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> selection_model) { + fbb_.AddOffset(Model::VT_SELECTION_MODEL, selection_model); + } + void add_classification_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> classification_model) { + fbb_.AddOffset(Model::VT_CLASSIFICATION_MODEL, classification_model); + } + void add_embedding_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> embedding_model) { + fbb_.AddOffset(Model::VT_EMBEDDING_MODEL, embedding_model); + } + void add_selection_options(flatbuffers::Offset<SelectionModelOptions> selection_options) { + fbb_.AddOffset(Model::VT_SELECTION_OPTIONS, selection_options); + } + void add_classification_options(flatbuffers::Offset<ClassificationModelOptions> classification_options) { + fbb_.AddOffset(Model::VT_CLASSIFICATION_OPTIONS, classification_options); + } + void add_regex_model(flatbuffers::Offset<RegexModel> regex_model) { + fbb_.AddOffset(Model::VT_REGEX_MODEL, regex_model); + } + void add_datetime_model(flatbuffers::Offset<DatetimeModel> datetime_model) { + fbb_.AddOffset(Model::VT_DATETIME_MODEL, datetime_model); + } + void add_triggering_options(flatbuffers::Offset<ModelTriggeringOptions> triggering_options) { + fbb_.AddOffset(Model::VT_TRIGGERING_OPTIONS, triggering_options); + } + void add_enabled_modes(ModeFlag enabled_modes) { + fbb_.AddElement<int32_t>(Model::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7); + } + void add_snap_whitespace_selections(bool snap_whitespace_selections) { + fbb_.AddElement<uint8_t>(Model::VT_SNAP_WHITESPACE_SELECTIONS, static_cast<uint8_t>(snap_whitespace_selections), 1); + } + void add_output_options(flatbuffers::Offset<OutputOptions> output_options) { + fbb_.AddOffset(Model::VT_OUTPUT_OPTIONS, output_options); + } + explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ModelBuilder &operator=(const ModelBuilder &); + flatbuffers::Offset<Model> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<Model>(end); + return o; + } +}; + +inline flatbuffers::Offset<Model> CreateModel( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::String> locales = 0, + int32_t version = 0, + flatbuffers::Offset<flatbuffers::String> name = 0, + flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options = 0, + flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options = 0, + flatbuffers::Offset<flatbuffers::Vector<uint8_t>> selection_model = 0, + flatbuffers::Offset<flatbuffers::Vector<uint8_t>> classification_model = 0, + flatbuffers::Offset<flatbuffers::Vector<uint8_t>> embedding_model = 0, + flatbuffers::Offset<SelectionModelOptions> selection_options = 0, + flatbuffers::Offset<ClassificationModelOptions> classification_options = 0, + flatbuffers::Offset<RegexModel> regex_model = 0, + flatbuffers::Offset<DatetimeModel> datetime_model = 0, + flatbuffers::Offset<ModelTriggeringOptions> triggering_options = 0, + ModeFlag enabled_modes = ModeFlag_ALL, + bool snap_whitespace_selections = true, + flatbuffers::Offset<OutputOptions> output_options = 0) { + ModelBuilder builder_(_fbb); + builder_.add_output_options(output_options); + builder_.add_enabled_modes(enabled_modes); + builder_.add_triggering_options(triggering_options); + builder_.add_datetime_model(datetime_model); + builder_.add_regex_model(regex_model); + builder_.add_classification_options(classification_options); + builder_.add_selection_options(selection_options); + builder_.add_embedding_model(embedding_model); + builder_.add_classification_model(classification_model); + builder_.add_selection_model(selection_model); + builder_.add_classification_feature_options(classification_feature_options); + builder_.add_selection_feature_options(selection_feature_options); + builder_.add_name(name); + builder_.add_version(version); + builder_.add_locales(locales); + builder_.add_snap_whitespace_selections(snap_whitespace_selections); + return builder_.Finish(); +} + +inline flatbuffers::Offset<Model> CreateModelDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *locales = nullptr, + int32_t version = 0, + const char *name = nullptr, + flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options = 0, + flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options = 0, + const std::vector<uint8_t> *selection_model = nullptr, + const std::vector<uint8_t> *classification_model = nullptr, + const std::vector<uint8_t> *embedding_model = nullptr, + flatbuffers::Offset<SelectionModelOptions> selection_options = 0, + flatbuffers::Offset<ClassificationModelOptions> classification_options = 0, + flatbuffers::Offset<RegexModel> regex_model = 0, + flatbuffers::Offset<DatetimeModel> datetime_model = 0, + flatbuffers::Offset<ModelTriggeringOptions> triggering_options = 0, + ModeFlag enabled_modes = ModeFlag_ALL, + bool snap_whitespace_selections = true, + flatbuffers::Offset<OutputOptions> output_options = 0) { + return libtextclassifier2::CreateModel( + _fbb, + locales ? _fbb.CreateString(locales) : 0, + version, + name ? _fbb.CreateString(name) : 0, + selection_feature_options, + classification_feature_options, + selection_model ? _fbb.CreateVector<uint8_t>(*selection_model) : 0, + classification_model ? _fbb.CreateVector<uint8_t>(*classification_model) : 0, + embedding_model ? _fbb.CreateVector<uint8_t>(*embedding_model) : 0, + selection_options, + classification_options, + regex_model, + datetime_model, + triggering_options, + enabled_modes, + snap_whitespace_selections, + output_options); +} + +flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TokenizationCodepointRangeT : public flatbuffers::NativeTable { + typedef TokenizationCodepointRange TableType; + int32_t start; + int32_t end; + libtextclassifier2::TokenizationCodepointRange_::Role role; + int32_t script_id; + TokenizationCodepointRangeT() + : start(0), + end(0), + role(libtextclassifier2::TokenizationCodepointRange_::Role_DEFAULT_ROLE), + script_id(0) { + } +}; + +struct TokenizationCodepointRange FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TokenizationCodepointRangeT NativeTableType; + enum { + VT_START = 4, + VT_END = 6, + VT_ROLE = 8, + VT_SCRIPT_ID = 10 + }; + int32_t start() const { + return GetField<int32_t>(VT_START, 0); + } + int32_t end() const { + return GetField<int32_t>(VT_END, 0); + } + libtextclassifier2::TokenizationCodepointRange_::Role role() const { + return static_cast<libtextclassifier2::TokenizationCodepointRange_::Role>(GetField<int32_t>(VT_ROLE, 0)); + } + int32_t script_id() const { + return GetField<int32_t>(VT_SCRIPT_ID, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_START) && + VerifyField<int32_t>(verifier, VT_END) && + VerifyField<int32_t>(verifier, VT_ROLE) && + VerifyField<int32_t>(verifier, VT_SCRIPT_ID) && + verifier.EndTable(); + } + TokenizationCodepointRangeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TokenizationCodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<TokenizationCodepointRange> Pack(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TokenizationCodepointRangeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_start(int32_t start) { + fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_START, start, 0); + } + void add_end(int32_t end) { + fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_END, end, 0); + } + void add_role(libtextclassifier2::TokenizationCodepointRange_::Role role) { + fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_ROLE, static_cast<int32_t>(role), 0); + } + void add_script_id(int32_t script_id) { + fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_SCRIPT_ID, script_id, 0); + } + explicit TokenizationCodepointRangeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TokenizationCodepointRangeBuilder &operator=(const TokenizationCodepointRangeBuilder &); + flatbuffers::Offset<TokenizationCodepointRange> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<TokenizationCodepointRange>(end); + return o; + } +}; + +inline flatbuffers::Offset<TokenizationCodepointRange> CreateTokenizationCodepointRange( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t start = 0, + int32_t end = 0, + libtextclassifier2::TokenizationCodepointRange_::Role role = libtextclassifier2::TokenizationCodepointRange_::Role_DEFAULT_ROLE, + int32_t script_id = 0) { + TokenizationCodepointRangeBuilder builder_(_fbb); + builder_.add_script_id(script_id); + builder_.add_role(role); + builder_.add_end(end); + builder_.add_start(start); + return builder_.Finish(); +} + +flatbuffers::Offset<TokenizationCodepointRange> CreateTokenizationCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +namespace FeatureProcessorOptions_ { + +struct CodepointRangeT : public flatbuffers::NativeTable { + typedef CodepointRange TableType; + int32_t start; + int32_t end; + CodepointRangeT() + : start(0), + end(0) { + } +}; + +struct CodepointRange FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CodepointRangeT NativeTableType; + enum { + VT_START = 4, + VT_END = 6 + }; + int32_t start() const { + return GetField<int32_t>(VT_START, 0); + } + int32_t end() const { + return GetField<int32_t>(VT_END, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_START) && + VerifyField<int32_t>(verifier, VT_END) && + verifier.EndTable(); + } + CodepointRangeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<CodepointRange> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CodepointRangeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_start(int32_t start) { + fbb_.AddElement<int32_t>(CodepointRange::VT_START, start, 0); + } + void add_end(int32_t end) { + fbb_.AddElement<int32_t>(CodepointRange::VT_END, end, 0); + } + explicit CodepointRangeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CodepointRangeBuilder &operator=(const CodepointRangeBuilder &); + flatbuffers::Offset<CodepointRange> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<CodepointRange>(end); + return o; + } +}; + +inline flatbuffers::Offset<CodepointRange> CreateCodepointRange( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t start = 0, + int32_t end = 0) { + CodepointRangeBuilder builder_(_fbb); + builder_.add_end(end); + builder_.add_start(start); + return builder_.Finish(); +} + +flatbuffers::Offset<CodepointRange> CreateCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BoundsSensitiveFeaturesT : public flatbuffers::NativeTable { + typedef BoundsSensitiveFeatures TableType; + bool enabled; + int32_t num_tokens_before; + int32_t num_tokens_inside_left; + int32_t num_tokens_inside_right; + int32_t num_tokens_after; + bool include_inside_bag; + bool include_inside_length; + bool score_single_token_spans_as_zero; + BoundsSensitiveFeaturesT() + : enabled(false), + num_tokens_before(0), + num_tokens_inside_left(0), + num_tokens_inside_right(0), + num_tokens_after(0), + include_inside_bag(false), + include_inside_length(false), + score_single_token_spans_as_zero(false) { + } +}; + +struct BoundsSensitiveFeatures FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef BoundsSensitiveFeaturesT NativeTableType; + enum { + VT_ENABLED = 4, + VT_NUM_TOKENS_BEFORE = 6, + VT_NUM_TOKENS_INSIDE_LEFT = 8, + VT_NUM_TOKENS_INSIDE_RIGHT = 10, + VT_NUM_TOKENS_AFTER = 12, + VT_INCLUDE_INSIDE_BAG = 14, + VT_INCLUDE_INSIDE_LENGTH = 16, + VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO = 18 + }; + bool enabled() const { + return GetField<uint8_t>(VT_ENABLED, 0) != 0; + } + int32_t num_tokens_before() const { + return GetField<int32_t>(VT_NUM_TOKENS_BEFORE, 0); + } + int32_t num_tokens_inside_left() const { + return GetField<int32_t>(VT_NUM_TOKENS_INSIDE_LEFT, 0); + } + int32_t num_tokens_inside_right() const { + return GetField<int32_t>(VT_NUM_TOKENS_INSIDE_RIGHT, 0); + } + int32_t num_tokens_after() const { + return GetField<int32_t>(VT_NUM_TOKENS_AFTER, 0); + } + bool include_inside_bag() const { + return GetField<uint8_t>(VT_INCLUDE_INSIDE_BAG, 0) != 0; + } + bool include_inside_length() const { + return GetField<uint8_t>(VT_INCLUDE_INSIDE_LENGTH, 0) != 0; + } + bool score_single_token_spans_as_zero() const { + return GetField<uint8_t>(VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<uint8_t>(verifier, VT_ENABLED) && + VerifyField<int32_t>(verifier, VT_NUM_TOKENS_BEFORE) && + VerifyField<int32_t>(verifier, VT_NUM_TOKENS_INSIDE_LEFT) && + VerifyField<int32_t>(verifier, VT_NUM_TOKENS_INSIDE_RIGHT) && + VerifyField<int32_t>(verifier, VT_NUM_TOKENS_AFTER) && + VerifyField<uint8_t>(verifier, VT_INCLUDE_INSIDE_BAG) && + VerifyField<uint8_t>(verifier, VT_INCLUDE_INSIDE_LENGTH) && + VerifyField<uint8_t>(verifier, VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO) && + verifier.EndTable(); + } + BoundsSensitiveFeaturesT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BoundsSensitiveFeaturesT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<BoundsSensitiveFeatures> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BoundsSensitiveFeaturesBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_enabled(bool enabled) { + fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_ENABLED, static_cast<uint8_t>(enabled), 0); + } + void add_num_tokens_before(int32_t num_tokens_before) { + fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_BEFORE, num_tokens_before, 0); + } + void add_num_tokens_inside_left(int32_t num_tokens_inside_left) { + fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_INSIDE_LEFT, num_tokens_inside_left, 0); + } + void add_num_tokens_inside_right(int32_t num_tokens_inside_right) { + fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_INSIDE_RIGHT, num_tokens_inside_right, 0); + } + void add_num_tokens_after(int32_t num_tokens_after) { + fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_AFTER, num_tokens_after, 0); + } + void add_include_inside_bag(bool include_inside_bag) { + fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_INCLUDE_INSIDE_BAG, static_cast<uint8_t>(include_inside_bag), 0); + } + void add_include_inside_length(bool include_inside_length) { + fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_INCLUDE_INSIDE_LENGTH, static_cast<uint8_t>(include_inside_length), 0); + } + void add_score_single_token_spans_as_zero(bool score_single_token_spans_as_zero) { + fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO, static_cast<uint8_t>(score_single_token_spans_as_zero), 0); + } + explicit BoundsSensitiveFeaturesBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + BoundsSensitiveFeaturesBuilder &operator=(const BoundsSensitiveFeaturesBuilder &); + flatbuffers::Offset<BoundsSensitiveFeatures> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<BoundsSensitiveFeatures>(end); + return o; + } +}; + +inline flatbuffers::Offset<BoundsSensitiveFeatures> CreateBoundsSensitiveFeatures( + flatbuffers::FlatBufferBuilder &_fbb, + bool enabled = false, + int32_t num_tokens_before = 0, + int32_t num_tokens_inside_left = 0, + int32_t num_tokens_inside_right = 0, + int32_t num_tokens_after = 0, + bool include_inside_bag = false, + bool include_inside_length = false, + bool score_single_token_spans_as_zero = false) { + BoundsSensitiveFeaturesBuilder builder_(_fbb); + builder_.add_num_tokens_after(num_tokens_after); + builder_.add_num_tokens_inside_right(num_tokens_inside_right); + builder_.add_num_tokens_inside_left(num_tokens_inside_left); + builder_.add_num_tokens_before(num_tokens_before); + builder_.add_score_single_token_spans_as_zero(score_single_token_spans_as_zero); + builder_.add_include_inside_length(include_inside_length); + builder_.add_include_inside_bag(include_inside_bag); + builder_.add_enabled(enabled); + return builder_.Finish(); +} + +flatbuffers::Offset<BoundsSensitiveFeatures> CreateBoundsSensitiveFeatures(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct AlternativeCollectionMapEntryT : public flatbuffers::NativeTable { + typedef AlternativeCollectionMapEntry TableType; + std::string key; + std::string value; + AlternativeCollectionMapEntryT() { + } +}; + +struct AlternativeCollectionMapEntry FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef AlternativeCollectionMapEntryT NativeTableType; + enum { + VT_KEY = 4, + VT_VALUE = 6 + }; + const flatbuffers::String *key() const { + return GetPointer<const flatbuffers::String *>(VT_KEY); + } + const flatbuffers::String *value() const { + return GetPointer<const flatbuffers::String *>(VT_VALUE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_KEY) && + verifier.Verify(key()) && + VerifyOffset(verifier, VT_VALUE) && + verifier.Verify(value()) && + verifier.EndTable(); + } + AlternativeCollectionMapEntryT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(AlternativeCollectionMapEntryT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<AlternativeCollectionMapEntry> Pack(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct AlternativeCollectionMapEntryBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_key(flatbuffers::Offset<flatbuffers::String> key) { + fbb_.AddOffset(AlternativeCollectionMapEntry::VT_KEY, key); + } + void add_value(flatbuffers::Offset<flatbuffers::String> value) { + fbb_.AddOffset(AlternativeCollectionMapEntry::VT_VALUE, value); + } + explicit AlternativeCollectionMapEntryBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + AlternativeCollectionMapEntryBuilder &operator=(const AlternativeCollectionMapEntryBuilder &); + flatbuffers::Offset<AlternativeCollectionMapEntry> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<AlternativeCollectionMapEntry>(end); + return o; + } +}; + +inline flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntry( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::String> key = 0, + flatbuffers::Offset<flatbuffers::String> value = 0) { + AlternativeCollectionMapEntryBuilder builder_(_fbb); + builder_.add_value(value); + builder_.add_key(key); + return builder_.Finish(); +} + +inline flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntryDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *key = nullptr, + const char *value = nullptr) { + return libtextclassifier2::FeatureProcessorOptions_::CreateAlternativeCollectionMapEntry( + _fbb, + key ? _fbb.CreateString(key) : 0, + value ? _fbb.CreateString(value) : 0); +} + +flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntry(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +} // namespace FeatureProcessorOptions_ + +struct FeatureProcessorOptionsT : public flatbuffers::NativeTable { + typedef FeatureProcessorOptions TableType; + int32_t num_buckets; + int32_t embedding_size; + int32_t embedding_quantization_bits; + int32_t context_size; + int32_t max_selection_span; + std::vector<int32_t> chargram_orders; + int32_t max_word_length; + bool unicode_aware_features; + bool extract_case_feature; + bool extract_selection_mask_feature; + std::vector<std::string> regexp_feature; + bool remap_digits; + bool lowercase_tokens; + bool selection_reduced_output_space; + std::vector<std::string> collections; + int32_t default_collection; + bool only_use_line_with_click; + bool split_tokens_on_selection_boundaries; + std::vector<std::unique_ptr<TokenizationCodepointRangeT>> tokenization_codepoint_config; + libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method; + bool snap_label_span_boundaries_to_containing_tokens; + std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>> supported_codepoint_ranges; + std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>> internal_tokenizer_codepoint_ranges; + float min_supported_codepoint_ratio; + int32_t feature_version; + libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type; + bool icu_preserve_whitespace_tokens; + std::vector<int32_t> ignored_span_boundary_codepoints; + std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT> bounds_sensitive_features; + std::vector<std::string> allowed_chargrams; + bool tokenize_on_script_change; + FeatureProcessorOptionsT() + : num_buckets(-1), + embedding_size(-1), + embedding_quantization_bits(8), + context_size(-1), + max_selection_span(-1), + max_word_length(20), + unicode_aware_features(false), + extract_case_feature(false), + extract_selection_mask_feature(false), + remap_digits(false), + lowercase_tokens(false), + selection_reduced_output_space(true), + default_collection(-1), + only_use_line_with_click(false), + split_tokens_on_selection_boundaries(false), + center_token_selection_method(libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD), + snap_label_span_boundaries_to_containing_tokens(false), + min_supported_codepoint_ratio(0.0f), + feature_version(0), + tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER), + icu_preserve_whitespace_tokens(false), + tokenize_on_script_change(false) { + } +}; + +struct FeatureProcessorOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FeatureProcessorOptionsT NativeTableType; + enum { + VT_NUM_BUCKETS = 4, + VT_EMBEDDING_SIZE = 6, + VT_EMBEDDING_QUANTIZATION_BITS = 8, + VT_CONTEXT_SIZE = 10, + VT_MAX_SELECTION_SPAN = 12, + VT_CHARGRAM_ORDERS = 14, + VT_MAX_WORD_LENGTH = 16, + VT_UNICODE_AWARE_FEATURES = 18, + VT_EXTRACT_CASE_FEATURE = 20, + VT_EXTRACT_SELECTION_MASK_FEATURE = 22, + VT_REGEXP_FEATURE = 24, + VT_REMAP_DIGITS = 26, + VT_LOWERCASE_TOKENS = 28, + VT_SELECTION_REDUCED_OUTPUT_SPACE = 30, + VT_COLLECTIONS = 32, + VT_DEFAULT_COLLECTION = 34, + VT_ONLY_USE_LINE_WITH_CLICK = 36, + VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES = 38, + VT_TOKENIZATION_CODEPOINT_CONFIG = 40, + VT_CENTER_TOKEN_SELECTION_METHOD = 42, + VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS = 44, + VT_SUPPORTED_CODEPOINT_RANGES = 46, + VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES = 48, + VT_MIN_SUPPORTED_CODEPOINT_RATIO = 50, + VT_FEATURE_VERSION = 52, + VT_TOKENIZATION_TYPE = 54, + VT_ICU_PRESERVE_WHITESPACE_TOKENS = 56, + VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS = 58, + VT_BOUNDS_SENSITIVE_FEATURES = 60, + VT_ALLOWED_CHARGRAMS = 62, + VT_TOKENIZE_ON_SCRIPT_CHANGE = 64 + }; + int32_t num_buckets() const { + return GetField<int32_t>(VT_NUM_BUCKETS, -1); + } + int32_t embedding_size() const { + return GetField<int32_t>(VT_EMBEDDING_SIZE, -1); + } + int32_t embedding_quantization_bits() const { + return GetField<int32_t>(VT_EMBEDDING_QUANTIZATION_BITS, 8); + } + int32_t context_size() const { + return GetField<int32_t>(VT_CONTEXT_SIZE, -1); + } + int32_t max_selection_span() const { + return GetField<int32_t>(VT_MAX_SELECTION_SPAN, -1); + } + const flatbuffers::Vector<int32_t> *chargram_orders() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_CHARGRAM_ORDERS); + } + int32_t max_word_length() const { + return GetField<int32_t>(VT_MAX_WORD_LENGTH, 20); + } + bool unicode_aware_features() const { + return GetField<uint8_t>(VT_UNICODE_AWARE_FEATURES, 0) != 0; + } + bool extract_case_feature() const { + return GetField<uint8_t>(VT_EXTRACT_CASE_FEATURE, 0) != 0; + } + bool extract_selection_mask_feature() const { + return GetField<uint8_t>(VT_EXTRACT_SELECTION_MASK_FEATURE, 0) != 0; + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *regexp_feature() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_REGEXP_FEATURE); + } + bool remap_digits() const { + return GetField<uint8_t>(VT_REMAP_DIGITS, 0) != 0; + } + bool lowercase_tokens() const { + return GetField<uint8_t>(VT_LOWERCASE_TOKENS, 0) != 0; + } + bool selection_reduced_output_space() const { + return GetField<uint8_t>(VT_SELECTION_REDUCED_OUTPUT_SPACE, 1) != 0; + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *collections() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_COLLECTIONS); + } + int32_t default_collection() const { + return GetField<int32_t>(VT_DEFAULT_COLLECTION, -1); + } + bool only_use_line_with_click() const { + return GetField<uint8_t>(VT_ONLY_USE_LINE_WITH_CLICK, 0) != 0; + } + bool split_tokens_on_selection_boundaries() const { + return GetField<uint8_t>(VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES, 0) != 0; + } + const flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>> *tokenization_codepoint_config() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>> *>(VT_TOKENIZATION_CODEPOINT_CONFIG); + } + libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method() const { + return static_cast<libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod>(GetField<int32_t>(VT_CENTER_TOKEN_SELECTION_METHOD, 0)); + } + bool snap_label_span_boundaries_to_containing_tokens() const { + return GetField<uint8_t>(VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS, 0) != 0; + } + const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *supported_codepoint_ranges() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *>(VT_SUPPORTED_CODEPOINT_RANGES); + } + const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *>(VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES); + } + float min_supported_codepoint_ratio() const { + return GetField<float>(VT_MIN_SUPPORTED_CODEPOINT_RATIO, 0.0f); + } + int32_t feature_version() const { + return GetField<int32_t>(VT_FEATURE_VERSION, 0); + } + libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type() const { + return static_cast<libtextclassifier2::FeatureProcessorOptions_::TokenizationType>(GetField<int32_t>(VT_TOKENIZATION_TYPE, 1)); + } + bool icu_preserve_whitespace_tokens() const { + return GetField<uint8_t>(VT_ICU_PRESERVE_WHITESPACE_TOKENS, 0) != 0; + } + const flatbuffers::Vector<int32_t> *ignored_span_boundary_codepoints() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS); + } + const libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures *bounds_sensitive_features() const { + return GetPointer<const libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures *>(VT_BOUNDS_SENSITIVE_FEATURES); + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_ALLOWED_CHARGRAMS); + } + bool tokenize_on_script_change() const { + return GetField<uint8_t>(VT_TOKENIZE_ON_SCRIPT_CHANGE, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_NUM_BUCKETS) && + VerifyField<int32_t>(verifier, VT_EMBEDDING_SIZE) && + VerifyField<int32_t>(verifier, VT_EMBEDDING_QUANTIZATION_BITS) && + VerifyField<int32_t>(verifier, VT_CONTEXT_SIZE) && + VerifyField<int32_t>(verifier, VT_MAX_SELECTION_SPAN) && + VerifyOffset(verifier, VT_CHARGRAM_ORDERS) && + verifier.Verify(chargram_orders()) && + VerifyField<int32_t>(verifier, VT_MAX_WORD_LENGTH) && + VerifyField<uint8_t>(verifier, VT_UNICODE_AWARE_FEATURES) && + VerifyField<uint8_t>(verifier, VT_EXTRACT_CASE_FEATURE) && + VerifyField<uint8_t>(verifier, VT_EXTRACT_SELECTION_MASK_FEATURE) && + VerifyOffset(verifier, VT_REGEXP_FEATURE) && + verifier.Verify(regexp_feature()) && + verifier.VerifyVectorOfStrings(regexp_feature()) && + VerifyField<uint8_t>(verifier, VT_REMAP_DIGITS) && + VerifyField<uint8_t>(verifier, VT_LOWERCASE_TOKENS) && + VerifyField<uint8_t>(verifier, VT_SELECTION_REDUCED_OUTPUT_SPACE) && + VerifyOffset(verifier, VT_COLLECTIONS) && + verifier.Verify(collections()) && + verifier.VerifyVectorOfStrings(collections()) && + VerifyField<int32_t>(verifier, VT_DEFAULT_COLLECTION) && + VerifyField<uint8_t>(verifier, VT_ONLY_USE_LINE_WITH_CLICK) && + VerifyField<uint8_t>(verifier, VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES) && + VerifyOffset(verifier, VT_TOKENIZATION_CODEPOINT_CONFIG) && + verifier.Verify(tokenization_codepoint_config()) && + verifier.VerifyVectorOfTables(tokenization_codepoint_config()) && + VerifyField<int32_t>(verifier, VT_CENTER_TOKEN_SELECTION_METHOD) && + VerifyField<uint8_t>(verifier, VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS) && + VerifyOffset(verifier, VT_SUPPORTED_CODEPOINT_RANGES) && + verifier.Verify(supported_codepoint_ranges()) && + verifier.VerifyVectorOfTables(supported_codepoint_ranges()) && + VerifyOffset(verifier, VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES) && + verifier.Verify(internal_tokenizer_codepoint_ranges()) && + verifier.VerifyVectorOfTables(internal_tokenizer_codepoint_ranges()) && + VerifyField<float>(verifier, VT_MIN_SUPPORTED_CODEPOINT_RATIO) && + VerifyField<int32_t>(verifier, VT_FEATURE_VERSION) && + VerifyField<int32_t>(verifier, VT_TOKENIZATION_TYPE) && + VerifyField<uint8_t>(verifier, VT_ICU_PRESERVE_WHITESPACE_TOKENS) && + VerifyOffset(verifier, VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS) && + verifier.Verify(ignored_span_boundary_codepoints()) && + VerifyOffset(verifier, VT_BOUNDS_SENSITIVE_FEATURES) && + verifier.VerifyTable(bounds_sensitive_features()) && + VerifyOffset(verifier, VT_ALLOWED_CHARGRAMS) && + verifier.Verify(allowed_chargrams()) && + verifier.VerifyVectorOfStrings(allowed_chargrams()) && + VerifyField<uint8_t>(verifier, VT_TOKENIZE_ON_SCRIPT_CHANGE) && + verifier.EndTable(); + } + FeatureProcessorOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FeatureProcessorOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<FeatureProcessorOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FeatureProcessorOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_num_buckets(int32_t num_buckets) { + fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_NUM_BUCKETS, num_buckets, -1); + } + void add_embedding_size(int32_t embedding_size) { + fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_SIZE, embedding_size, -1); + } + void add_embedding_quantization_bits(int32_t embedding_quantization_bits) { + fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_QUANTIZATION_BITS, embedding_quantization_bits, 8); + } + void add_context_size(int32_t context_size) { + fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CONTEXT_SIZE, context_size, -1); + } + void add_max_selection_span(int32_t max_selection_span) { + fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_MAX_SELECTION_SPAN, max_selection_span, -1); + } + void add_chargram_orders(flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders) { + fbb_.AddOffset(FeatureProcessorOptions::VT_CHARGRAM_ORDERS, chargram_orders); + } + void add_max_word_length(int32_t max_word_length) { + fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_MAX_WORD_LENGTH, max_word_length, 20); + } + void add_unicode_aware_features(bool unicode_aware_features) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_UNICODE_AWARE_FEATURES, static_cast<uint8_t>(unicode_aware_features), 0); + } + void add_extract_case_feature(bool extract_case_feature) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_EXTRACT_CASE_FEATURE, static_cast<uint8_t>(extract_case_feature), 0); + } + void add_extract_selection_mask_feature(bool extract_selection_mask_feature) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_EXTRACT_SELECTION_MASK_FEATURE, static_cast<uint8_t>(extract_selection_mask_feature), 0); + } + void add_regexp_feature(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexp_feature) { + fbb_.AddOffset(FeatureProcessorOptions::VT_REGEXP_FEATURE, regexp_feature); + } + void add_remap_digits(bool remap_digits) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_REMAP_DIGITS, static_cast<uint8_t>(remap_digits), 0); + } + void add_lowercase_tokens(bool lowercase_tokens) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_LOWERCASE_TOKENS, static_cast<uint8_t>(lowercase_tokens), 0); + } + void add_selection_reduced_output_space(bool selection_reduced_output_space) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SELECTION_REDUCED_OUTPUT_SPACE, static_cast<uint8_t>(selection_reduced_output_space), 1); + } + void add_collections(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> collections) { + fbb_.AddOffset(FeatureProcessorOptions::VT_COLLECTIONS, collections); + } + void add_default_collection(int32_t default_collection) { + fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_DEFAULT_COLLECTION, default_collection, -1); + } + void add_only_use_line_with_click(bool only_use_line_with_click) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ONLY_USE_LINE_WITH_CLICK, static_cast<uint8_t>(only_use_line_with_click), 0); + } + void add_split_tokens_on_selection_boundaries(bool split_tokens_on_selection_boundaries) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES, static_cast<uint8_t>(split_tokens_on_selection_boundaries), 0); + } + void add_tokenization_codepoint_config(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>>> tokenization_codepoint_config) { + fbb_.AddOffset(FeatureProcessorOptions::VT_TOKENIZATION_CODEPOINT_CONFIG, tokenization_codepoint_config); + } + void add_center_token_selection_method(libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method) { + fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CENTER_TOKEN_SELECTION_METHOD, static_cast<int32_t>(center_token_selection_method), 0); + } + void add_snap_label_span_boundaries_to_containing_tokens(bool snap_label_span_boundaries_to_containing_tokens) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS, static_cast<uint8_t>(snap_label_span_boundaries_to_containing_tokens), 0); + } + void add_supported_codepoint_ranges(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> supported_codepoint_ranges) { + fbb_.AddOffset(FeatureProcessorOptions::VT_SUPPORTED_CODEPOINT_RANGES, supported_codepoint_ranges); + } + void add_internal_tokenizer_codepoint_ranges(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges) { + fbb_.AddOffset(FeatureProcessorOptions::VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES, internal_tokenizer_codepoint_ranges); + } + void add_min_supported_codepoint_ratio(float min_supported_codepoint_ratio) { + fbb_.AddElement<float>(FeatureProcessorOptions::VT_MIN_SUPPORTED_CODEPOINT_RATIO, min_supported_codepoint_ratio, 0.0f); + } + void add_feature_version(int32_t feature_version) { + fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_FEATURE_VERSION, feature_version, 0); + } + void add_tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type) { + fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_TOKENIZATION_TYPE, static_cast<int32_t>(tokenization_type), 1); + } + void add_icu_preserve_whitespace_tokens(bool icu_preserve_whitespace_tokens) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ICU_PRESERVE_WHITESPACE_TOKENS, static_cast<uint8_t>(icu_preserve_whitespace_tokens), 0); + } + void add_ignored_span_boundary_codepoints(flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints) { + fbb_.AddOffset(FeatureProcessorOptions::VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS, ignored_span_boundary_codepoints); + } + void add_bounds_sensitive_features(flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features) { + fbb_.AddOffset(FeatureProcessorOptions::VT_BOUNDS_SENSITIVE_FEATURES, bounds_sensitive_features); + } + void add_allowed_chargrams(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams) { + fbb_.AddOffset(FeatureProcessorOptions::VT_ALLOWED_CHARGRAMS, allowed_chargrams); + } + void add_tokenize_on_script_change(bool tokenize_on_script_change) { + fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_TOKENIZE_ON_SCRIPT_CHANGE, static_cast<uint8_t>(tokenize_on_script_change), 0); + } + explicit FeatureProcessorOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + FeatureProcessorOptionsBuilder &operator=(const FeatureProcessorOptionsBuilder &); + flatbuffers::Offset<FeatureProcessorOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<FeatureProcessorOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_buckets = -1, + int32_t embedding_size = -1, + int32_t embedding_quantization_bits = 8, + int32_t context_size = -1, + int32_t max_selection_span = -1, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders = 0, + int32_t max_word_length = 20, + bool unicode_aware_features = false, + bool extract_case_feature = false, + bool extract_selection_mask_feature = false, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexp_feature = 0, + bool remap_digits = false, + bool lowercase_tokens = false, + bool selection_reduced_output_space = true, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> collections = 0, + int32_t default_collection = -1, + bool only_use_line_with_click = false, + bool split_tokens_on_selection_boundaries = false, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>>> tokenization_codepoint_config = 0, + libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method = libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD, + bool snap_label_span_boundaries_to_containing_tokens = false, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> supported_codepoint_ranges = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges = 0, + float min_supported_codepoint_ratio = 0.0f, + int32_t feature_version = 0, + libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER, + bool icu_preserve_whitespace_tokens = false, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints = 0, + flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams = 0, + bool tokenize_on_script_change = false) { + FeatureProcessorOptionsBuilder builder_(_fbb); + builder_.add_allowed_chargrams(allowed_chargrams); + builder_.add_bounds_sensitive_features(bounds_sensitive_features); + builder_.add_ignored_span_boundary_codepoints(ignored_span_boundary_codepoints); + builder_.add_tokenization_type(tokenization_type); + builder_.add_feature_version(feature_version); + builder_.add_min_supported_codepoint_ratio(min_supported_codepoint_ratio); + builder_.add_internal_tokenizer_codepoint_ranges(internal_tokenizer_codepoint_ranges); + builder_.add_supported_codepoint_ranges(supported_codepoint_ranges); + builder_.add_center_token_selection_method(center_token_selection_method); + builder_.add_tokenization_codepoint_config(tokenization_codepoint_config); + builder_.add_default_collection(default_collection); + builder_.add_collections(collections); + builder_.add_regexp_feature(regexp_feature); + builder_.add_max_word_length(max_word_length); + builder_.add_chargram_orders(chargram_orders); + builder_.add_max_selection_span(max_selection_span); + builder_.add_context_size(context_size); + builder_.add_embedding_quantization_bits(embedding_quantization_bits); + builder_.add_embedding_size(embedding_size); + builder_.add_num_buckets(num_buckets); + builder_.add_tokenize_on_script_change(tokenize_on_script_change); + builder_.add_icu_preserve_whitespace_tokens(icu_preserve_whitespace_tokens); + builder_.add_snap_label_span_boundaries_to_containing_tokens(snap_label_span_boundaries_to_containing_tokens); + builder_.add_split_tokens_on_selection_boundaries(split_tokens_on_selection_boundaries); + builder_.add_only_use_line_with_click(only_use_line_with_click); + builder_.add_selection_reduced_output_space(selection_reduced_output_space); + builder_.add_lowercase_tokens(lowercase_tokens); + builder_.add_remap_digits(remap_digits); + builder_.add_extract_selection_mask_feature(extract_selection_mask_feature); + builder_.add_extract_case_feature(extract_case_feature); + builder_.add_unicode_aware_features(unicode_aware_features); + return builder_.Finish(); +} + +inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptionsDirect( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_buckets = -1, + int32_t embedding_size = -1, + int32_t embedding_quantization_bits = 8, + int32_t context_size = -1, + int32_t max_selection_span = -1, + const std::vector<int32_t> *chargram_orders = nullptr, + int32_t max_word_length = 20, + bool unicode_aware_features = false, + bool extract_case_feature = false, + bool extract_selection_mask_feature = false, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *regexp_feature = nullptr, + bool remap_digits = false, + bool lowercase_tokens = false, + bool selection_reduced_output_space = true, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *collections = nullptr, + int32_t default_collection = -1, + bool only_use_line_with_click = false, + bool split_tokens_on_selection_boundaries = false, + const std::vector<flatbuffers::Offset<TokenizationCodepointRange>> *tokenization_codepoint_config = nullptr, + libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method = libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD, + bool snap_label_span_boundaries_to_containing_tokens = false, + const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *supported_codepoint_ranges = nullptr, + const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges = nullptr, + float min_supported_codepoint_ratio = 0.0f, + int32_t feature_version = 0, + libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER, + bool icu_preserve_whitespace_tokens = false, + const std::vector<int32_t> *ignored_span_boundary_codepoints = nullptr, + flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams = nullptr, + bool tokenize_on_script_change = false) { + return libtextclassifier2::CreateFeatureProcessorOptions( + _fbb, + num_buckets, + embedding_size, + embedding_quantization_bits, + context_size, + max_selection_span, + chargram_orders ? _fbb.CreateVector<int32_t>(*chargram_orders) : 0, + max_word_length, + unicode_aware_features, + extract_case_feature, + extract_selection_mask_feature, + regexp_feature ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*regexp_feature) : 0, + remap_digits, + lowercase_tokens, + selection_reduced_output_space, + collections ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*collections) : 0, + default_collection, + only_use_line_with_click, + split_tokens_on_selection_boundaries, + tokenization_codepoint_config ? _fbb.CreateVector<flatbuffers::Offset<TokenizationCodepointRange>>(*tokenization_codepoint_config) : 0, + center_token_selection_method, + snap_label_span_boundaries_to_containing_tokens, + supported_codepoint_ranges ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>(*supported_codepoint_ranges) : 0, + internal_tokenizer_codepoint_ranges ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>(*internal_tokenizer_codepoint_ranges) : 0, + min_supported_codepoint_ratio, + feature_version, + tokenization_type, + icu_preserve_whitespace_tokens, + ignored_span_boundary_codepoints ? _fbb.CreateVector<int32_t>(*ignored_span_boundary_codepoints) : 0, + bounds_sensitive_features, + allowed_chargrams ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*allowed_chargrams) : 0, + tokenize_on_script_change); +} + +flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +inline CompressedBufferT *CompressedBuffer::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new CompressedBufferT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void CompressedBuffer::UnPackTo(CompressedBufferT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = buffer(); if (_e) { _o->buffer.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->buffer[_i] = _e->Get(_i); } } }; + { auto _e = uncompressed_size(); _o->uncompressed_size = _e; }; +} + +inline flatbuffers::Offset<CompressedBuffer> CompressedBuffer::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateCompressedBuffer(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<CompressedBuffer> CreateCompressedBuffer(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CompressedBufferT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _buffer = _o->buffer.size() ? _fbb.CreateVector(_o->buffer) : 0; + auto _uncompressed_size = _o->uncompressed_size; + return libtextclassifier2::CreateCompressedBuffer( + _fbb, + _buffer, + _uncompressed_size); +} + +inline SelectionModelOptionsT *SelectionModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SelectionModelOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SelectionModelOptions::UnPackTo(SelectionModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = strip_unpaired_brackets(); _o->strip_unpaired_brackets = _e; }; + { auto _e = symmetry_context_size(); _o->symmetry_context_size = _e; }; + { auto _e = batch_size(); _o->batch_size = _e; }; + { auto _e = always_classify_suggested_selection(); _o->always_classify_suggested_selection = _e; }; +} + +inline flatbuffers::Offset<SelectionModelOptions> SelectionModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSelectionModelOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<SelectionModelOptions> CreateSelectionModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SelectionModelOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _strip_unpaired_brackets = _o->strip_unpaired_brackets; + auto _symmetry_context_size = _o->symmetry_context_size; + auto _batch_size = _o->batch_size; + auto _always_classify_suggested_selection = _o->always_classify_suggested_selection; + return libtextclassifier2::CreateSelectionModelOptions( + _fbb, + _strip_unpaired_brackets, + _symmetry_context_size, + _batch_size, + _always_classify_suggested_selection); +} + +inline ClassificationModelOptionsT *ClassificationModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ClassificationModelOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ClassificationModelOptions::UnPackTo(ClassificationModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = phone_min_num_digits(); _o->phone_min_num_digits = _e; }; + { auto _e = phone_max_num_digits(); _o->phone_max_num_digits = _e; }; + { auto _e = address_min_num_tokens(); _o->address_min_num_tokens = _e; }; + { auto _e = max_num_tokens(); _o->max_num_tokens = _e; }; +} + +inline flatbuffers::Offset<ClassificationModelOptions> ClassificationModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateClassificationModelOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ClassificationModelOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _phone_min_num_digits = _o->phone_min_num_digits; + auto _phone_max_num_digits = _o->phone_max_num_digits; + auto _address_min_num_tokens = _o->address_min_num_tokens; + auto _max_num_tokens = _o->max_num_tokens; + return libtextclassifier2::CreateClassificationModelOptions( + _fbb, + _phone_min_num_digits, + _phone_max_num_digits, + _address_min_num_tokens, + _max_num_tokens); +} + +namespace RegexModel_ { + +inline PatternT *Pattern::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new PatternT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Pattern::UnPackTo(PatternT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = collection_name(); if (_e) _o->collection_name = _e->str(); }; + { auto _e = pattern(); if (_e) _o->pattern = _e->str(); }; + { auto _e = enabled_modes(); _o->enabled_modes = _e; }; + { auto _e = target_classification_score(); _o->target_classification_score = _e; }; + { auto _e = priority_score(); _o->priority_score = _e; }; + { auto _e = use_approximate_matching(); _o->use_approximate_matching = _e; }; + { auto _e = compressed_pattern(); if (_e) _o->compressed_pattern = std::unique_ptr<libtextclassifier2::CompressedBufferT>(_e->UnPack(_resolver)); }; +} + +inline flatbuffers::Offset<Pattern> Pattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreatePattern(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<Pattern> CreatePattern(flatbuffers::FlatBufferBuilder &_fbb, const PatternT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _collection_name = _o->collection_name.empty() ? 0 : _fbb.CreateString(_o->collection_name); + auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern); + auto _enabled_modes = _o->enabled_modes; + auto _target_classification_score = _o->target_classification_score; + auto _priority_score = _o->priority_score; + auto _use_approximate_matching = _o->use_approximate_matching; + auto _compressed_pattern = _o->compressed_pattern ? CreateCompressedBuffer(_fbb, _o->compressed_pattern.get(), _rehasher) : 0; + return libtextclassifier2::RegexModel_::CreatePattern( + _fbb, + _collection_name, + _pattern, + _enabled_modes, + _target_classification_score, + _priority_score, + _use_approximate_matching, + _compressed_pattern); +} + +} // namespace RegexModel_ + +inline RegexModelT *RegexModel::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new RegexModelT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void RegexModel::UnPackTo(RegexModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<libtextclassifier2::RegexModel_::PatternT>(_e->Get(_i)->UnPack(_resolver)); } } }; +} + +inline flatbuffers::Offset<RegexModel> RegexModel::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateRegexModel(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<RegexModel> CreateRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RegexModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreatePattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0; + return libtextclassifier2::CreateRegexModel( + _fbb, + _patterns); +} + +namespace DatetimeModelPattern_ { + +inline RegexT *Regex::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new RegexT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Regex::UnPackTo(RegexT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = pattern(); if (_e) _o->pattern = _e->str(); }; + { auto _e = groups(); if (_e) { _o->groups.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->groups[_i] = (DatetimeGroupType)_e->Get(_i); } } }; + { auto _e = compressed_pattern(); if (_e) _o->compressed_pattern = std::unique_ptr<libtextclassifier2::CompressedBufferT>(_e->UnPack(_resolver)); }; +} + +inline flatbuffers::Offset<Regex> Regex::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateRegex(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<Regex> CreateRegex(flatbuffers::FlatBufferBuilder &_fbb, const RegexT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RegexT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern); + auto _groups = _o->groups.size() ? _fbb.CreateVector((const int32_t*)_o->groups.data(), _o->groups.size()) : 0; + auto _compressed_pattern = _o->compressed_pattern ? CreateCompressedBuffer(_fbb, _o->compressed_pattern.get(), _rehasher) : 0; + return libtextclassifier2::DatetimeModelPattern_::CreateRegex( + _fbb, + _pattern, + _groups, + _compressed_pattern); +} + +} // namespace DatetimeModelPattern_ + +inline DatetimeModelPatternT *DatetimeModelPattern::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new DatetimeModelPatternT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void DatetimeModelPattern::UnPackTo(DatetimeModelPatternT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = regexes(); if (_e) { _o->regexes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regexes[_i] = std::unique_ptr<libtextclassifier2::DatetimeModelPattern_::RegexT>(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i); } } }; + { auto _e = target_classification_score(); _o->target_classification_score = _e; }; + { auto _e = priority_score(); _o->priority_score = _e; }; + { auto _e = enabled_modes(); _o->enabled_modes = _e; }; +} + +inline flatbuffers::Offset<DatetimeModelPattern> DatetimeModelPattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateDatetimeModelPattern(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelPatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _regexes = _o->regexes.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> (_o->regexes.size(), [](size_t i, _VectorArgs *__va) { return CreateRegex(*__va->__fbb, __va->__o->regexes[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _locales = _o->locales.size() ? _fbb.CreateVector(_o->locales) : 0; + auto _target_classification_score = _o->target_classification_score; + auto _priority_score = _o->priority_score; + auto _enabled_modes = _o->enabled_modes; + return libtextclassifier2::CreateDatetimeModelPattern( + _fbb, + _regexes, + _locales, + _target_classification_score, + _priority_score, + _enabled_modes); +} + +inline DatetimeModelExtractorT *DatetimeModelExtractor::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new DatetimeModelExtractorT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void DatetimeModelExtractor::UnPackTo(DatetimeModelExtractorT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = extractor(); _o->extractor = _e; }; + { auto _e = pattern(); if (_e) _o->pattern = _e->str(); }; + { auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i); } } }; + { auto _e = compressed_pattern(); if (_e) _o->compressed_pattern = std::unique_ptr<CompressedBufferT>(_e->UnPack(_resolver)); }; +} + +inline flatbuffers::Offset<DatetimeModelExtractor> DatetimeModelExtractor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateDatetimeModelExtractor(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractor(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelExtractorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _extractor = _o->extractor; + auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern); + auto _locales = _o->locales.size() ? _fbb.CreateVector(_o->locales) : 0; + auto _compressed_pattern = _o->compressed_pattern ? CreateCompressedBuffer(_fbb, _o->compressed_pattern.get(), _rehasher) : 0; + return libtextclassifier2::CreateDatetimeModelExtractor( + _fbb, + _extractor, + _pattern, + _locales, + _compressed_pattern); +} + +inline DatetimeModelT *DatetimeModel::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new DatetimeModelT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void DatetimeModel::UnPackTo(DatetimeModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i)->str(); } } }; + { auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<DatetimeModelPatternT>(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = extractors(); if (_e) { _o->extractors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->extractors[_i] = std::unique_ptr<DatetimeModelExtractorT>(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = use_extractors_for_locating(); _o->use_extractors_for_locating = _e; }; + { auto _e = default_locales(); if (_e) { _o->default_locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->default_locales[_i] = _e->Get(_i); } } }; +} + +inline flatbuffers::Offset<DatetimeModel> DatetimeModel::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateDatetimeModel(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<DatetimeModel> CreateDatetimeModel(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _locales = _o->locales.size() ? _fbb.CreateVectorOfStrings(_o->locales) : 0; + auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelPattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreateDatetimeModelPattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _extractors = _o->extractors.size() ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>> (_o->extractors.size(), [](size_t i, _VectorArgs *__va) { return CreateDatetimeModelExtractor(*__va->__fbb, __va->__o->extractors[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _use_extractors_for_locating = _o->use_extractors_for_locating; + auto _default_locales = _o->default_locales.size() ? _fbb.CreateVector(_o->default_locales) : 0; + return libtextclassifier2::CreateDatetimeModel( + _fbb, + _locales, + _patterns, + _extractors, + _use_extractors_for_locating, + _default_locales); +} + +namespace DatetimeModelLibrary_ { + +inline ItemT *Item::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ItemT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Item::UnPackTo(ItemT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = key(); if (_e) _o->key = _e->str(); }; + { auto _e = value(); if (_e) _o->value = std::unique_ptr<libtextclassifier2::DatetimeModelT>(_e->UnPack(_resolver)); }; +} + +inline flatbuffers::Offset<Item> Item::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ItemT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateItem(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<Item> CreateItem(flatbuffers::FlatBufferBuilder &_fbb, const ItemT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ItemT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _key = _o->key.empty() ? 0 : _fbb.CreateString(_o->key); + auto _value = _o->value ? CreateDatetimeModel(_fbb, _o->value.get(), _rehasher) : 0; + return libtextclassifier2::DatetimeModelLibrary_::CreateItem( + _fbb, + _key, + _value); +} + +} // namespace DatetimeModelLibrary_ + +inline DatetimeModelLibraryT *DatetimeModelLibrary::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new DatetimeModelLibraryT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void DatetimeModelLibrary::UnPackTo(DatetimeModelLibraryT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = models(); if (_e) { _o->models.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->models[_i] = std::unique_ptr<libtextclassifier2::DatetimeModelLibrary_::ItemT>(_e->Get(_i)->UnPack(_resolver)); } } }; +} + +inline flatbuffers::Offset<DatetimeModelLibrary> DatetimeModelLibrary::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateDatetimeModelLibrary(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibrary(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelLibraryT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _models = _o->models.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> (_o->models.size(), [](size_t i, _VectorArgs *__va) { return CreateItem(*__va->__fbb, __va->__o->models[i].get(), __va->__rehasher); }, &_va ) : 0; + return libtextclassifier2::CreateDatetimeModelLibrary( + _fbb, + _models); +} + +inline ModelTriggeringOptionsT *ModelTriggeringOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ModelTriggeringOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ModelTriggeringOptions::UnPackTo(ModelTriggeringOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = min_annotate_confidence(); _o->min_annotate_confidence = _e; }; + { auto _e = enabled_modes(); _o->enabled_modes = _e; }; +} + +inline flatbuffers::Offset<ModelTriggeringOptions> ModelTriggeringOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateModelTriggeringOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelTriggeringOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _min_annotate_confidence = _o->min_annotate_confidence; + auto _enabled_modes = _o->enabled_modes; + return libtextclassifier2::CreateModelTriggeringOptions( + _fbb, + _min_annotate_confidence, + _enabled_modes); +} + +inline OutputOptionsT *OutputOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new OutputOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void OutputOptions::UnPackTo(OutputOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = filtered_collections_annotation(); if (_e) { _o->filtered_collections_annotation.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->filtered_collections_annotation[_i] = _e->Get(_i)->str(); } } }; + { auto _e = filtered_collections_classification(); if (_e) { _o->filtered_collections_classification.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->filtered_collections_classification[_i] = _e->Get(_i)->str(); } } }; + { auto _e = filtered_collections_selection(); if (_e) { _o->filtered_collections_selection.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->filtered_collections_selection[_i] = _e->Get(_i)->str(); } } }; +} + +inline flatbuffers::Offset<OutputOptions> OutputOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateOutputOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<OutputOptions> CreateOutputOptions(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OutputOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _filtered_collections_annotation = _o->filtered_collections_annotation.size() ? _fbb.CreateVectorOfStrings(_o->filtered_collections_annotation) : 0; + auto _filtered_collections_classification = _o->filtered_collections_classification.size() ? _fbb.CreateVectorOfStrings(_o->filtered_collections_classification) : 0; + auto _filtered_collections_selection = _o->filtered_collections_selection.size() ? _fbb.CreateVectorOfStrings(_o->filtered_collections_selection) : 0; + return libtextclassifier2::CreateOutputOptions( + _fbb, + _filtered_collections_annotation, + _filtered_collections_classification, + _filtered_collections_selection); +} + +inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ModelT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Model::UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = locales(); if (_e) _o->locales = _e->str(); }; + { auto _e = version(); _o->version = _e; }; + { auto _e = name(); if (_e) _o->name = _e->str(); }; + { auto _e = selection_feature_options(); if (_e) _o->selection_feature_options = std::unique_ptr<FeatureProcessorOptionsT>(_e->UnPack(_resolver)); }; + { auto _e = classification_feature_options(); if (_e) _o->classification_feature_options = std::unique_ptr<FeatureProcessorOptionsT>(_e->UnPack(_resolver)); }; + { auto _e = selection_model(); if (_e) { _o->selection_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->selection_model[_i] = _e->Get(_i); } } }; + { auto _e = classification_model(); if (_e) { _o->classification_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->classification_model[_i] = _e->Get(_i); } } }; + { auto _e = embedding_model(); if (_e) { _o->embedding_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_model[_i] = _e->Get(_i); } } }; + { auto _e = selection_options(); if (_e) _o->selection_options = std::unique_ptr<SelectionModelOptionsT>(_e->UnPack(_resolver)); }; + { auto _e = classification_options(); if (_e) _o->classification_options = std::unique_ptr<ClassificationModelOptionsT>(_e->UnPack(_resolver)); }; + { auto _e = regex_model(); if (_e) _o->regex_model = std::unique_ptr<RegexModelT>(_e->UnPack(_resolver)); }; + { auto _e = datetime_model(); if (_e) _o->datetime_model = std::unique_ptr<DatetimeModelT>(_e->UnPack(_resolver)); }; + { auto _e = triggering_options(); if (_e) _o->triggering_options = std::unique_ptr<ModelTriggeringOptionsT>(_e->UnPack(_resolver)); }; + { auto _e = enabled_modes(); _o->enabled_modes = _e; }; + { auto _e = snap_whitespace_selections(); _o->snap_whitespace_selections = _e; }; + { auto _e = output_options(); if (_e) _o->output_options = std::unique_ptr<OutputOptionsT>(_e->UnPack(_resolver)); }; +} + +inline flatbuffers::Offset<Model> Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateModel(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _locales = _o->locales.empty() ? 0 : _fbb.CreateString(_o->locales); + auto _version = _o->version; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + auto _selection_feature_options = _o->selection_feature_options ? CreateFeatureProcessorOptions(_fbb, _o->selection_feature_options.get(), _rehasher) : 0; + auto _classification_feature_options = _o->classification_feature_options ? CreateFeatureProcessorOptions(_fbb, _o->classification_feature_options.get(), _rehasher) : 0; + auto _selection_model = _o->selection_model.size() ? _fbb.CreateVector(_o->selection_model) : 0; + auto _classification_model = _o->classification_model.size() ? _fbb.CreateVector(_o->classification_model) : 0; + auto _embedding_model = _o->embedding_model.size() ? _fbb.CreateVector(_o->embedding_model) : 0; + auto _selection_options = _o->selection_options ? CreateSelectionModelOptions(_fbb, _o->selection_options.get(), _rehasher) : 0; + auto _classification_options = _o->classification_options ? CreateClassificationModelOptions(_fbb, _o->classification_options.get(), _rehasher) : 0; + auto _regex_model = _o->regex_model ? CreateRegexModel(_fbb, _o->regex_model.get(), _rehasher) : 0; + auto _datetime_model = _o->datetime_model ? CreateDatetimeModel(_fbb, _o->datetime_model.get(), _rehasher) : 0; + auto _triggering_options = _o->triggering_options ? CreateModelTriggeringOptions(_fbb, _o->triggering_options.get(), _rehasher) : 0; + auto _enabled_modes = _o->enabled_modes; + auto _snap_whitespace_selections = _o->snap_whitespace_selections; + auto _output_options = _o->output_options ? CreateOutputOptions(_fbb, _o->output_options.get(), _rehasher) : 0; + return libtextclassifier2::CreateModel( + _fbb, + _locales, + _version, + _name, + _selection_feature_options, + _classification_feature_options, + _selection_model, + _classification_model, + _embedding_model, + _selection_options, + _classification_options, + _regex_model, + _datetime_model, + _triggering_options, + _enabled_modes, + _snap_whitespace_selections, + _output_options); +} + +inline TokenizationCodepointRangeT *TokenizationCodepointRange::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TokenizationCodepointRangeT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void TokenizationCodepointRange::UnPackTo(TokenizationCodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = start(); _o->start = _e; }; + { auto _e = end(); _o->end = _e; }; + { auto _e = role(); _o->role = _e; }; + { auto _e = script_id(); _o->script_id = _e; }; +} + +inline flatbuffers::Offset<TokenizationCodepointRange> TokenizationCodepointRange::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTokenizationCodepointRange(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<TokenizationCodepointRange> CreateTokenizationCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TokenizationCodepointRangeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _start = _o->start; + auto _end = _o->end; + auto _role = _o->role; + auto _script_id = _o->script_id; + return libtextclassifier2::CreateTokenizationCodepointRange( + _fbb, + _start, + _end, + _role, + _script_id); +} + +namespace FeatureProcessorOptions_ { + +inline CodepointRangeT *CodepointRange::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new CodepointRangeT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void CodepointRange::UnPackTo(CodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = start(); _o->start = _e; }; + { auto _e = end(); _o->end = _e; }; +} + +inline flatbuffers::Offset<CodepointRange> CodepointRange::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateCodepointRange(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<CodepointRange> CreateCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CodepointRangeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _start = _o->start; + auto _end = _o->end; + return libtextclassifier2::FeatureProcessorOptions_::CreateCodepointRange( + _fbb, + _start, + _end); +} + +inline BoundsSensitiveFeaturesT *BoundsSensitiveFeatures::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new BoundsSensitiveFeaturesT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void BoundsSensitiveFeatures::UnPackTo(BoundsSensitiveFeaturesT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = enabled(); _o->enabled = _e; }; + { auto _e = num_tokens_before(); _o->num_tokens_before = _e; }; + { auto _e = num_tokens_inside_left(); _o->num_tokens_inside_left = _e; }; + { auto _e = num_tokens_inside_right(); _o->num_tokens_inside_right = _e; }; + { auto _e = num_tokens_after(); _o->num_tokens_after = _e; }; + { auto _e = include_inside_bag(); _o->include_inside_bag = _e; }; + { auto _e = include_inside_length(); _o->include_inside_length = _e; }; + { auto _e = score_single_token_spans_as_zero(); _o->score_single_token_spans_as_zero = _e; }; +} + +inline flatbuffers::Offset<BoundsSensitiveFeatures> BoundsSensitiveFeatures::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateBoundsSensitiveFeatures(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<BoundsSensitiveFeatures> CreateBoundsSensitiveFeatures(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BoundsSensitiveFeaturesT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _enabled = _o->enabled; + auto _num_tokens_before = _o->num_tokens_before; + auto _num_tokens_inside_left = _o->num_tokens_inside_left; + auto _num_tokens_inside_right = _o->num_tokens_inside_right; + auto _num_tokens_after = _o->num_tokens_after; + auto _include_inside_bag = _o->include_inside_bag; + auto _include_inside_length = _o->include_inside_length; + auto _score_single_token_spans_as_zero = _o->score_single_token_spans_as_zero; + return libtextclassifier2::FeatureProcessorOptions_::CreateBoundsSensitiveFeatures( + _fbb, + _enabled, + _num_tokens_before, + _num_tokens_inside_left, + _num_tokens_inside_right, + _num_tokens_after, + _include_inside_bag, + _include_inside_length, + _score_single_token_spans_as_zero); +} + +inline AlternativeCollectionMapEntryT *AlternativeCollectionMapEntry::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new AlternativeCollectionMapEntryT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void AlternativeCollectionMapEntry::UnPackTo(AlternativeCollectionMapEntryT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = key(); if (_e) _o->key = _e->str(); }; + { auto _e = value(); if (_e) _o->value = _e->str(); }; +} + +inline flatbuffers::Offset<AlternativeCollectionMapEntry> AlternativeCollectionMapEntry::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateAlternativeCollectionMapEntry(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntry(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const AlternativeCollectionMapEntryT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _key = _o->key.empty() ? 0 : _fbb.CreateString(_o->key); + auto _value = _o->value.empty() ? 0 : _fbb.CreateString(_o->value); + return libtextclassifier2::FeatureProcessorOptions_::CreateAlternativeCollectionMapEntry( + _fbb, + _key, + _value); +} + +} // namespace FeatureProcessorOptions_ + +inline FeatureProcessorOptionsT *FeatureProcessorOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new FeatureProcessorOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void FeatureProcessorOptions::UnPackTo(FeatureProcessorOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = num_buckets(); _o->num_buckets = _e; }; + { auto _e = embedding_size(); _o->embedding_size = _e; }; + { auto _e = embedding_quantization_bits(); _o->embedding_quantization_bits = _e; }; + { auto _e = context_size(); _o->context_size = _e; }; + { auto _e = max_selection_span(); _o->max_selection_span = _e; }; + { auto _e = chargram_orders(); if (_e) { _o->chargram_orders.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->chargram_orders[_i] = _e->Get(_i); } } }; + { auto _e = max_word_length(); _o->max_word_length = _e; }; + { auto _e = unicode_aware_features(); _o->unicode_aware_features = _e; }; + { auto _e = extract_case_feature(); _o->extract_case_feature = _e; }; + { auto _e = extract_selection_mask_feature(); _o->extract_selection_mask_feature = _e; }; + { auto _e = regexp_feature(); if (_e) { _o->regexp_feature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regexp_feature[_i] = _e->Get(_i)->str(); } } }; + { auto _e = remap_digits(); _o->remap_digits = _e; }; + { auto _e = lowercase_tokens(); _o->lowercase_tokens = _e; }; + { auto _e = selection_reduced_output_space(); _o->selection_reduced_output_space = _e; }; + { auto _e = collections(); if (_e) { _o->collections.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->collections[_i] = _e->Get(_i)->str(); } } }; + { auto _e = default_collection(); _o->default_collection = _e; }; + { auto _e = only_use_line_with_click(); _o->only_use_line_with_click = _e; }; + { auto _e = split_tokens_on_selection_boundaries(); _o->split_tokens_on_selection_boundaries = _e; }; + { auto _e = tokenization_codepoint_config(); if (_e) { _o->tokenization_codepoint_config.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tokenization_codepoint_config[_i] = std::unique_ptr<TokenizationCodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = center_token_selection_method(); _o->center_token_selection_method = _e; }; + { auto _e = snap_label_span_boundaries_to_containing_tokens(); _o->snap_label_span_boundaries_to_containing_tokens = _e; }; + { auto _e = supported_codepoint_ranges(); if (_e) { _o->supported_codepoint_ranges.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->supported_codepoint_ranges[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = internal_tokenizer_codepoint_ranges(); if (_e) { _o->internal_tokenizer_codepoint_ranges.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->internal_tokenizer_codepoint_ranges[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = min_supported_codepoint_ratio(); _o->min_supported_codepoint_ratio = _e; }; + { auto _e = feature_version(); _o->feature_version = _e; }; + { auto _e = tokenization_type(); _o->tokenization_type = _e; }; + { auto _e = icu_preserve_whitespace_tokens(); _o->icu_preserve_whitespace_tokens = _e; }; + { auto _e = ignored_span_boundary_codepoints(); if (_e) { _o->ignored_span_boundary_codepoints.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->ignored_span_boundary_codepoints[_i] = _e->Get(_i); } } }; + { auto _e = bounds_sensitive_features(); if (_e) _o->bounds_sensitive_features = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT>(_e->UnPack(_resolver)); }; + { auto _e = allowed_chargrams(); if (_e) { _o->allowed_chargrams.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->allowed_chargrams[_i] = _e->Get(_i)->str(); } } }; + { auto _e = tokenize_on_script_change(); _o->tokenize_on_script_change = _e; }; +} + +inline flatbuffers::Offset<FeatureProcessorOptions> FeatureProcessorOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateFeatureProcessorOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FeatureProcessorOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _num_buckets = _o->num_buckets; + auto _embedding_size = _o->embedding_size; + auto _embedding_quantization_bits = _o->embedding_quantization_bits; + auto _context_size = _o->context_size; + auto _max_selection_span = _o->max_selection_span; + auto _chargram_orders = _o->chargram_orders.size() ? _fbb.CreateVector(_o->chargram_orders) : 0; + auto _max_word_length = _o->max_word_length; + auto _unicode_aware_features = _o->unicode_aware_features; + auto _extract_case_feature = _o->extract_case_feature; + auto _extract_selection_mask_feature = _o->extract_selection_mask_feature; + auto _regexp_feature = _o->regexp_feature.size() ? _fbb.CreateVectorOfStrings(_o->regexp_feature) : 0; + auto _remap_digits = _o->remap_digits; + auto _lowercase_tokens = _o->lowercase_tokens; + auto _selection_reduced_output_space = _o->selection_reduced_output_space; + auto _collections = _o->collections.size() ? _fbb.CreateVectorOfStrings(_o->collections) : 0; + auto _default_collection = _o->default_collection; + auto _only_use_line_with_click = _o->only_use_line_with_click; + auto _split_tokens_on_selection_boundaries = _o->split_tokens_on_selection_boundaries; + auto _tokenization_codepoint_config = _o->tokenization_codepoint_config.size() ? _fbb.CreateVector<flatbuffers::Offset<TokenizationCodepointRange>> (_o->tokenization_codepoint_config.size(), [](size_t i, _VectorArgs *__va) { return CreateTokenizationCodepointRange(*__va->__fbb, __va->__o->tokenization_codepoint_config[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _center_token_selection_method = _o->center_token_selection_method; + auto _snap_label_span_boundaries_to_containing_tokens = _o->snap_label_span_boundaries_to_containing_tokens; + auto _supported_codepoint_ranges = _o->supported_codepoint_ranges.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> (_o->supported_codepoint_ranges.size(), [](size_t i, _VectorArgs *__va) { return CreateCodepointRange(*__va->__fbb, __va->__o->supported_codepoint_ranges[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _internal_tokenizer_codepoint_ranges = _o->internal_tokenizer_codepoint_ranges.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> (_o->internal_tokenizer_codepoint_ranges.size(), [](size_t i, _VectorArgs *__va) { return CreateCodepointRange(*__va->__fbb, __va->__o->internal_tokenizer_codepoint_ranges[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _min_supported_codepoint_ratio = _o->min_supported_codepoint_ratio; + auto _feature_version = _o->feature_version; + auto _tokenization_type = _o->tokenization_type; + auto _icu_preserve_whitespace_tokens = _o->icu_preserve_whitespace_tokens; + auto _ignored_span_boundary_codepoints = _o->ignored_span_boundary_codepoints.size() ? _fbb.CreateVector(_o->ignored_span_boundary_codepoints) : 0; + auto _bounds_sensitive_features = _o->bounds_sensitive_features ? CreateBoundsSensitiveFeatures(_fbb, _o->bounds_sensitive_features.get(), _rehasher) : 0; + auto _allowed_chargrams = _o->allowed_chargrams.size() ? _fbb.CreateVectorOfStrings(_o->allowed_chargrams) : 0; + auto _tokenize_on_script_change = _o->tokenize_on_script_change; + return libtextclassifier2::CreateFeatureProcessorOptions( + _fbb, + _num_buckets, + _embedding_size, + _embedding_quantization_bits, + _context_size, + _max_selection_span, + _chargram_orders, + _max_word_length, + _unicode_aware_features, + _extract_case_feature, + _extract_selection_mask_feature, + _regexp_feature, + _remap_digits, + _lowercase_tokens, + _selection_reduced_output_space, + _collections, + _default_collection, + _only_use_line_with_click, + _split_tokens_on_selection_boundaries, + _tokenization_codepoint_config, + _center_token_selection_method, + _snap_label_span_boundaries_to_containing_tokens, + _supported_codepoint_ranges, + _internal_tokenizer_codepoint_ranges, + _min_supported_codepoint_ratio, + _feature_version, + _tokenization_type, + _icu_preserve_whitespace_tokens, + _ignored_span_boundary_codepoints, + _bounds_sensitive_features, + _allowed_chargrams, + _tokenize_on_script_change); +} + +inline const libtextclassifier2::Model *GetModel(const void *buf) { + return flatbuffers::GetRoot<libtextclassifier2::Model>(buf); +} + +inline const char *ModelIdentifier() { + return "TC2 "; +} + +inline bool ModelBufferHasIdentifier(const void *buf) { + return flatbuffers::BufferHasIdentifier( + buf, ModelIdentifier()); +} + +inline bool VerifyModelBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer<libtextclassifier2::Model>(ModelIdentifier()); +} + +inline void FinishModelBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset<libtextclassifier2::Model> root) { + fbb.Finish(root, ModelIdentifier()); +} + +inline std::unique_ptr<ModelT> UnPackModel( + const void *buf, + const flatbuffers::resolver_function_t *res = nullptr) { + return std::unique_ptr<ModelT>(GetModel(buf)->UnPack(res)); +} + +} // namespace libtextclassifier2 + +#endif // FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_H_ diff --git a/models/textclassifier.ar.model b/models/textclassifier.ar.model Binary files differnew file mode 100644 index 0000000..2342daa --- /dev/null +++ b/models/textclassifier.ar.model diff --git a/models/textclassifier.en.model b/models/textclassifier.en.model Binary files differnew file mode 100644 index 0000000..a40f940 --- /dev/null +++ b/models/textclassifier.en.model diff --git a/models/textclassifier.es.model b/models/textclassifier.es.model Binary files differnew file mode 100644 index 0000000..7de4e5d --- /dev/null +++ b/models/textclassifier.es.model diff --git a/models/textclassifier.fr.model b/models/textclassifier.fr.model Binary files differnew file mode 100644 index 0000000..1072041 --- /dev/null +++ b/models/textclassifier.fr.model diff --git a/models/textclassifier.it.model b/models/textclassifier.it.model Binary files differnew file mode 100644 index 0000000..5bc98ae --- /dev/null +++ b/models/textclassifier.it.model diff --git a/models/textclassifier.ja.model b/models/textclassifier.ja.model Binary files differnew file mode 100644 index 0000000..9f60b8a --- /dev/null +++ b/models/textclassifier.ja.model diff --git a/models/textclassifier.ko.model b/models/textclassifier.ko.model Binary files differnew file mode 100644 index 0000000..451df45 --- /dev/null +++ b/models/textclassifier.ko.model diff --git a/models/textclassifier.langid.model b/models/textclassifier.langid.model Binary files differdeleted file mode 100644 index 6b68223..0000000 --- a/models/textclassifier.langid.model +++ /dev/null diff --git a/models/textclassifier.nl.model b/models/textclassifier.nl.model Binary files differnew file mode 100644 index 0000000..07ea076 --- /dev/null +++ b/models/textclassifier.nl.model diff --git a/models/textclassifier.pl.model b/models/textclassifier.pl.model Binary files differnew file mode 100644 index 0000000..6cf62a5 --- /dev/null +++ b/models/textclassifier.pl.model diff --git a/models/textclassifier.pt.model b/models/textclassifier.pt.model Binary files differnew file mode 100644 index 0000000..a745d58 --- /dev/null +++ b/models/textclassifier.pt.model diff --git a/models/textclassifier.ru.model b/models/textclassifier.ru.model Binary files differnew file mode 100644 index 0000000..aa97ebc --- /dev/null +++ b/models/textclassifier.ru.model diff --git a/models/textclassifier.smartselection.ar.model b/models/textclassifier.smartselection.ar.model Binary files differdeleted file mode 100644 index f22fe0f..0000000 --- a/models/textclassifier.smartselection.ar.model +++ /dev/null diff --git a/models/textclassifier.smartselection.de.model b/models/textclassifier.smartselection.de.model Binary files differdeleted file mode 100644 index 5eb3181..0000000 --- a/models/textclassifier.smartselection.de.model +++ /dev/null diff --git a/models/textclassifier.smartselection.en.model b/models/textclassifier.smartselection.en.model Binary files differdeleted file mode 100644 index 7af0897..0000000 --- a/models/textclassifier.smartselection.en.model +++ /dev/null diff --git a/models/textclassifier.smartselection.es.model b/models/textclassifier.smartselection.es.model Binary files differdeleted file mode 100644 index 9ea6af9..0000000 --- a/models/textclassifier.smartselection.es.model +++ /dev/null diff --git a/models/textclassifier.smartselection.fr.model b/models/textclassifier.smartselection.fr.model Binary files differdeleted file mode 100644 index 3ff5416..0000000 --- a/models/textclassifier.smartselection.fr.model +++ /dev/null diff --git a/models/textclassifier.smartselection.it.model b/models/textclassifier.smartselection.it.model Binary files differdeleted file mode 100644 index 377fff5..0000000 --- a/models/textclassifier.smartselection.it.model +++ /dev/null diff --git a/models/textclassifier.smartselection.ja.model b/models/textclassifier.smartselection.ja.model Binary files differdeleted file mode 100644 index 53fce93..0000000 --- a/models/textclassifier.smartselection.ja.model +++ /dev/null diff --git a/models/textclassifier.smartselection.ko.model b/models/textclassifier.smartselection.ko.model Binary files differdeleted file mode 100644 index 6bcac15..0000000 --- a/models/textclassifier.smartselection.ko.model +++ /dev/null diff --git a/models/textclassifier.smartselection.nl.model b/models/textclassifier.smartselection.nl.model Binary files differdeleted file mode 100644 index c80dff6..0000000 --- a/models/textclassifier.smartselection.nl.model +++ /dev/null diff --git a/models/textclassifier.smartselection.pl.model b/models/textclassifier.smartselection.pl.model Binary files differdeleted file mode 100644 index 3379c63..0000000 --- a/models/textclassifier.smartselection.pl.model +++ /dev/null diff --git a/models/textclassifier.smartselection.pt.model b/models/textclassifier.smartselection.pt.model Binary files differdeleted file mode 100644 index 4378c8f..0000000 --- a/models/textclassifier.smartselection.pt.model +++ /dev/null diff --git a/models/textclassifier.smartselection.ru.model b/models/textclassifier.smartselection.ru.model Binary files differdeleted file mode 100644 index 0763b33..0000000 --- a/models/textclassifier.smartselection.ru.model +++ /dev/null diff --git a/models/textclassifier.smartselection.th.model b/models/textclassifier.smartselection.th.model Binary files differdeleted file mode 100644 index 521fea0..0000000 --- a/models/textclassifier.smartselection.th.model +++ /dev/null diff --git a/models/textclassifier.smartselection.tr.model b/models/textclassifier.smartselection.tr.model Binary files differdeleted file mode 100644 index 0177175..0000000 --- a/models/textclassifier.smartselection.tr.model +++ /dev/null diff --git a/models/textclassifier.smartselection.zh-Hant.model b/models/textclassifier.smartselection.zh-Hant.model Binary files differdeleted file mode 100644 index ec03c26..0000000 --- a/models/textclassifier.smartselection.zh-Hant.model +++ /dev/null diff --git a/models/textclassifier.smartselection.zh.model b/models/textclassifier.smartselection.zh.model Binary files differdeleted file mode 100644 index acc6142..0000000 --- a/models/textclassifier.smartselection.zh.model +++ /dev/null diff --git a/models/textclassifier.th.model b/models/textclassifier.th.model Binary files differnew file mode 100644 index 0000000..37339b7 --- /dev/null +++ b/models/textclassifier.th.model diff --git a/models/textclassifier.tr.model b/models/textclassifier.tr.model Binary files differnew file mode 100644 index 0000000..2405d9e --- /dev/null +++ b/models/textclassifier.tr.model diff --git a/models/textclassifier.universal.model b/models/textclassifier.universal.model Binary files differnew file mode 100644 index 0000000..5c4220f --- /dev/null +++ b/models/textclassifier.universal.model diff --git a/models/textclassifier.zh-Hant.model b/models/textclassifier.zh-Hant.model Binary files differnew file mode 100644 index 0000000..32edfe4 --- /dev/null +++ b/models/textclassifier.zh-Hant.model diff --git a/models/textclassifier.zh.model b/models/textclassifier.zh.model Binary files differnew file mode 100644 index 0000000..eb1ff61 --- /dev/null +++ b/models/textclassifier.zh.model diff --git a/models/update.sh b/models/update.sh new file mode 100755 index 0000000..8b60d2f --- /dev/null +++ b/models/update.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Updates the set of model with the most recent ones. + +set -e + +BASE_URL=https://www.gstatic.com/android/text_classifier/p/live + +cd "$(dirname "$0")" + +for f in $(wget -O- "$BASE_URL/FILELIST"); do + wget "$BASE_URL/$f" -O "$f" +done diff --git a/quantization.cc b/quantization.cc new file mode 100644 index 0000000..1a34565 --- /dev/null +++ b/quantization.cc @@ -0,0 +1,92 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "quantization.h" + +#include "util/base/logging.h" + +namespace libtextclassifier2 { +namespace { +float DequantizeValue(int num_sparse_features, int quantization_bias, + float multiplier, int value) { + return 1.0 / num_sparse_features * (value - quantization_bias) * multiplier; +} + +void DequantizeAdd8bit(const float* scales, const uint8* embeddings, + int bytes_per_embedding, const int num_sparse_features, + const int bucket_id, float* dest, int dest_size) { + static const int kQuantizationBias8bit = 128; + const float multiplier = scales[bucket_id]; + for (int k = 0; k < dest_size; ++k) { + dest[k] += + DequantizeValue(num_sparse_features, kQuantizationBias8bit, multiplier, + embeddings[bucket_id * bytes_per_embedding + k]); + } +} + +void DequantizeAddNBit(const float* scales, const uint8* embeddings, + int bytes_per_embedding, int num_sparse_features, + int quantization_bits, int bucket_id, float* dest, + int dest_size) { + const int quantization_bias = 1 << (quantization_bits - 1); + const float multiplier = scales[bucket_id]; + for (int i = 0; i < dest_size; ++i) { + const int bit_offset = i * quantization_bits; + const int read16_offset = bit_offset / 8; + + uint16 data = embeddings[bucket_id * bytes_per_embedding + read16_offset]; + // If we are not at the end of the embedding row, we can read 2-byte uint16, + // but if we are, we need to only read uint8. + if (read16_offset < bytes_per_embedding - 1) { + data |= embeddings[bucket_id * bytes_per_embedding + read16_offset + 1] + << 8; + } + int value = (data >> (bit_offset % 8)) & ((1 << quantization_bits) - 1); + dest[i] += DequantizeValue(num_sparse_features, quantization_bias, + multiplier, value); + } +} +} // namespace + +bool CheckQuantizationParams(int bytes_per_embedding, int quantization_bits, + int output_embedding_size) { + if (bytes_per_embedding * 8 / quantization_bits < output_embedding_size) { + return false; + } + + return true; +} + +bool DequantizeAdd(const float* scales, const uint8* embeddings, + int bytes_per_embedding, int num_sparse_features, + int quantization_bits, int bucket_id, float* dest, + int dest_size) { + if (quantization_bits == 8) { + DequantizeAdd8bit(scales, embeddings, bytes_per_embedding, + num_sparse_features, bucket_id, dest, dest_size); + } else if (quantization_bits != 8) { + DequantizeAddNBit(scales, embeddings, bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, dest, + dest_size); + } else { + TC_LOG(ERROR) << "Unsupported quantization_bits: " << quantization_bits; + return false; + } + + return true; +} + +} // namespace libtextclassifier2 diff --git a/quantization.h b/quantization.h new file mode 100644 index 0000000..c486640 --- /dev/null +++ b/quantization.h @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_QUANTIZATION_H_ +#define LIBTEXTCLASSIFIER_QUANTIZATION_H_ + +#include "util/base/integral_types.h" + +namespace libtextclassifier2 { + +// Returns true if the quantization parameters are valid. +bool CheckQuantizationParams(int bytes_per_embedding, int quantization_bits, + int output_embedding_size); + +// Dequantizes embeddings (quantized to 1 to 8 bits) into the floats they +// represent. The algorithm proceeds by reading 2-byte words from the embedding +// storage to handle well the cases when the quantized value crosses the byte- +// boundary. +bool DequantizeAdd(const float* scales, const uint8* embeddings, + int bytes_per_embedding, int num_sparse_features, + int quantization_bits, int bucket_id, float* dest, + int dest_size); + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_QUANTIZATION_H_ diff --git a/quantization_test.cc b/quantization_test.cc new file mode 100644 index 0000000..088daaf --- /dev/null +++ b/quantization_test.cc @@ -0,0 +1,163 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "quantization.h" + +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::ElementsAreArray; +using testing::FloatEq; +using testing::Matcher; + +namespace libtextclassifier2 { +namespace { + +Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) { + std::vector<Matcher<float>> matchers; + for (const float value : values) { + matchers.push_back(FloatEq(value)); + } + return ElementsAreArray(matchers); +} + +TEST(QuantizationTest, DequantizeAdd8bit) { + std::vector<float> scales{{0.1, 9.0, -7.0}}; + std::vector<uint8> embeddings{{/*0: */ 0x00, 0xFF, 0x09, 0x00, + /*1: */ 0xFF, 0x09, 0x00, 0xFF, + /*2: */ 0x09, 0x00, 0xFF, 0x09}}; + + const int quantization_bits = 8; + const int bytes_per_embedding = 4; + const int num_sparse_features = 7; + { + const int bucket_id = 0; + std::vector<float> dest(4, 0.0); + DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, + dest.data(), dest.size()); + + EXPECT_THAT(dest, + ElementsAreFloat(std::vector<float>{ + // clang-format off + {1.0 / 7 * 0.1 * (0x00 - 128), + 1.0 / 7 * 0.1 * (0xFF - 128), + 1.0 / 7 * 0.1 * (0x09 - 128), + 1.0 / 7 * 0.1 * (0x00 - 128)} + // clang-format on + })); + } + + { + const int bucket_id = 1; + std::vector<float> dest(4, 0.0); + DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, + dest.data(), dest.size()); + + EXPECT_THAT(dest, + ElementsAreFloat(std::vector<float>{ + // clang-format off + {1.0 / 7 * 9.0 * (0xFF - 128), + 1.0 / 7 * 9.0 * (0x09 - 128), + 1.0 / 7 * 9.0 * (0x00 - 128), + 1.0 / 7 * 9.0 * (0xFF - 128)} + // clang-format on + })); + } +} + +TEST(QuantizationTest, DequantizeAdd1bitZeros) { + const int bytes_per_embedding = 4; + const int num_buckets = 3; + const int num_sparse_features = 7; + const int quantization_bits = 1; + const int bucket_id = 1; + + std::vector<float> scales(num_buckets); + std::vector<uint8> embeddings(bytes_per_embedding * num_buckets); + std::fill(scales.begin(), scales.end(), 1); + std::fill(embeddings.begin(), embeddings.end(), 0); + + std::vector<float> dest(32); + DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, dest.data(), + dest.size()); + + std::vector<float> expected(32); + std::fill(expected.begin(), expected.end(), + 1.0 / num_sparse_features * (0 - 1)); + EXPECT_THAT(dest, ElementsAreFloat(expected)); +} + +TEST(QuantizationTest, DequantizeAdd1bitOnes) { + const int bytes_per_embedding = 4; + const int num_buckets = 3; + const int num_sparse_features = 7; + const int quantization_bits = 1; + const int bucket_id = 1; + + std::vector<float> scales(num_buckets, 1.0); + std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0xFF); + + std::vector<float> dest(32); + DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, dest.data(), + dest.size()); + std::vector<float> expected(32); + std::fill(expected.begin(), expected.end(), + 1.0 / num_sparse_features * (1 - 1)); + EXPECT_THAT(dest, ElementsAreFloat(expected)); +} + +TEST(QuantizationTest, DequantizeAdd3bit) { + const int bytes_per_embedding = 4; + const int num_buckets = 3; + const int num_sparse_features = 7; + const int quantization_bits = 3; + const int bucket_id = 1; + + std::vector<float> scales(num_buckets, 1.0); + scales[1] = 9.0; + std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0); + // For bucket_id=1, the embedding has values 0..9 for indices 0..9: + embeddings[4] = (1 << 7) | (1 << 6) | (1 << 4) | 1; + embeddings[5] = (1 << 6) | (1 << 4) | (1 << 3); + embeddings[6] = (1 << 4) | (1 << 3) | (1 << 2) | (1 << 1) | 1; + + std::vector<float> dest(10); + DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, dest.data(), + dest.size()); + + std::vector<float> expected; + expected.push_back(1.0 / num_sparse_features * (1 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (2 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (3 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (4 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (5 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (6 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (7 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]); + EXPECT_THAT(dest, ElementsAreFloat(expected)); +} + +} // namespace +} // namespace libtextclassifier2 diff --git a/smartselect/cached-features.cc b/smartselect/cached-features.cc deleted file mode 100644 index c249db9..0000000 --- a/smartselect/cached-features.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "smartselect/cached-features.h" -#include "util/base/logging.h" - -namespace libtextclassifier { - -void CachedFeatures::Extract( - const std::vector<std::vector<int>>& sparse_features, - const std::vector<std::vector<float>>& dense_features, - const std::function<bool(const std::vector<int>&, const std::vector<float>&, - float*)>& feature_vector_fn) { - features_.resize(feature_vector_size_ * tokens_.size()); - for (int i = 0; i < tokens_.size(); ++i) { - feature_vector_fn(sparse_features[i], dense_features[i], - features_.data() + i * feature_vector_size_); - } -} - -bool CachedFeatures::Get(int click_pos, VectorSpan<float>* features, - VectorSpan<Token>* output_tokens) { - const int token_start = click_pos - context_size_; - const int token_end = click_pos + context_size_ + 1; - if (token_start < 0 || token_end > tokens_.size()) { - TC_LOG(ERROR) << "Tokens out of range: " << token_start << " " << token_end; - return false; - } - - *features = - VectorSpan<float>(features_.begin() + token_start * feature_vector_size_, - features_.begin() + token_end * feature_vector_size_); - *output_tokens = VectorSpan<Token>(tokens_.begin() + token_start, - tokens_.begin() + token_end); - if (remap_v0_feature_vector_) { - RemapV0FeatureVector(features); - } - - return true; -} - -void CachedFeatures::RemapV0FeatureVector(VectorSpan<float>* features) { - if (!remap_v0_feature_vector_) { - return; - } - - auto it = features->begin(); - int num_suffix_features = - feature_vector_size_ - remap_v0_chargram_embedding_size_; - int num_tokens = context_size_ * 2 + 1; - for (int t = 0; t < num_tokens; ++t) { - for (int i = 0; i < remap_v0_chargram_embedding_size_; ++i) { - v0_feature_storage_[t * remap_v0_chargram_embedding_size_ + i] = *it; - ++it; - } - // Rest of the features are the dense features that come to the end. - for (int i = 0; i < num_suffix_features; ++i) { - // clang-format off - v0_feature_storage_[num_tokens * remap_v0_chargram_embedding_size_ - + t * num_suffix_features - + i] = *it; - // clang-format on - ++it; - } - } - *features = VectorSpan<float>(v0_feature_storage_); -} - -} // namespace libtextclassifier diff --git a/smartselect/cached-features.h b/smartselect/cached-features.h deleted file mode 100644 index 990233c..0000000 --- a/smartselect/cached-features.h +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_ -#define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_ - -#include <memory> -#include <vector> - -#include "common/vector-span.h" -#include "smartselect/types.h" - -namespace libtextclassifier { - -// Holds state for extracting features across multiple calls and reusing them. -// Assumes that features for each Token are independent. -class CachedFeatures { - public: - // Extracts the features for the given sequence of tokens. - // - context_size: Specifies how many tokens to the left, and how many - // tokens to the right spans the context. - // - sparse_features, dense_features: Extracted features for each token. - // - feature_vector_fn: Writes features for given Token to the specified - // storage. - // NOTE: The function can assume that the underlying - // storage is initialized to all zeros. - // - feature_vector_size: Size of a feature vector for one Token. - CachedFeatures(VectorSpan<Token> tokens, int context_size, - const std::vector<std::vector<int>>& sparse_features, - const std::vector<std::vector<float>>& dense_features, - const std::function<bool(const std::vector<int>&, - const std::vector<float>&, float*)>& - feature_vector_fn, - int feature_vector_size) - : tokens_(tokens), - context_size_(context_size), - feature_vector_size_(feature_vector_size), - remap_v0_feature_vector_(false), - remap_v0_chargram_embedding_size_(-1) { - Extract(sparse_features, dense_features, feature_vector_fn); - } - - // Gets a VectorSpan with the features for given click position. - bool Get(int click_pos, VectorSpan<float>* features, - VectorSpan<Token>* output_tokens); - - // Turns on a compatibility mode, which re-maps the extracted features to the - // v0 feature format (where the dense features were at the end). - // WARNING: Internally v0_feature_storage_ is used as a backing buffer for - // VectorSpan<float>, so the output of Extract is valid only until the next - // call or destruction of the current CachedFeatures object. - // TODO(zilka): Remove when we'll have retrained models. - void SetV0FeatureMode(int chargram_embedding_size) { - remap_v0_feature_vector_ = true; - remap_v0_chargram_embedding_size_ = chargram_embedding_size; - v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1)); - } - - protected: - // Extracts features for all tokens and stores them for later retrieval. - void Extract(const std::vector<std::vector<int>>& sparse_features, - const std::vector<std::vector<float>>& dense_features, - const std::function<bool(const std::vector<int>&, - const std::vector<float>&, float*)>& - feature_vector_fn); - - // Remaps extracted features to V0 feature format. The mapping is using - // the v0_feature_storage_ as the backing storage for the mapped features. - // For each token the features consist of: - // - chargram embeddings - // - dense features - // They are concatenated together as [chargram embeddings; dense features] - // for each token independently. - // The V0 features require that the chargram embeddings for tokens are - // concatenated first together, and at the end, the dense features for the - // tokens are concatenated to it. - void RemapV0FeatureVector(VectorSpan<float>* features); - - private: - const VectorSpan<Token> tokens_; - const int context_size_; - const int feature_vector_size_; - bool remap_v0_feature_vector_; - int remap_v0_chargram_embedding_size_; - - std::vector<float> features_; - std::vector<float> v0_feature_storage_; -}; - -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_ diff --git a/smartselect/cached-features_test.cc b/smartselect/cached-features_test.cc deleted file mode 100644 index b456816..0000000 --- a/smartselect/cached-features_test.cc +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "smartselect/cached-features.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace libtextclassifier { -namespace { - -class TestingCachedFeatures : public CachedFeatures { - public: - using CachedFeatures::CachedFeatures; - using CachedFeatures::RemapV0FeatureVector; -}; - -TEST(CachedFeaturesTest, Simple) { - std::vector<Token> tokens; - tokens.push_back(Token()); - tokens.push_back(Token()); - tokens.push_back(Token("Hello", 0, 1)); - tokens.push_back(Token("World", 1, 2)); - tokens.push_back(Token("today!", 2, 3)); - tokens.push_back(Token()); - tokens.push_back(Token()); - - std::vector<std::vector<int>> sparse_features(tokens.size()); - for (int i = 0; i < sparse_features.size(); ++i) { - sparse_features[i].push_back(i); - } - std::vector<std::vector<float>> dense_features(tokens.size()); - for (int i = 0; i < dense_features.size(); ++i) { - dense_features[i].push_back(-i); - } - - TestingCachedFeatures feature_extractor( - tokens, /*context_size=*/2, sparse_features, dense_features, - [](const std::vector<int>& sparse_features, - const std::vector<float>& dense_features, float* features) { - features[0] = sparse_features[0]; - features[1] = sparse_features[0]; - features[2] = dense_features[0]; - features[3] = dense_features[0]; - features[4] = 123; - return true; - }, - 5); - - VectorSpan<float> features; - VectorSpan<Token> output_tokens; - EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens)); - for (int i = 0; i < 5; i++) { - EXPECT_EQ(features[i * 5 + 0], i) << "Feature " << i; - EXPECT_EQ(features[i * 5 + 1], i) << "Feature " << i; - EXPECT_EQ(features[i * 5 + 2], -i) << "Feature " << i; - EXPECT_EQ(features[i * 5 + 3], -i) << "Feature " << i; - EXPECT_EQ(features[i * 5 + 4], 123) << "Feature " << i; - } -} - -TEST(CachedFeaturesTest, InvalidInput) { - std::vector<Token> tokens; - tokens.push_back(Token()); - tokens.push_back(Token()); - tokens.push_back(Token("Hello", 0, 1)); - tokens.push_back(Token("World", 1, 2)); - tokens.push_back(Token("today!", 2, 3)); - tokens.push_back(Token()); - tokens.push_back(Token()); - - std::vector<std::vector<int>> sparse_features(tokens.size()); - std::vector<std::vector<float>> dense_features(tokens.size()); - - TestingCachedFeatures feature_extractor( - tokens, /*context_size=*/2, sparse_features, dense_features, - [](const std::vector<int>& sparse_features, - const std::vector<float>& dense_features, - float* features) { return true; }, - /*feature_vector_size=*/5); - - VectorSpan<float> features; - VectorSpan<Token> output_tokens; - EXPECT_FALSE(feature_extractor.Get(-1000, &features, &output_tokens)); - EXPECT_FALSE(feature_extractor.Get(-1, &features, &output_tokens)); - EXPECT_FALSE(feature_extractor.Get(0, &features, &output_tokens)); - EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens)); - EXPECT_TRUE(feature_extractor.Get(4, &features, &output_tokens)); - EXPECT_FALSE(feature_extractor.Get(5, &features, &output_tokens)); - EXPECT_FALSE(feature_extractor.Get(500, &features, &output_tokens)); -} - -TEST(CachedFeaturesTest, RemapV0FeatureVector) { - std::vector<Token> tokens; - tokens.push_back(Token()); - tokens.push_back(Token()); - tokens.push_back(Token("Hello", 0, 1)); - tokens.push_back(Token("World", 1, 2)); - tokens.push_back(Token("today!", 2, 3)); - tokens.push_back(Token()); - tokens.push_back(Token()); - - std::vector<std::vector<int>> sparse_features(tokens.size()); - std::vector<std::vector<float>> dense_features(tokens.size()); - - TestingCachedFeatures feature_extractor( - tokens, /*context_size=*/2, sparse_features, dense_features, - [](const std::vector<int>& sparse_features, - const std::vector<float>& dense_features, - float* features) { return true; }, - /*feature_vector_size=*/5); - - std::vector<float> features_orig(5 * 5); - for (int i = 0; i < features_orig.size(); i++) { - features_orig[i] = i; - } - VectorSpan<float> features; - - feature_extractor.SetV0FeatureMode(0); - features = VectorSpan<float>(features_orig); - feature_extractor.RemapV0FeatureVector(&features); - EXPECT_EQ( - std::vector<float>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}), - std::vector<float>(features.begin(), features.end())); - - feature_extractor.SetV0FeatureMode(2); - features = VectorSpan<float>(features_orig); - feature_extractor.RemapV0FeatureVector(&features); - EXPECT_EQ(std::vector<float>({0, 1, 5, 6, 10, 11, 15, 16, 20, 21, 2, 3, 4, - 7, 8, 9, 12, 13, 14, 17, 18, 19, 22, 23, 24}), - std::vector<float>(features.begin(), features.end())); -} - -} // namespace -} // namespace libtextclassifier diff --git a/smartselect/model-params.cc b/smartselect/model-params.cc deleted file mode 100644 index 65c4f93..0000000 --- a/smartselect/model-params.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "smartselect/model-params.h" - -#include "common/memory_image/memory-image-reader.h" - -namespace libtextclassifier { - -using nlp_core::EmbeddingNetworkProto; -using nlp_core::MemoryImageReader; - -ModelParams* ModelParamsBuilder( - const void* start, uint64 num_bytes, - std::shared_ptr<EmbeddingParams> external_embedding_params) { - MemoryImageReader<EmbeddingNetworkProto> reader(start, num_bytes); - - ModelOptions model_options; - auto model_options_extension_id = model_options_in_embedding_network_proto; - if (reader.trimmed_proto().HasExtension(model_options_extension_id)) { - model_options = - reader.trimmed_proto().GetExtension(model_options_extension_id); - } - - FeatureProcessorOptions feature_processor_options; - auto feature_processor_extension_id = - feature_processor_options_in_embedding_network_proto; - if (reader.trimmed_proto().HasExtension(feature_processor_extension_id)) { - feature_processor_options = - reader.trimmed_proto().GetExtension(feature_processor_extension_id); - - // If no tokenization codepoint config is present, tokenize on space. - // TODO(zilka): Remove the default config. - if (feature_processor_options.tokenization_codepoint_config_size() == 0) { - TokenizationCodepointRange* config; - // New line character. - config = feature_processor_options.add_tokenization_codepoint_config(); - config->set_start(10); - config->set_end(11); - config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); - - // Space character. - config = feature_processor_options.add_tokenization_codepoint_config(); - config->set_start(32); - config->set_end(33); - config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); - } - } else { - return nullptr; - } - - SelectionModelOptions selection_options; - auto selection_options_extension_id = - selection_model_options_in_embedding_network_proto; - if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) { - selection_options = - reader.trimmed_proto().GetExtension(selection_options_extension_id); - - // For backward compatibility with the current models. - if (!feature_processor_options.ignored_span_boundary_codepoints_size()) { - *feature_processor_options.mutable_ignored_span_boundary_codepoints() = - selection_options.deprecated_punctuation_to_strip(); - } - } else { - selection_options.set_enforce_symmetry(true); - selection_options.set_symmetry_context_size( - feature_processor_options.context_size() * 2); - } - - SharingModelOptions sharing_options; - auto sharing_options_extension_id = - sharing_model_options_in_embedding_network_proto; - if (reader.trimmed_proto().HasExtension(sharing_options_extension_id)) { - sharing_options = - reader.trimmed_proto().GetExtension(sharing_options_extension_id); - } else { - // Default values when SharingModelOptions is not present. - sharing_options.set_always_accept_url_hint(true); - sharing_options.set_always_accept_email_hint(true); - } - - if (!model_options.use_shared_embeddings()) { - std::shared_ptr<EmbeddingParams> embedding_params(new EmbeddingParams( - start, num_bytes, feature_processor_options.context_size())); - return new ModelParams(start, num_bytes, embedding_params, - selection_options, sharing_options, - feature_processor_options); - } else { - return new ModelParams( - start, num_bytes, std::move(external_embedding_params), - selection_options, sharing_options, feature_processor_options); - } -} - -} // namespace libtextclassifier diff --git a/smartselect/model-params.h b/smartselect/model-params.h deleted file mode 100644 index a0d39e6..0000000 --- a/smartselect/model-params.h +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Model parameter loading. - -#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_ -#define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_ - -#include "common/embedding-network.h" -#include "common/memory_image/embedding-network-params-from-image.h" -#include "smartselect/text-classification-model.pb.h" - -namespace libtextclassifier { - -class EmbeddingParams : public nlp_core::EmbeddingNetworkParamsFromImage { - public: - EmbeddingParams(const void* start, uint64 num_bytes, int context_size) - : EmbeddingNetworkParamsFromImage(start, num_bytes), - context_size_(context_size) {} - - int embeddings_size() const override { return context_size_ * 2 + 1; } - - int embedding_num_features_size() const override { - return context_size_ * 2 + 1; - } - - int embedding_num_features(int i) const override { return 1; } - - int embeddings_num_rows(int i) const override { - return EmbeddingNetworkParamsFromImage::embeddings_num_rows(0); - }; - - int embeddings_num_cols(int i) const override { - return EmbeddingNetworkParamsFromImage::embeddings_num_cols(0); - }; - - const void* embeddings_weights(int i) const override { - return EmbeddingNetworkParamsFromImage::embeddings_weights(0); - }; - - nlp_core::QuantizationType embeddings_quant_type(int i) const override { - return EmbeddingNetworkParamsFromImage::embeddings_quant_type(0); - } - - const nlp_core::float16* embeddings_quant_scales(int i) const override { - return EmbeddingNetworkParamsFromImage::embeddings_quant_scales(0); - } - - private: - int context_size_; -}; - -// Loads and holds the parameters of the inference network. -// -// This class overrides a couple of methods of EmbeddingNetworkParamsFromImage -// because we only have one embedding matrix for all positions of context, -// whereas the original class would have a separate one for each. -class ModelParams : public nlp_core::EmbeddingNetworkParamsFromImage { - public: - const FeatureProcessorOptions& GetFeatureProcessorOptions() const { - return feature_processor_options_; - } - - const SelectionModelOptions& GetSelectionModelOptions() const { - return selection_options_; - } - - const SharingModelOptions& GetSharingModelOptions() const { - return sharing_options_; - } - - std::shared_ptr<EmbeddingParams> GetEmbeddingParams() const { - return embedding_params_; - } - - protected: - int embeddings_size() const override { - return embedding_params_->embeddings_size(); - } - - int embedding_num_features_size() const override { - return embedding_params_->embedding_num_features_size(); - } - - int embedding_num_features(int i) const override { - return embedding_params_->embedding_num_features(i); - } - - int embeddings_num_rows(int i) const override { - return embedding_params_->embeddings_num_rows(i); - }; - - int embeddings_num_cols(int i) const override { - return embedding_params_->embeddings_num_cols(i); - }; - - const void* embeddings_weights(int i) const override { - return embedding_params_->embeddings_weights(i); - }; - - nlp_core::QuantizationType embeddings_quant_type(int i) const override { - return embedding_params_->embeddings_quant_type(i); - } - - const nlp_core::float16* embeddings_quant_scales(int i) const override { - return embedding_params_->embeddings_quant_scales(i); - } - - private: - friend ModelParams* ModelParamsBuilder( - const void* start, uint64 num_bytes, - std::shared_ptr<EmbeddingParams> external_embedding_params); - - ModelParams(const void* start, uint64 num_bytes, - std::shared_ptr<EmbeddingParams> embedding_params, - const SelectionModelOptions& selection_options, - const SharingModelOptions& sharing_options, - const FeatureProcessorOptions& feature_processor_options) - : EmbeddingNetworkParamsFromImage(start, num_bytes), - selection_options_(selection_options), - sharing_options_(sharing_options), - feature_processor_options_(feature_processor_options), - context_size_(feature_processor_options_.context_size()), - embedding_params_(std::move(embedding_params)) {} - - SelectionModelOptions selection_options_; - SharingModelOptions sharing_options_; - FeatureProcessorOptions feature_processor_options_; - int context_size_; - std::shared_ptr<EmbeddingParams> embedding_params_; -}; - -ModelParams* ModelParamsBuilder( - const void* start, uint64 num_bytes, - std::shared_ptr<EmbeddingParams> external_embedding_params); - -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_ diff --git a/smartselect/model-parser.cc b/smartselect/model-parser.cc deleted file mode 100644 index 0cf05e3..0000000 --- a/smartselect/model-parser.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "smartselect/model-parser.h" -#include "util/base/endian.h" - -namespace libtextclassifier { -namespace { - -// Small helper class for parsing the merged model format. -// The merged model consists of interleaved <int32 data_size, char* data> -// segments. -class MergedModelParser { - public: - MergedModelParser(const void* addr, const int size) - : addr_(reinterpret_cast<const char*>(addr)), size_(size), pos_(addr_) {} - - bool ReadBytesAndAdvance(int num_bytes, const char** result) { - const char* read_addr = pos_; - if (Advance(num_bytes)) { - *result = read_addr; - return true; - } else { - return false; - } - } - - bool ReadInt32AndAdvance(int* result) { - const char* read_addr = pos_; - if (Advance(sizeof(int))) { - *result = - LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(read_addr)); - return true; - } else { - return false; - } - } - - bool IsDone() { return pos_ == addr_ + size_; } - - private: - bool Advance(int num_bytes) { - pos_ += num_bytes; - return pos_ <= addr_ + size_; - } - - const char* addr_; - const int size_; - const char* pos_; -}; - -} // namespace - -bool ParseMergedModel(const void* addr, const int size, - const char** selection_model, int* selection_model_length, - const char** sharing_model, int* sharing_model_length) { - MergedModelParser parser(addr, size); - - if (!parser.ReadInt32AndAdvance(selection_model_length)) { - return false; - } - - if (!parser.ReadBytesAndAdvance(*selection_model_length, selection_model)) { - return false; - } - - if (!parser.ReadInt32AndAdvance(sharing_model_length)) { - return false; - } - - if (!parser.ReadBytesAndAdvance(*sharing_model_length, sharing_model)) { - return false; - } - - return parser.IsDone(); -} - -} // namespace libtextclassifier diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc deleted file mode 100644 index e7ae09c..0000000 --- a/smartselect/text-classification-model.cc +++ /dev/null @@ -1,741 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "smartselect/text-classification-model.h" - -#include <cctype> -#include <cmath> -#include <iterator> -#include <numeric> - -#include "common/embedding-network.h" -#include "common/feature-extractor.h" -#include "common/memory_image/embedding-network-params-from-image.h" -#include "common/memory_image/memory-image-reader.h" -#include "common/mmap.h" -#include "common/softmax.h" -#include "smartselect/model-parser.h" -#include "smartselect/text-classification-model.pb.h" -#include "util/base/logging.h" -#include "util/utf8/unicodetext.h" -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT -#include "unicode/regex.h" -#include "unicode/uchar.h" -#endif - -namespace libtextclassifier { - -using nlp_core::EmbeddingNetwork; -using nlp_core::EmbeddingNetworkProto; -using nlp_core::FeatureVector; -using nlp_core::MemoryImageReader; -using nlp_core::MmapFile; -using nlp_core::MmapHandle; -using nlp_core::ScopedMmap; - -namespace { - -int CountDigits(const std::string& str, CodepointSpan selection_indices) { - int count = 0; - int i = 0; - const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false); - for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) { - if (i >= selection_indices.first && i < selection_indices.second && - isdigit(*it)) { - ++count; - } - } - return count; -} - -std::string ExtractSelection(const std::string& context, - CodepointSpan selection_indices) { - const UnicodeText context_unicode = - UTF8ToUnicodeText(context, /*do_copy=*/false); - auto selection_begin = context_unicode.begin(); - std::advance(selection_begin, selection_indices.first); - auto selection_end = context_unicode.begin(); - std::advance(selection_end, selection_indices.second); - return UnicodeText::UTF8Substring(selection_begin, selection_end); -} - -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT -bool MatchesRegex(const icu::RegexPattern* regex, const std::string& context) { - const icu::UnicodeString unicode_context(context.c_str(), context.size(), - "utf-8"); - UErrorCode status = U_ZERO_ERROR; - std::unique_ptr<icu::RegexMatcher> matcher( - regex->matcher(unicode_context, status)); - return matcher->matches(0 /* start */, status); -} -#endif - -} // namespace - -TextClassificationModel::TextClassificationModel(const std::string& path) - : mmap_(new nlp_core::ScopedMmap(path)) { - InitFromMmap(); -} - -TextClassificationModel::TextClassificationModel(int fd) - : mmap_(new nlp_core::ScopedMmap(fd)) { - InitFromMmap(); -} - -TextClassificationModel::TextClassificationModel(int fd, int offset, int size) - : mmap_(new nlp_core::ScopedMmap(fd, offset, size)) { - InitFromMmap(); -} - -TextClassificationModel::TextClassificationModel(const void* addr, int size) { - initialized_ = LoadModels(addr, size); - if (!initialized_) { - TC_LOG(ERROR) << "Failed to load models"; - return; - } -} - -void TextClassificationModel::InitFromMmap() { - if (!mmap_->handle().ok()) { - return; - } - - initialized_ = - LoadModels(mmap_->handle().start(), mmap_->handle().num_bytes()); - if (!initialized_) { - TC_LOG(ERROR) << "Failed to load models"; - return; - } -} - -namespace { - -// Converts sparse features vector to nlp_core::FeatureVector. -void SparseFeaturesToFeatureVector( - const std::vector<int> sparse_features, - const nlp_core::NumericFeatureType& feature_type, - nlp_core::FeatureVector* result) { - for (int feature_id : sparse_features) { - const int64 feature_value = - nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size()) - .discrete_value; - result->add(const_cast<nlp_core::NumericFeatureType*>(&feature_type), - feature_value); - } -} - -// Returns a function that can be used for mapping sparse and dense features -// to a float feature vector. -// NOTE: The network object needs to be available at the time when the returned -// function object is used. -FeatureVectorFn CreateFeatureVectorFn(const EmbeddingNetwork& network, - int sparse_embedding_size) { - const nlp_core::NumericFeatureType feature_type("chargram_continuous", 0); - return [&network, sparse_embedding_size, feature_type]( - const std::vector<int>& sparse_features, - const std::vector<float>& dense_features, float* embedding) { - nlp_core::FeatureVector feature_vector; - SparseFeaturesToFeatureVector(sparse_features, feature_type, - &feature_vector); - - if (network.GetEmbedding(feature_vector, 0, embedding)) { - for (int i = 0; i < dense_features.size(); i++) { - embedding[sparse_embedding_size + i] = dense_features[i]; - } - return true; - } else { - return false; - } - }; -} - -} // namespace - -void TextClassificationModel::InitializeSharingRegexPatterns( - const std::vector<SharingModelOptions::RegexPattern>& patterns) { -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - // Initialize pattern recognizers. - for (const auto& regex_pattern : patterns) { - UErrorCode status = U_ZERO_ERROR; - std::unique_ptr<icu::RegexPattern> compiled_pattern( - icu::RegexPattern::compile( - icu::UnicodeString(regex_pattern.pattern().c_str(), - regex_pattern.pattern().size(), "utf-8"), - 0 /* flags */, status)); - if (U_FAILURE(status)) { - TC_LOG(WARNING) << "Failed to load pattern" << regex_pattern.pattern(); - } else { - regex_patterns_.push_back( - {regex_pattern.collection_name(), std::move(compiled_pattern)}); - } - } -#else - if (!patterns.empty()) { - TC_LOG(WARNING) << "ICU not supported regexp matchers ignored."; - } -#endif -} - -bool TextClassificationModel::LoadModels(const void* addr, int size) { - const char *selection_model, *sharing_model; - int selection_model_length, sharing_model_length; - if (!ParseMergedModel(addr, size, &selection_model, &selection_model_length, - &sharing_model, &sharing_model_length)) { - TC_LOG(ERROR) << "Couldn't parse the model."; - return false; - } - - selection_params_.reset( - ModelParamsBuilder(selection_model, selection_model_length, nullptr)); - if (!selection_params_.get()) { - return false; - } - selection_options_ = selection_params_->GetSelectionModelOptions(); - selection_network_.reset(new EmbeddingNetwork(selection_params_.get())); - selection_feature_processor_.reset( - new FeatureProcessor(selection_params_->GetFeatureProcessorOptions())); - selection_feature_fn_ = CreateFeatureVectorFn( - *selection_network_, selection_network_->EmbeddingSize(0)); - - sharing_params_.reset( - ModelParamsBuilder(sharing_model, sharing_model_length, - selection_params_->GetEmbeddingParams())); - if (!sharing_params_.get()) { - return false; - } - sharing_options_ = selection_params_->GetSharingModelOptions(); - sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get())); - sharing_feature_processor_.reset( - new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions())); - sharing_feature_fn_ = CreateFeatureVectorFn( - *sharing_network_, sharing_network_->EmbeddingSize(0)); - - InitializeSharingRegexPatterns(std::vector<SharingModelOptions::RegexPattern>( - sharing_options_.regex_pattern().begin(), - sharing_options_.regex_pattern().end())); - - return true; -} - -bool ReadSelectionModelOptions(int fd, ModelOptions* model_options) { - ScopedMmap mmap = ScopedMmap(fd); - if (!mmap.handle().ok()) { - TC_LOG(ERROR) << "Can't mmap."; - return false; - } - - const char *selection_model, *sharing_model; - int selection_model_length, sharing_model_length; - if (!ParseMergedModel(mmap.handle().start(), mmap.handle().num_bytes(), - &selection_model, &selection_model_length, - &sharing_model, &sharing_model_length)) { - TC_LOG(ERROR) << "Couldn't parse merged model."; - return false; - } - - MemoryImageReader<EmbeddingNetworkProto> reader(selection_model, - selection_model_length); - - auto model_options_extension_id = model_options_in_embedding_network_proto; - if (reader.trimmed_proto().HasExtension(model_options_extension_id)) { - *model_options = - reader.trimmed_proto().GetExtension(model_options_extension_id); - return true; - } else { - return false; - } -} - -EmbeddingNetwork::Vector TextClassificationModel::InferInternal( - const std::string& context, CodepointSpan span, - const FeatureProcessor& feature_processor, const EmbeddingNetwork& network, - const FeatureVectorFn& feature_vector_fn, - std::vector<CodepointSpan>* selection_label_spans) const { - std::vector<Token> tokens; - int click_pos; - std::unique_ptr<CachedFeatures> cached_features; - const int embedding_size = network.EmbeddingSize(0); - if (!feature_processor.ExtractFeatures( - context, span, /*relative_click_span=*/{0, 0}, - CreateFeatureVectorFn(network, embedding_size), - embedding_size + feature_processor.DenseFeaturesCount(), &tokens, - &click_pos, &cached_features)) { - TC_VLOG(1) << "Could not extract features."; - return {}; - } - - VectorSpan<float> features; - VectorSpan<Token> output_tokens; - if (!cached_features->Get(click_pos, &features, &output_tokens)) { - TC_VLOG(1) << "Could not extract features."; - return {}; - } - - if (selection_label_spans != nullptr) { - if (!feature_processor.SelectionLabelSpans(output_tokens, - selection_label_spans)) { - TC_LOG(ERROR) << "Could not get spans for selection labels."; - return {}; - } - } - - std::vector<float> scores; - network.ComputeLogits(features, &scores); - return scores; -} - -namespace { - -// Returns true if given codepoint is contained in the given span in context. -bool IsCodepointInSpan(const char32 codepoint, const std::string& context, - const CodepointSpan span) { - const UnicodeText context_unicode = - UTF8ToUnicodeText(context, /*do_copy=*/false); - - auto begin_it = context_unicode.begin(); - std::advance(begin_it, span.first); - auto end_it = context_unicode.begin(); - std::advance(end_it, span.second); - - return std::find(begin_it, end_it, codepoint) != end_it; -} - -// Returns the first codepoint of the span. -char32 FirstSpanCodepoint(const std::string& context, - const CodepointSpan span) { - const UnicodeText context_unicode = - UTF8ToUnicodeText(context, /*do_copy=*/false); - - auto it = context_unicode.begin(); - std::advance(it, span.first); - return *it; -} - -// Returns the last codepoint of the span. -char32 LastSpanCodepoint(const std::string& context, const CodepointSpan span) { - const UnicodeText context_unicode = - UTF8ToUnicodeText(context, /*do_copy=*/false); - - auto it = context_unicode.begin(); - std::advance(it, span.second - 1); - return *it; -} - -} // namespace - -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - -namespace { - -bool IsOpenBracket(const char32 codepoint) { - return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) == - U_BPT_OPEN; -} - -bool IsClosingBracket(const char32 codepoint) { - return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) == - U_BPT_CLOSE; -} - -} // namespace - -// If the first or the last codepoint of the given span is a bracket, the -// bracket is stripped if the span does not contain its corresponding paired -// version. -CodepointSpan StripUnpairedBrackets(const std::string& context, - CodepointSpan span) { - if (context.empty()) { - return span; - } - - const char32 begin_char = FirstSpanCodepoint(context, span); - - const char32 paired_begin_char = u_getBidiPairedBracket(begin_char); - if (paired_begin_char != begin_char) { - if (!IsOpenBracket(begin_char) || - !IsCodepointInSpan(paired_begin_char, context, span)) { - ++span.first; - } - } - - if (span.first == span.second) { - return span; - } - - const char32 end_char = LastSpanCodepoint(context, span); - const char32 paired_end_char = u_getBidiPairedBracket(end_char); - if (paired_end_char != end_char) { - if (!IsClosingBracket(end_char) || - !IsCodepointInSpan(paired_end_char, context, span)) { - --span.second; - } - } - - // Should not happen, but let's make sure. - if (span.first > span.second) { - TC_LOG(WARNING) << "Inverse indices result: " << span.first << ", " - << span.second; - span.second = span.first; - } - - return span; -} -#endif - -CodepointSpan TextClassificationModel::SuggestSelection( - const std::string& context, CodepointSpan click_indices) const { - if (!initialized_) { - TC_LOG(ERROR) << "Not initialized"; - return click_indices; - } - - const int context_codepoint_size = - UTF8ToUnicodeText(context, /*do_copy=*/false).size(); - - if (click_indices.first < 0 || click_indices.second < 0 || - click_indices.first >= context_codepoint_size || - click_indices.second > context_codepoint_size || - click_indices.first >= click_indices.second) { - TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: " - << click_indices.first << " " << click_indices.second; - return click_indices; - } - - CodepointSpan result; - if (selection_options_.enforce_symmetry()) { - result = SuggestSelectionSymmetrical(context, click_indices); - } else { - float score; - std::tie(result, score) = SuggestSelectionInternal(context, click_indices); - } - -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - if (selection_options_.strip_unpaired_brackets()) { - const CodepointSpan stripped_result = - StripUnpairedBrackets(context, result); - if (stripped_result.first != stripped_result.second) { - result = stripped_result; - } - } -#endif - - return result; -} - -namespace { - -int BestPrediction(const std::vector<float>& scores) { - if (!scores.empty()) { - const int prediction = - std::max_element(scores.begin(), scores.end()) - scores.begin(); - return prediction; - } else { - return kInvalidLabel; - } -} - -std::pair<CodepointSpan, float> BestSelectionSpan( - CodepointSpan original_click_indices, const std::vector<float>& scores, - const std::vector<CodepointSpan>& selection_label_spans) { - const int prediction = BestPrediction(scores); - if (prediction != kInvalidLabel) { - std::pair<CodepointIndex, CodepointIndex> selection = - selection_label_spans[prediction]; - - if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) { - TC_VLOG(1) << "Invalid indices predicted, returning input: " << prediction - << " " << selection.first << " " << selection.second; - return {original_click_indices, -1.0}; - } - - return {{selection.first, selection.second}, scores[prediction]}; - } else { - TC_LOG(ERROR) << "Returning default selection: scores.size() = " - << scores.size(); - return {original_click_indices, -1.0}; - } -} - -} // namespace - -std::pair<CodepointSpan, float> -TextClassificationModel::SuggestSelectionInternal( - const std::string& context, CodepointSpan click_indices) const { - if (!initialized_) { - TC_LOG(ERROR) << "Not initialized"; - return {click_indices, -1.0}; - } - - std::vector<CodepointSpan> selection_label_spans; - EmbeddingNetwork::Vector scores = InferInternal( - context, click_indices, *selection_feature_processor_, - *selection_network_, selection_feature_fn_, &selection_label_spans); - scores = nlp_core::ComputeSoftmax(scores); - - return BestSelectionSpan(click_indices, scores, selection_label_spans); -} - -// Implements a greedy-search-like algorithm for making selections symmetric. -// -// Steps: -// 1. Get a set of selection proposals from places around the clicked word. -// 2. For each proposal (going from highest-scoring), check if the tokens that -// the proposal selects are still free, in which case it claims them, if a -// proposal that contains the clicked token is found, it is returned as the -// suggestion. -// -// This algorithm should ensure that if a selection is proposed, it does not -// matter which word of it was tapped - all of them will lead to the same -// selection. -CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical( - const std::string& context, CodepointSpan click_indices) const { - const int symmetry_context_size = selection_options_.symmetry_context_size(); - std::vector<CodepointSpan> chunks = Chunk( - context, click_indices, {symmetry_context_size, symmetry_context_size}); - for (const CodepointSpan& chunk : chunks) { - // If chunk and click indices have an overlap, return the chunk. - if (!(click_indices.first >= chunk.second || - click_indices.second <= chunk.first)) { - return chunk; - } - } - - return click_indices; -} - -std::vector<std::pair<std::string, float>> -TextClassificationModel::ClassifyText(const std::string& context, - CodepointSpan selection_indices, - int hint_flags) const { - if (!initialized_) { - TC_LOG(ERROR) << "Not initialized"; - return {}; - } - - if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) { - TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: " - << std::get<0>(selection_indices) << " " - << std::get<1>(selection_indices); - return {}; - } - - if (hint_flags & SELECTION_IS_URL && - sharing_options_.always_accept_url_hint()) { - return {{kUrlHintCollection, 1.0}}; - } - - if (hint_flags & SELECTION_IS_EMAIL && - sharing_options_.always_accept_email_hint()) { - return {{kEmailHintCollection, 1.0}}; - } - - // Check whether any of the regular expressions match. -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - const std::string selection_text = - ExtractSelection(context, selection_indices); - for (const CompiledRegexPattern& regex_pattern : regex_patterns_) { - if (MatchesRegex(regex_pattern.pattern.get(), selection_text)) { - return {{regex_pattern.collection_name, 1.0}}; - } - } -#endif - - EmbeddingNetwork::Vector scores = - InferInternal(context, selection_indices, *sharing_feature_processor_, - *sharing_network_, sharing_feature_fn_, nullptr); - if (scores.empty() || - scores.size() != sharing_feature_processor_->NumCollections()) { - TC_VLOG(1) << "Using default class: scores.size() = " << scores.size(); - return {}; - } - - scores = nlp_core::ComputeSoftmax(scores); - - std::vector<std::pair<std::string, float>> result(scores.size()); - for (int i = 0; i < scores.size(); i++) { - result[i] = {sharing_feature_processor_->LabelToCollection(i), scores[i]}; - } - std::sort(result.begin(), result.end(), - [](const std::pair<std::string, float>& a, - const std::pair<std::string, float>& b) { - return a.second > b.second; - }); - - // Phone class sanity check. - if (result.begin()->first == kPhoneCollection) { - const int digit_count = CountDigits(context, selection_indices); - if (digit_count < sharing_options_.phone_min_num_digits() || - digit_count > sharing_options_.phone_max_num_digits()) { - return {{kOtherCollection, 1.0}}; - } - } - - return result; -} - -std::vector<CodepointSpan> TextClassificationModel::Chunk( - const std::string& context, CodepointSpan click_span, - TokenSpan relative_click_span) const { - std::unique_ptr<CachedFeatures> cached_features; - std::vector<Token> tokens; - int click_index; - int embedding_size = selection_network_->EmbeddingSize(0); - if (!selection_feature_processor_->ExtractFeatures( - context, click_span, relative_click_span, selection_feature_fn_, - embedding_size + selection_feature_processor_->DenseFeaturesCount(), - &tokens, &click_index, &cached_features)) { - TC_VLOG(1) << "Couldn't ExtractFeatures."; - return {}; - } - - int first_token; - int last_token; - if (relative_click_span.first == kInvalidIndex || - relative_click_span.second == kInvalidIndex) { - first_token = 0; - last_token = tokens.size(); - } else { - first_token = click_index - relative_click_span.first; - last_token = click_index + relative_click_span.second + 1; - } - - struct SelectionProposal { - int label; - int token_index; - CodepointSpan span; - float score; - }; - - // Scan in the symmetry context for selection span proposals. - std::vector<SelectionProposal> proposals; - for (int token_index = first_token; token_index < last_token; ++token_index) { - if (token_index < 0 || token_index >= tokens.size() || - tokens[token_index].is_padding) { - continue; - } - - float score; - VectorSpan<float> features; - VectorSpan<Token> output_tokens; - std::vector<CodepointSpan> selection_label_spans; - CodepointSpan span; - if (cached_features->Get(token_index, &features, &output_tokens) && - selection_feature_processor_->SelectionLabelSpans( - output_tokens, &selection_label_spans)) { - // Add an implicit proposal for each token to be by itself. Every - // token should be now represented in the results. - proposals.push_back( - SelectionProposal{0, token_index, selection_label_spans[0], 0.0}); - - std::vector<float> scores; - selection_network_->ComputeLogits(features, &scores); - - scores = nlp_core::ComputeSoftmax(scores); - std::tie(span, score) = BestSelectionSpan({kInvalidIndex, kInvalidIndex}, - scores, selection_label_spans); - if (span.first != kInvalidIndex && span.second != kInvalidIndex && - score >= 0) { - const int prediction = BestPrediction(scores); - proposals.push_back( - SelectionProposal{prediction, token_index, span, score}); - } - } else { - // Add an implicit proposal for each token to be by itself. Every token - // should be now represented in the results. - proposals.push_back(SelectionProposal{ - 0, - token_index, - {tokens[token_index].start, tokens[token_index].end}, - 0.0}); - } - } - - // Sort selection span proposals by their respective probabilities. - std::sort(proposals.begin(), proposals.end(), - [](const SelectionProposal& a, const SelectionProposal& b) { - return a.score > b.score; - }); - - // Go from the highest-scoring proposal and claim tokens. Tokens are marked as - // claimed by the higher-scoring selection proposals, so that the - // lower-scoring ones cannot use them. Returns the selection proposal if it - // contains the clicked token. - std::vector<CodepointSpan> result; - std::vector<bool> token_used(tokens.size(), false); - for (const SelectionProposal& proposal : proposals) { - const int predicted_label = proposal.label; - TokenSpan relative_span; - if (!selection_feature_processor_->LabelToTokenSpan(predicted_label, - &relative_span)) { - continue; - } - TokenSpan span; - span.first = proposal.token_index - relative_span.first; - span.second = proposal.token_index + relative_span.second + 1; - - if (span.first != kInvalidIndex && span.second != kInvalidIndex) { - bool feasible = true; - for (int i = span.first; i < span.second; i++) { - if (token_used[i]) { - feasible = false; - break; - } - } - - if (feasible) { - result.push_back(proposal.span); - for (int i = span.first; i < span.second; i++) { - token_used[i] = true; - } - } - } - } - - std::sort(result.begin(), result.end(), - [](const CodepointSpan& a, const CodepointSpan& b) { - return a.first < b.first; - }); - - return result; -} - -std::vector<TextClassificationModel::AnnotatedSpan> -TextClassificationModel::Annotate(const std::string& context) const { - std::vector<CodepointSpan> chunks; - const UnicodeText context_unicode = UTF8ToUnicodeText(context, - /*do_copy=*/false); - for (const UnicodeTextRange& line : - selection_feature_processor_->SplitContext(context_unicode)) { - const std::vector<CodepointSpan> local_chunks = - Chunk(UnicodeText::UTF8Substring(line.first, line.second), - /*click_span=*/{kInvalidIndex, kInvalidIndex}, - /*relative_click_span=*/{kInvalidIndex, kInvalidIndex}); - const int offset = std::distance(context_unicode.begin(), line.first); - for (CodepointSpan chunk : local_chunks) { - chunks.push_back({chunk.first + offset, chunk.second + offset}); - } - } - - std::vector<TextClassificationModel::AnnotatedSpan> result; - for (const CodepointSpan& chunk : chunks) { - result.emplace_back(); - result.back().span = chunk; - result.back().classification = ClassifyText(context, chunk); - } - return result; -} - -} // namespace libtextclassifier diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h deleted file mode 100644 index d0df193..0000000 --- a/smartselect/text-classification-model.h +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Inference code for the feed-forward text classification models. - -#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_ -#define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_ - -#include <memory> -#include <set> -#include <string> - -#include "common/embedding-network.h" -#include "common/feature-extractor.h" -#include "common/memory_image/embedding-network-params-from-image.h" -#include "common/mmap.h" -#include "smartselect/feature-processor.h" -#include "smartselect/model-params.h" -#include "smartselect/text-classification-model.pb.h" -#include "smartselect/types.h" - -namespace libtextclassifier { - -// SmartSelection/Sharing feed-forward model. -class TextClassificationModel { - public: - // Represents a result of Annotate call. - struct AnnotatedSpan { - // Unicode codepoint indices in the input string. - CodepointSpan span = {kInvalidIndex, kInvalidIndex}; - - // Classification result for the span. - std::vector<std::pair<std::string, float>> classification; - }; - - // Loads TextClassificationModel from given file given by an int - // file descriptor. - // Offset is byte a position in the file to the beginning of the model data. - TextClassificationModel(int fd, int offset, int size); - - // Same as above but the whole file is mapped and it is assumed the model - // starts at offset 0. - explicit TextClassificationModel(int fd); - - // Loads TextClassificationModel from given file. - explicit TextClassificationModel(const std::string& path); - - // Loads TextClassificationModel from given location in memory. - TextClassificationModel(const void* addr, int size); - - // Returns true if the model is ready for use. - bool IsInitialized() { return initialized_; } - - // Bit flags for the input selection. - enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 }; - - // Runs inference for given a context and current selection (i.e. index - // of the first and one past last selected characters (utf8 codepoint - // offsets)). Returns the indices (utf8 codepoint offsets) of the selection - // beginning character and one past selection end character. - // Returns the original click_indices if an error occurs. - // NOTE: The selection indices are passed in and returned in terms of - // UTF8 codepoints (not bytes). - // Requires that the model is a smart selection model. - CodepointSpan SuggestSelection(const std::string& context, - CodepointSpan click_indices) const; - - // Classifies the selected text given the context string. - // Requires that the model is a smart sharing model. - // Returns an empty result if an error occurs. - std::vector<std::pair<std::string, float>> ClassifyText( - const std::string& context, CodepointSpan click_indices, - int input_flags = 0) const; - - // Annotates given input text. The annotations should cover the whole input - // context except for whitespaces, and are sorted by their position in the - // context string. - std::vector<AnnotatedSpan> Annotate(const std::string& context) const; - - protected: - // Initializes the model from mmap_ file. - void InitFromMmap(); - - // Extracts chunks from the context. The extraction proceeds from the center - // token determined by click_span and looks at relative_click_span tokens - // left and right around the click position. - // If relative_click_span == {kInvalidIndex, kInvalidIndex} then the whole - // context is considered, regardless of the click_span. - // Returns the chunks sorted by their position in the context string. - std::vector<CodepointSpan> Chunk(const std::string& context, - CodepointSpan click_span, - TokenSpan relative_click_span) const; - - // During evaluation we need access to the feature processor. - FeatureProcessor* SelectionFeatureProcessor() const { - return selection_feature_processor_.get(); - } - - void InitializeSharingRegexPatterns( - const std::vector<SharingModelOptions::RegexPattern>& patterns); - - // Collection name when url hint is accepted. - const std::string kUrlHintCollection = "url"; - - // Collection name when email hint is accepted. - const std::string kEmailHintCollection = "email"; - - // Collection name for other. - const std::string kOtherCollection = "other"; - - // Collection name for phone. - const std::string kPhoneCollection = "phone"; - - SelectionModelOptions selection_options_; - SharingModelOptions sharing_options_; - - private: -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - struct CompiledRegexPattern { - std::string collection_name; - std::unique_ptr<icu::RegexPattern> pattern; - }; -#endif - - bool LoadModels(const void* addr, int size); - - nlp_core::EmbeddingNetwork::Vector InferInternal( - const std::string& context, CodepointSpan span, - const FeatureProcessor& feature_processor, - const nlp_core::EmbeddingNetwork& network, - const FeatureVectorFn& feature_vector_fn, - std::vector<CodepointSpan>* selection_label_spans) const; - - // Returns a selection suggestion with a score. - std::pair<CodepointSpan, float> SuggestSelectionInternal( - const std::string& context, CodepointSpan click_indices) const; - - // Returns a selection suggestion and makes sure it's symmetric. Internally - // runs several times SuggestSelectionInternal. - CodepointSpan SuggestSelectionSymmetrical(const std::string& context, - CodepointSpan click_indices) const; - - bool initialized_ = false; - std::unique_ptr<nlp_core::ScopedMmap> mmap_; - std::unique_ptr<ModelParams> selection_params_; - std::unique_ptr<FeatureProcessor> selection_feature_processor_; - std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_; - FeatureVectorFn selection_feature_fn_; - std::unique_ptr<FeatureProcessor> sharing_feature_processor_; - std::unique_ptr<ModelParams> sharing_params_; - std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_; - FeatureVectorFn sharing_feature_fn_; -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - std::vector<CompiledRegexPattern> regex_patterns_; -#endif -}; - -// If the first or the last codepoint of the given span is a bracket, the -// bracket is stripped if the span does not contain its corresponding paired -// version. -CodepointSpan StripUnpairedBrackets(const std::string& context, - CodepointSpan span); - -// Parses the merged image given as a file descriptor, and reads -// the ModelOptions proto from the selection model. -bool ReadSelectionModelOptions(int fd, ModelOptions* model_options); - -// Pretty-printing function for TextClassificationModel::AnnotatedSpan. -inline std::ostream& operator<<( - std::ostream& os, const TextClassificationModel::AnnotatedSpan& span) { - std::string best_class; - float best_score = -1; - if (!span.classification.empty()) { - best_class = span.classification[0].first; - best_score = span.classification[0].second; - } - return os << "Span(" << span.span.first << ", " << span.span.second << ", " - << best_class << ", " << best_score << ")"; -} - -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_ diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto deleted file mode 100644 index 315e849..0000000 --- a/smartselect/text-classification-model.proto +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright (C) 2017 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Text classification model configuration. - -syntax = "proto2"; -option optimize_for = LITE_RUNTIME; - -import "external/libtextclassifier/common/embedding-network.proto"; -import "external/libtextclassifier/smartselect/tokenizer.proto"; - -package libtextclassifier; - -// Generic options of a model, non-specific to selection or sharing. -message ModelOptions { - // If true, will use embeddings from a different model. This is mainly useful - // for the Sharing model using the embeddings from the Selection model. - optional bool use_shared_embeddings = 1; - - // Language of the model. - optional string language = 2; - - // Version of the model. - optional int32 version = 3; -} - -message SelectionModelOptions { - // A list of Unicode codepoints to strip from predicted selections. - repeated int32 deprecated_punctuation_to_strip = 1; - - // Enforce symmetrical selections. - optional bool enforce_symmetry = 3; - - // Number of inferences made around the click position (to one side), for - // enforcing symmetry. - optional int32 symmetry_context_size = 4; - - // If true, before the selection is returned, the unpaired brackets contained - // in the predicted selection are stripped from the both selection ends. - // The bracket codepoints are defined in the Unicode standard: - // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt - optional bool strip_unpaired_brackets = 5 [default = true]; - - reserved 2; -} - -message SharingModelOptions { - // If true, will always return "url" when the url hint is passed in. - optional bool always_accept_url_hint = 1; - - // If true, will always return "email" when the e-mail hint is passed in. - optional bool always_accept_email_hint = 2; - - // Limits for phone numbers. - optional int32 phone_min_num_digits = 3 [default = 7]; - optional int32 phone_max_num_digits = 4 [default = 15]; - - // List of regular expression matchers to check. - message RegexPattern { - // The name of the collection of a match. - optional string collection_name = 1; - - // The pattern to check. - optional string pattern = 2; - } - repeated RegexPattern regex_pattern = 5; -} - -// Next ID: 41 -message FeatureProcessorOptions { - // Number of buckets used for hashing charactergrams. - optional int32 num_buckets = 1 [default = -1]; - - // Context size defines the number of words to the left and to the right of - // the selected word to be used as context. For example, if context size is - // N, then we take N words to the left and N words to the right of the - // selected word as its context. - optional int32 context_size = 2 [default = -1]; - - // Maximum number of words of the context to select in total. - optional int32 max_selection_span = 3 [default = -1]; - - // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3 - // character trigrams etc. - repeated int32 chargram_orders = 4; - - // Maximum length of a word, in codepoints. - optional int32 max_word_length = 21 [default = 20]; - - // If true, will use the unicode-aware functionality for extracting features. - optional bool unicode_aware_features = 19 [default = false]; - - // Whether to extract the token case feature. - optional bool extract_case_feature = 5 [default = false]; - - // Whether to extract the selection mask feature. - optional bool extract_selection_mask_feature = 6 [default = false]; - - // List of regexps to run over each token. For each regexp, if there is a - // match, a dense feature of 1.0 is emitted. Otherwise -1.0 is used. - repeated string regexp_feature = 22; - - // Whether to remap all digits to a single number. - optional bool remap_digits = 20 [default = false]; - - // Whether to lower-case each token before generating hashgrams. - optional bool lowercase_tokens = 33; - - // If true, the selection classifier output will contain only the selections - // that are feasible (e.g., those that are shorter than max_selection_span), - // if false, the output will be a complete cross-product of possible - // selections to the left and posible selections to the right, including the - // infeasible ones. - // NOTE: Exists mainly for compatibility with older models that were trained - // with the non-reduced output space. - optional bool selection_reduced_output_space = 8 [default = true]; - - // Collection names. - repeated string collections = 9; - - // An index of collection in collections to be used if a collection name can't - // be mapped to an id. - optional int32 default_collection = 10 [default = -1]; - - // If true, will split the input by lines, and only use the line that contains - // the clicked token. - optional bool only_use_line_with_click = 13 [default = false]; - - // If true, will split tokens that contain the selection boundary, at the - // position of the boundary. - // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" - optional bool split_tokens_on_selection_boundaries = 14 [default = false]; - - // Codepoint ranges that determine how different codepoints are tokenized. - // The ranges must not overlap. - repeated TokenizationCodepointRange tokenization_codepoint_config = 15; - - // Method for selecting the center token. - enum CenterTokenSelectionMethod { - DEFAULT_CENTER_TOKEN_METHOD = 0; // Invalid option. - - // Use click indices to determine the center token. - CENTER_TOKEN_FROM_CLICK = 1; - - // Use selection indices to get a token range, and select the middle of it - // as the center token. - CENTER_TOKEN_MIDDLE_OF_SELECTION = 2; - } - optional CenterTokenSelectionMethod center_token_selection_method = 16; - - // If true, span boundaries will be snapped to containing tokens and not - // required to exactly match token boundaries. - optional bool snap_label_span_boundaries_to_containing_tokens = 18; - - // Range of codepoints start - end, where end is exclusive. - message CodepointRange { - optional int32 start = 1; - optional int32 end = 2; - } - - // A set of codepoint ranges supported by the model. - repeated CodepointRange supported_codepoint_ranges = 23; - - // A set of codepoint ranges to use in the mixed tokenization mode to identify - // stretches of tokens to re-tokenize using the internal tokenizer. - repeated CodepointRange internal_tokenizer_codepoint_ranges = 34; - - // Minimum ratio of supported codepoints in the input context. If the ratio - // is lower than this, the feature computation will fail. - optional float min_supported_codepoint_ratio = 24 [default = 0.0]; - - // Used for versioning the format of features the model expects. - // - feature_version == 0: - // For each token the features consist of: - // - chargram embeddings - // - dense features - // Chargram embeddings for tokens are concatenated first together, - // and at the end, the dense features for the tokens are concatenated - // to it. So the resulting feature vector has two regions. - optional int32 feature_version = 25 [default = 0]; - - // Controls the type of tokenization the model will use for the input text. - enum TokenizationType { - INVALID_TOKENIZATION_TYPE = 0; - - // Use the internal tokenizer for tokenization. - INTERNAL_TOKENIZER = 1; - - // Use ICU for tokenization. - ICU = 2; - - // First apply ICU tokenization. Then identify stretches of tokens - // consisting only of codepoints in internal_tokenizer_codepoint_ranges - // and re-tokenize them using the internal tokenizer. - MIXED = 3; - } - optional TokenizationType tokenization_type = 30 - [default = INTERNAL_TOKENIZER]; - optional bool icu_preserve_whitespace_tokens = 31 [default = false]; - - // List of codepoints that will be stripped from beginning and end of - // predicted spans. - repeated int32 ignored_span_boundary_codepoints = 36; - - reserved 7, 11, 12, 26, 27, 28, 29, 32, 35, 39, 40; - - // List of allowed charactergrams. The extracted charactergrams are filtered - // using this list, and charactergrams that are not present are interpreted as - // out-of-vocabulary. - // If no allowed_chargrams are specified, all charactergrams are allowed. - // The field is typed as bytes type to allow non-UTF8 chargrams. - repeated bytes allowed_chargrams = 38; -}; - -extend nlp_core.EmbeddingNetworkProto { - optional ModelOptions model_options_in_embedding_network_proto = 150063045; - optional FeatureProcessorOptions - feature_processor_options_in_embedding_network_proto = 146230910; - optional SelectionModelOptions - selection_model_options_in_embedding_network_proto = 148190899; - optional SharingModelOptions - sharing_model_options_in_embedding_network_proto = 151445439; -} diff --git a/smartselect/text-classification-model_test.cc b/smartselect/text-classification-model_test.cc deleted file mode 100644 index 5550e53..0000000 --- a/smartselect/text-classification-model_test.cc +++ /dev/null @@ -1,440 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "smartselect/text-classification-model.h" - -#include <fcntl.h> -#include <stdio.h> -#include <fstream> -#include <iostream> -#include <memory> -#include <string> - -#include "gtest/gtest.h" - -namespace libtextclassifier { -namespace { - -std::string ReadFile(const std::string& file_name) { - std::ifstream file_stream(file_name); - return std::string(std::istreambuf_iterator<char>(file_stream), {}); -} - -std::string GetModelPath() { - return TEST_DATA_DIR "smartselection.model"; -} - -std::string GetURLRegexPath() { - return TEST_DATA_DIR "regex_url.txt"; -} - -std::string GetEmailRegexPath() { - return TEST_DATA_DIR "regex_email.txt"; -} - -TEST(TextClassificationModelTest, StripUnpairedBrackets) { - // Stripping brackets strip brackets from length 1 bracket only selections. - EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}), - std::make_pair(12, 12)); - EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}), - std::make_pair(12, 12)); -} - -TEST(TextClassificationModelTest, ReadModelOptions) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - ModelOptions model_options; - ASSERT_TRUE(ReadSelectionModelOptions(fd, &model_options)); - close(fd); - - EXPECT_EQ("en", model_options.language()); - EXPECT_GT(model_options.version(), 0); -} - -TEST(TextClassificationModelTest, SuggestSelection) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - std::unique_ptr<TextClassificationModel> model( - new TextClassificationModel(fd)); - close(fd); - - EXPECT_EQ(model->SuggestSelection( - "this afternoon Barack Obama gave a speech at", {15, 21}), - std::make_pair(15, 27)); - - // Try passing whole string. - // If more than 1 token is specified, we should return back what entered. - EXPECT_EQ(model->SuggestSelection("350 Third Street, Cambridge", {0, 27}), - std::make_pair(0, 27)); - - // Single letter. - EXPECT_EQ(std::make_pair(0, 1), model->SuggestSelection("a", {0, 1})); - - // Single word. - EXPECT_EQ(std::make_pair(0, 4), model->SuggestSelection("asdf", {0, 4})); - - EXPECT_EQ(model->SuggestSelection("call me at 857 225 3556 today", {11, 14}), - std::make_pair(11, 23)); - - // Unpaired bracket stripping. - EXPECT_EQ( - model->SuggestSelection("call me at (857) 225 3556 today", {11, 16}), - std::make_pair(11, 25)); - EXPECT_EQ(model->SuggestSelection("call me at (857 225 3556 today", {11, 15}), - std::make_pair(12, 24)); - EXPECT_EQ(model->SuggestSelection("call me at 857 225 3556) today", {11, 14}), - std::make_pair(11, 23)); - EXPECT_EQ( - model->SuggestSelection("call me at )857 225 3556( today", {11, 15}), - std::make_pair(12, 24)); - - // If the resulting selection would be empty, the original span is returned. - EXPECT_EQ(model->SuggestSelection("call me at )( today", {11, 13}), - std::make_pair(11, 13)); - EXPECT_EQ(model->SuggestSelection("call me at ( today", {11, 12}), - std::make_pair(11, 12)); - EXPECT_EQ(model->SuggestSelection("call me at ) today", {11, 12}), - std::make_pair(11, 12)); -} - -TEST(TextClassificationModelTest, SuggestSelectionsAreSymmetric) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - std::unique_ptr<TextClassificationModel> model( - new TextClassificationModel(fd)); - close(fd); - - EXPECT_EQ(std::make_pair(0, 27), - model->SuggestSelection("350 Third Street, Cambridge", {0, 3})); - EXPECT_EQ(std::make_pair(0, 27), - model->SuggestSelection("350 Third Street, Cambridge", {4, 9})); - EXPECT_EQ(std::make_pair(0, 27), - model->SuggestSelection("350 Third Street, Cambridge", {10, 16})); - EXPECT_EQ(std::make_pair(6, 33), - model->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge", - {16, 22})); -} - -TEST(TextClassificationModelTest, SuggestSelectionWithNewLine) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - std::unique_ptr<TextClassificationModel> model( - new TextClassificationModel(fd)); - close(fd); - - std::tuple<int, int> selection; - selection = model->SuggestSelection("abc\nBarack Obama", {4, 10}); - EXPECT_EQ(4, std::get<0>(selection)); - EXPECT_EQ(16, std::get<1>(selection)); - - selection = model->SuggestSelection("Barack Obama\nabc", {0, 6}); - EXPECT_EQ(0, std::get<0>(selection)); - EXPECT_EQ(12, std::get<1>(selection)); -} - -TEST(TextClassificationModelTest, SuggestSelectionWithPunctuation) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - std::unique_ptr<TextClassificationModel> model( - new TextClassificationModel(fd)); - close(fd); - - std::tuple<int, int> selection; - - // From the right. - selection = model->SuggestSelection( - "this afternoon Barack Obama, gave a speech at", {15, 21}); - EXPECT_EQ(15, std::get<0>(selection)); - EXPECT_EQ(27, std::get<1>(selection)); - - // From the right multiple. - selection = model->SuggestSelection( - "this afternoon Barack Obama,.,.,, gave a speech at", {15, 21}); - EXPECT_EQ(15, std::get<0>(selection)); - EXPECT_EQ(27, std::get<1>(selection)); - - // From the left multiple. - selection = model->SuggestSelection( - "this afternoon ,.,.,,Barack Obama gave a speech at", {21, 27}); - EXPECT_EQ(21, std::get<0>(selection)); - EXPECT_EQ(27, std::get<1>(selection)); - - // From both sides. - selection = model->SuggestSelection( - "this afternoon !Barack Obama,- gave a speech at", {16, 22}); - EXPECT_EQ(16, std::get<0>(selection)); - EXPECT_EQ(28, std::get<1>(selection)); -} - -class TestingTextClassificationModel - : public libtextclassifier::TextClassificationModel { - public: - explicit TestingTextClassificationModel(int fd) - : libtextclassifier::TextClassificationModel(fd) {} - - using TextClassificationModel::InitializeSharingRegexPatterns; - - void DisableClassificationHints() { - sharing_options_.set_always_accept_url_hint(false); - sharing_options_.set_always_accept_email_hint(false); - } -}; - -TEST(TextClassificationModelTest, SuggestSelectionNoCrashWithJunk) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - std::unique_ptr<TextClassificationModel> ff_model( - new TextClassificationModel(fd)); - close(fd); - - std::tuple<int, int> selection; - - // Try passing in bunch of invalid selections. - selection = ff_model->SuggestSelection("", {0, 27}); - // If more than 1 token is specified, we should return back what entered. - EXPECT_EQ(0, std::get<0>(selection)); - EXPECT_EQ(27, std::get<1>(selection)); - - selection = ff_model->SuggestSelection("", {-10, 27}); - // If more than 1 token is specified, we should return back what entered. - EXPECT_EQ(-10, std::get<0>(selection)); - EXPECT_EQ(27, std::get<1>(selection)); - - selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {0, 27}); - // If more than 1 token is specified, we should return back what entered. - EXPECT_EQ(0, std::get<0>(selection)); - EXPECT_EQ(27, std::get<1>(selection)); - - selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {-30, 300}); - // If more than 1 token is specified, we should return back what entered. - EXPECT_EQ(-30, std::get<0>(selection)); - EXPECT_EQ(300, std::get<1>(selection)); - - selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {-10, -1}); - // If more than 1 token is specified, we should return back what entered. - EXPECT_EQ(-10, std::get<0>(selection)); - EXPECT_EQ(-1, std::get<1>(selection)); - - selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {100, 17}); - // If more than 1 token is specified, we should return back what entered. - EXPECT_EQ(100, std::get<0>(selection)); - EXPECT_EQ(17, std::get<1>(selection)); -} - -namespace { - -std::string FindBestResult(std::vector<std::pair<std::string, float>> results) { - if (results.empty()) { - return "<INVALID RESULTS>"; - } - - std::sort(results.begin(), results.end(), - [](const std::pair<std::string, float> a, - const std::pair<std::string, float> b) { - return a.second > b.second; - }); - return results[0].first; -} - -} // namespace - -TEST(TextClassificationModelTest, ClassifyText) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - std::unique_ptr<TestingTextClassificationModel> model( - new TestingTextClassificationModel(fd)); - close(fd); - - model->DisableClassificationHints(); - EXPECT_EQ("other", - FindBestResult(model->ClassifyText( - "this afternoon Barack Obama gave a speech at", {15, 27}))); - EXPECT_EQ("other", - FindBestResult(model->ClassifyText("you@android.com", {0, 15}))); - EXPECT_EQ("other", FindBestResult(model->ClassifyText( - "Contact me at you@android.com", {14, 29}))); - EXPECT_EQ("phone", FindBestResult(model->ClassifyText( - "Call me at (800) 123-456 today", {11, 24}))); - EXPECT_EQ("other", FindBestResult(model->ClassifyText( - "Visit www.google.com every today!", {6, 20}))); - - // More lines. - EXPECT_EQ("other", - FindBestResult(model->ClassifyText( - "this afternoon Barack Obama gave a speech at|Visit " - "www.google.com every today!|Call me at (800) 123-456 today.", - {15, 27}))); - EXPECT_EQ("other", - FindBestResult(model->ClassifyText( - "this afternoon Barack Obama gave a speech at|Visit " - "www.google.com every today!|Call me at (800) 123-456 today.", - {51, 65}))); - EXPECT_EQ("phone", - FindBestResult(model->ClassifyText( - "this afternoon Barack Obama gave a speech at|Visit " - "www.google.com every today!|Call me at (800) 123-456 today.", - {90, 103}))); - - // Single word. - EXPECT_EQ("other", FindBestResult(model->ClassifyText("obama", {0, 5}))); - EXPECT_EQ("other", FindBestResult(model->ClassifyText("asdf", {0, 4}))); - EXPECT_EQ("<INVALID RESULTS>", - FindBestResult(model->ClassifyText("asdf", {0, 0}))); - - // Junk. - EXPECT_EQ("<INVALID RESULTS>", - FindBestResult(model->ClassifyText("", {0, 0}))); - EXPECT_EQ("<INVALID RESULTS>", FindBestResult(model->ClassifyText( - "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5}))); -} - -TEST(TextClassificationModelTest, ClassifyTextWithHints) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - std::unique_ptr<TestingTextClassificationModel> model( - new TestingTextClassificationModel(fd)); - close(fd); - - // When EMAIL hint is passed, the result should be email. - EXPECT_EQ("email", - FindBestResult(model->ClassifyText( - "x", {0, 1}, TextClassificationModel::SELECTION_IS_EMAIL))); - // When URL hint is passed, the result should be email. - EXPECT_EQ("url", - FindBestResult(model->ClassifyText( - "x", {0, 1}, TextClassificationModel::SELECTION_IS_URL))); - // When both hints are passed, the result should be url (as it's probably - // better to let Chrome handle this case). - EXPECT_EQ("url", FindBestResult(model->ClassifyText( - "x", {0, 1}, - TextClassificationModel::SELECTION_IS_EMAIL | - TextClassificationModel::SELECTION_IS_URL))); - - // With disabled hints, we should get the same prediction regardless of the - // hint. - model->DisableClassificationHints(); - EXPECT_EQ(model->ClassifyText("x", {0, 1}, 0), - model->ClassifyText("x", {0, 1}, - TextClassificationModel::SELECTION_IS_EMAIL)); - - EXPECT_EQ(model->ClassifyText("x", {0, 1}, 0), - model->ClassifyText("x", {0, 1}, - TextClassificationModel::SELECTION_IS_URL)); -} - -TEST(TextClassificationModelTest, PhoneFiltering) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - std::unique_ptr<TestingTextClassificationModel> model( - new TestingTextClassificationModel(fd)); - close(fd); - - EXPECT_EQ("phone", FindBestResult(model->ClassifyText("phone: (123) 456 789", - {7, 20}, 0))); - EXPECT_EQ("phone", FindBestResult(model->ClassifyText( - "phone: (123) 456 789,0001112", {7, 25}, 0))); - EXPECT_EQ("other", FindBestResult(model->ClassifyText( - "phone: (123) 456 789,0001112", {7, 28}, 0))); -} - -TEST(TextClassificationModelTest, Annotate) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - std::unique_ptr<TestingTextClassificationModel> model( - new TestingTextClassificationModel(fd)); - close(fd); - - std::string test_string = - "& saw Barak Obama today .. 350 Third Street, Cambridge\nand my phone " - "number is 853 225-3556."; - std::vector<TextClassificationModel::AnnotatedSpan> result = - model->Annotate(test_string); - - std::vector<TextClassificationModel::AnnotatedSpan> expected; - expected.emplace_back(); - expected.back().span = {0, 0}; - expected.emplace_back(); - expected.back().span = {2, 5}; - expected.back().classification.push_back({"other", 1.0}); - expected.emplace_back(); - expected.back().span = {6, 17}; - expected.back().classification.push_back({"other", 1.0}); - expected.emplace_back(); - expected.back().span = {18, 23}; - expected.back().classification.push_back({"other", 1.0}); - expected.emplace_back(); - expected.back().span = {24, 24}; - expected.emplace_back(); - expected.back().span = {27, 54}; - expected.back().classification.push_back({"address", 1.0}); - expected.emplace_back(); - expected.back().span = {55, 58}; - expected.back().classification.push_back({"other", 1.0}); - expected.emplace_back(); - expected.back().span = {59, 61}; - expected.back().classification.push_back({"other", 1.0}); - expected.emplace_back(); - expected.back().span = {62, 74}; - expected.back().classification.push_back({"other", 1.0}); - expected.emplace_back(); - expected.back().span = {75, 77}; - expected.back().classification.push_back({"other", 1.0}); - expected.emplace_back(); - expected.back().span = {78, 90}; - expected.back().classification.push_back({"phone", 1.0}); - - EXPECT_EQ(result.size(), expected.size()); - for (int i = 0; i < expected.size(); ++i) { - EXPECT_EQ(result[i].span, expected[i].span) << result[i]; - if (!expected[i].classification.empty()) { - EXPECT_GT(result[i].classification.size(), 0); - EXPECT_EQ(result[i].classification[0].first, - expected[i].classification[0].first) - << result[i]; - } - } -} - -TEST(TextClassificationModelTest, URLEmailRegex) { - const std::string model_path = GetModelPath(); - int fd = open(model_path.c_str(), O_RDONLY); - std::unique_ptr<TestingTextClassificationModel> model( - new TestingTextClassificationModel(fd)); - close(fd); - - SharingModelOptions options; - SharingModelOptions::RegexPattern* email_pattern = - options.add_regex_pattern(); - email_pattern->set_collection_name("email"); - email_pattern->set_pattern(ReadFile(GetEmailRegexPath())); - SharingModelOptions::RegexPattern* url_pattern = options.add_regex_pattern(); - url_pattern->set_collection_name("url"); - url_pattern->set_pattern(ReadFile(GetURLRegexPath())); - - // TODO(b/69538802): Modify directly the model image instead. - model->InitializeSharingRegexPatterns( - {options.regex_pattern().begin(), options.regex_pattern().end()}); - - EXPECT_EQ("url", FindBestResult(model->ClassifyText( - "Visit www.google.com every today!", {6, 20}))); - EXPECT_EQ("email", FindBestResult(model->ClassifyText( - "My email: asdf@something.cz", {10, 27}))); - EXPECT_EQ("url", FindBestResult(model->ClassifyText( - "Login: http://asdf@something.cz", {7, 31}))); -} - -} // namespace -} // namespace libtextclassifier diff --git a/smartselect/tokenizer.h b/smartselect/tokenizer.h deleted file mode 100644 index 4eb78f9..0000000 --- a/smartselect/tokenizer.h +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_ -#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_ - -#include <string> -#include <vector> - -#include "smartselect/tokenizer.pb.h" -#include "smartselect/types.h" - -namespace libtextclassifier { - -// Tokenizer splits the input string into a sequence of tokens, according to the -// configuration. -class Tokenizer { - public: - explicit Tokenizer( - const std::vector<TokenizationCodepointRange>& codepoint_ranges); - - // Tokenizes the input string using the selected tokenization method. - std::vector<Token> Tokenize(const std::string& utf8_text) const; - - protected: - // Finds the tokenization role for given codepoint. - // If the character is not found returns DEFAULT_ROLE. - // Internally uses binary search so should be O(log(# of codepoint_ranges)). - TokenizationCodepointRange::Role FindTokenizationRole(int codepoint) const; - - private: - // Codepoint ranges that determine how different codepoints are tokenized. - // The ranges must not overlap. - std::vector<TokenizationCodepointRange> codepoint_ranges_; -}; - -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_ diff --git a/smartselect/tokenizer.proto b/smartselect/tokenizer.proto deleted file mode 100644 index 8e78970..0000000 --- a/smartselect/tokenizer.proto +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (C) 2017 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto2"; -option optimize_for = LITE_RUNTIME; - -package libtextclassifier; - -// Represents a codepoint range [start, end) with its role for tokenization. -message TokenizationCodepointRange { - optional int32 start = 1; - optional int32 end = 2; - - // Role of the codepoints in the range. - enum Role { - // Concatenates the codepoint to the current run of codepoints. - DEFAULT_ROLE = 0; - - // Splits a run of codepoints before the current codepoint. - SPLIT_BEFORE = 0x1; - - // Splits a run of codepoints after the current codepoint. - SPLIT_AFTER = 0x2; - - // Discards the codepoint. - DISCARD_CODEPOINT = 0x4; - - // Common values: - // Splits on the characters and discards them. Good e.g. for the space - // character. - WHITESPACE_SEPARATOR = 0x7; - // Each codepoint will be a separate token. Good e.g. for Chinese - // characters. - TOKEN_SEPARATOR = 0x3; - } - optional Role role = 3; -} diff --git a/smartselect/tokenizer_test.cc b/smartselect/tokenizer_test.cc deleted file mode 100644 index cdb90a9..0000000 --- a/smartselect/tokenizer_test.cc +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "smartselect/tokenizer.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace libtextclassifier { -namespace { - -using testing::ElementsAreArray; - -class TestingTokenizer : public Tokenizer { - public: - explicit TestingTokenizer( - const std::vector<TokenizationCodepointRange>& codepoint_range_configs) - : Tokenizer(codepoint_range_configs) {} - - TokenizationCodepointRange::Role TestFindTokenizationRole(int c) const { - return FindTokenizationRole(c); - } -}; - -TEST(TokenizerTest, FindTokenizationRole) { - std::vector<TokenizationCodepointRange> configs; - TokenizationCodepointRange* config; - - configs.emplace_back(); - config = &configs.back(); - config->set_start(0); - config->set_end(10); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - - configs.emplace_back(); - config = &configs.back(); - config->set_start(32); - config->set_end(33); - config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); - - configs.emplace_back(); - config = &configs.back(); - config->set_start(1234); - config->set_end(12345); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - - TestingTokenizer tokenizer(configs); - - // Test hits to the first group. - EXPECT_EQ(tokenizer.TestFindTokenizationRole(0), - TokenizationCodepointRange::TOKEN_SEPARATOR); - EXPECT_EQ(tokenizer.TestFindTokenizationRole(5), - TokenizationCodepointRange::TOKEN_SEPARATOR); - EXPECT_EQ(tokenizer.TestFindTokenizationRole(10), - TokenizationCodepointRange::DEFAULT_ROLE); - - // Test a hit to the second group. - EXPECT_EQ(tokenizer.TestFindTokenizationRole(31), - TokenizationCodepointRange::DEFAULT_ROLE); - EXPECT_EQ(tokenizer.TestFindTokenizationRole(32), - TokenizationCodepointRange::WHITESPACE_SEPARATOR); - EXPECT_EQ(tokenizer.TestFindTokenizationRole(33), - TokenizationCodepointRange::DEFAULT_ROLE); - - // Test hits to the third group. - EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233), - TokenizationCodepointRange::DEFAULT_ROLE); - EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234), - TokenizationCodepointRange::TOKEN_SEPARATOR); - EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344), - TokenizationCodepointRange::TOKEN_SEPARATOR); - EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345), - TokenizationCodepointRange::DEFAULT_ROLE); - - // Test a hit outside. - EXPECT_EQ(tokenizer.TestFindTokenizationRole(99), - TokenizationCodepointRange::DEFAULT_ROLE); -} - -TEST(TokenizerTest, TokenizeOnSpace) { - std::vector<TokenizationCodepointRange> configs; - TokenizationCodepointRange* config; - - configs.emplace_back(); - config = &configs.back(); - // Space character. - config->set_start(32); - config->set_end(33); - config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); - - TestingTokenizer tokenizer(configs); - std::vector<Token> tokens = tokenizer.Tokenize("Hello world!"); - - EXPECT_THAT(tokens, - ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)})); -} - -TEST(TokenizerTest, TokenizeComplex) { - std::vector<TokenizationCodepointRange> configs; - TokenizationCodepointRange* config; - - // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt - // Latin - cyrilic. - // 0000..007F; Basic Latin - // 0080..00FF; Latin-1 Supplement - // 0100..017F; Latin Extended-A - // 0180..024F; Latin Extended-B - // 0250..02AF; IPA Extensions - // 02B0..02FF; Spacing Modifier Letters - // 0300..036F; Combining Diacritical Marks - // 0370..03FF; Greek and Coptic - // 0400..04FF; Cyrillic - // 0500..052F; Cyrillic Supplement - // 0530..058F; Armenian - // 0590..05FF; Hebrew - // 0600..06FF; Arabic - // 0700..074F; Syriac - // 0750..077F; Arabic Supplement - configs.emplace_back(); - config = &configs.back(); - config->set_start(0); - config->set_end(32); - config->set_role(TokenizationCodepointRange::DEFAULT_ROLE); - configs.emplace_back(); - config = &configs.back(); - config->set_start(32); - config->set_end(33); - config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); - configs.emplace_back(); - config = &configs.back(); - config->set_start(33); - config->set_end(0x77F + 1); - config->set_role(TokenizationCodepointRange::DEFAULT_ROLE); - - // CJK - // 2E80..2EFF; CJK Radicals Supplement - // 3000..303F; CJK Symbols and Punctuation - // 3040..309F; Hiragana - // 30A0..30FF; Katakana - // 3100..312F; Bopomofo - // 3130..318F; Hangul Compatibility Jamo - // 3190..319F; Kanbun - // 31A0..31BF; Bopomofo Extended - // 31C0..31EF; CJK Strokes - // 31F0..31FF; Katakana Phonetic Extensions - // 3200..32FF; Enclosed CJK Letters and Months - // 3300..33FF; CJK Compatibility - // 3400..4DBF; CJK Unified Ideographs Extension A - // 4DC0..4DFF; Yijing Hexagram Symbols - // 4E00..9FFF; CJK Unified Ideographs - // A000..A48F; Yi Syllables - // A490..A4CF; Yi Radicals - // A4D0..A4FF; Lisu - // A500..A63F; Vai - // F900..FAFF; CJK Compatibility Ideographs - // FE30..FE4F; CJK Compatibility Forms - // 20000..2A6DF; CJK Unified Ideographs Extension B - // 2A700..2B73F; CJK Unified Ideographs Extension C - // 2B740..2B81F; CJK Unified Ideographs Extension D - // 2B820..2CEAF; CJK Unified Ideographs Extension E - // 2CEB0..2EBEF; CJK Unified Ideographs Extension F - // 2F800..2FA1F; CJK Compatibility Ideographs Supplement - configs.emplace_back(); - config = &configs.back(); - config->set_start(0x2E80); - config->set_end(0x2EFF + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - configs.emplace_back(); - config = &configs.back(); - config->set_start(0x3000); - config->set_end(0xA63F + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - configs.emplace_back(); - config = &configs.back(); - config->set_start(0xF900); - config->set_end(0xFAFF + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - configs.emplace_back(); - config = &configs.back(); - config->set_start(0xFE30); - config->set_end(0xFE4F + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - configs.emplace_back(); - config = &configs.back(); - config->set_start(0x20000); - config->set_end(0x2A6DF + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - configs.emplace_back(); - config = &configs.back(); - config->set_start(0x2A700); - config->set_end(0x2B73F + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - configs.emplace_back(); - config = &configs.back(); - config->set_start(0x2B740); - config->set_end(0x2B81F + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - configs.emplace_back(); - config = &configs.back(); - config->set_start(0x2B820); - config->set_end(0x2CEAF + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - configs.emplace_back(); - config = &configs.back(); - config->set_start(0x2CEB0); - config->set_end(0x2EBEF + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - configs.emplace_back(); - config = &configs.back(); - config->set_start(0x2F800); - config->set_end(0x2FA1F + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - - // Thai. - // 0E00..0E7F; Thai - configs.emplace_back(); - config = &configs.back(); - config->set_start(0x0E00); - config->set_end(0x0E7F + 1); - config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR); - - Tokenizer tokenizer(configs); - std::vector<Token> tokens; - - tokens = tokenizer.Tokenize( - "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。"); - EXPECT_EQ(tokens.size(), 30); - - tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ"); - // clang-format off - EXPECT_THAT( - tokens, - ElementsAreArray({Token("問", 0, 1), - Token("少", 1, 2), - Token("目", 2, 3), - Token("hello", 4, 9), - Token("木", 10, 11), - Token("輸", 11, 12), - Token("ย", 12, 13), - Token("า", 13, 14), - Token("ม", 14, 15), - Token("き", 15, 16), - Token("ゃ", 16, 17)})); - // clang-format on -} - -} // namespace -} // namespace libtextclassifier diff --git a/smartselect/types.h b/smartselect/types.h deleted file mode 100644 index 443e3ac..0000000 --- a/smartselect/types.h +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TYPES_H_ -#define LIBTEXTCLASSIFIER_SMARTSELECT_TYPES_H_ - -#include <ostream> -#include <string> -#include <utility> - -namespace libtextclassifier { - -constexpr int kInvalidIndex = -1; - -// Index for a 0-based array of tokens. -using TokenIndex = int; - -// Index for a 0-based array of codepoints. -using CodepointIndex = int; - -// Marks a span in a sequence of codepoints. The first element is the index of -// the first codepoint of the span, and the second element is the index of the -// codepoint one past the end of the span. -using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>; - -// Marks a span in a sequence of tokens. The first element is the index of the -// first token in the span, and the second element is the index of the token one -// past the end of the span. -using TokenSpan = std::pair<TokenIndex, TokenIndex>; - -// Token holds a token, its position in the original string and whether it was -// part of the input span. -struct Token { - std::string value; - CodepointIndex start; - CodepointIndex end; - - // Whether the token is a padding token. - bool is_padding; - - // Default constructor constructs the padding-token. - Token() - : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {} - - Token(const std::string& arg_value, CodepointIndex arg_start, - CodepointIndex arg_end) - : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {} - - bool operator==(const Token& other) const { - return value == other.value && start == other.start && end == other.end && - is_padding == other.is_padding; - } - - bool IsContainedInSpan(CodepointSpan span) const { - return start >= span.first && end <= span.second; - } -}; - -// Pretty-printing function for Token. -inline std::ostream& operator<<(std::ostream& os, const Token& token) { - return os << "Token(\"" << token.value << "\", " << token.start << ", " - << token.end << ", is_padding=" << token.is_padding << ")"; -} - -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TYPES_H_ diff --git a/strip-unpaired-brackets.cc b/strip-unpaired-brackets.cc new file mode 100644 index 0000000..ddf3322 --- /dev/null +++ b/strip-unpaired-brackets.cc @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "strip-unpaired-brackets.h" + +#include <iterator> + +#include "util/base/logging.h" +#include "util/utf8/unicodetext.h" + +namespace libtextclassifier2 { +namespace { + +// Returns true if given codepoint is contained in the given span in context. +bool IsCodepointInSpan(const char32 codepoint, + const UnicodeText& context_unicode, + const CodepointSpan span) { + auto begin_it = context_unicode.begin(); + std::advance(begin_it, span.first); + auto end_it = context_unicode.begin(); + std::advance(end_it, span.second); + + return std::find(begin_it, end_it, codepoint) != end_it; +} + +// Returns the first codepoint of the span. +char32 FirstSpanCodepoint(const UnicodeText& context_unicode, + const CodepointSpan span) { + auto it = context_unicode.begin(); + std::advance(it, span.first); + return *it; +} + +// Returns the last codepoint of the span. +char32 LastSpanCodepoint(const UnicodeText& context_unicode, + const CodepointSpan span) { + auto it = context_unicode.begin(); + std::advance(it, span.second - 1); + return *it; +} + +} // namespace + +CodepointSpan StripUnpairedBrackets(const std::string& context, + CodepointSpan span, const UniLib& unilib) { + const UnicodeText context_unicode = + UTF8ToUnicodeText(context, /*do_copy=*/false); + return StripUnpairedBrackets(context_unicode, span, unilib); +} + +// If the first or the last codepoint of the given span is a bracket, the +// bracket is stripped if the span does not contain its corresponding paired +// version. +CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, + CodepointSpan span, const UniLib& unilib) { + if (context_unicode.empty() || !ValidNonEmptySpan(span)) { + return span; + } + + const char32 begin_char = FirstSpanCodepoint(context_unicode, span); + const char32 paired_begin_char = unilib.GetPairedBracket(begin_char); + if (paired_begin_char != begin_char) { + if (!unilib.IsOpeningBracket(begin_char) || + !IsCodepointInSpan(paired_begin_char, context_unicode, span)) { + ++span.first; + } + } + + if (span.first == span.second) { + return span; + } + + const char32 end_char = LastSpanCodepoint(context_unicode, span); + const char32 paired_end_char = unilib.GetPairedBracket(end_char); + if (paired_end_char != end_char) { + if (!unilib.IsClosingBracket(end_char) || + !IsCodepointInSpan(paired_end_char, context_unicode, span)) { + --span.second; + } + } + + // Should not happen, but let's make sure. + if (span.first > span.second) { + TC_LOG(WARNING) << "Inverse indices result: " << span.first << ", " + << span.second; + span.second = span.first; + } + + return span; +} + +} // namespace libtextclassifier2 diff --git a/strip-unpaired-brackets.h b/strip-unpaired-brackets.h new file mode 100644 index 0000000..4e82c3e --- /dev/null +++ b/strip-unpaired-brackets.h @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_STRIP_UNPAIRED_BRACKETS_H_ +#define LIBTEXTCLASSIFIER_STRIP_UNPAIRED_BRACKETS_H_ + +#include <string> + +#include "types.h" +#include "util/utf8/unilib.h" + +namespace libtextclassifier2 { +// If the first or the last codepoint of the given span is a bracket, the +// bracket is stripped if the span does not contain its corresponding paired +// version. +CodepointSpan StripUnpairedBrackets(const std::string& context, + CodepointSpan span, const UniLib& unilib); + +// Same as above but takes UnicodeText instance directly. +CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, + CodepointSpan span, const UniLib& unilib); + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_STRIP_UNPAIRED_BRACKETS_H_ diff --git a/strip-unpaired-brackets_test.cc b/strip-unpaired-brackets_test.cc new file mode 100644 index 0000000..5362500 --- /dev/null +++ b/strip-unpaired-brackets_test.cc @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "strip-unpaired-brackets.h" + +#include "gtest/gtest.h" + +namespace libtextclassifier2 { +namespace { + +TEST(StripUnpairedBracketsTest, StripUnpairedBrackets) { + CREATE_UNILIB_FOR_TESTING + // If the brackets match, nothing gets stripped. + EXPECT_EQ(StripUnpairedBrackets("call me (123) 456 today", {8, 17}, unilib), + std::make_pair(8, 17)); + EXPECT_EQ(StripUnpairedBrackets("call me (123 456) today", {8, 17}, unilib), + std::make_pair(8, 17)); + + // If the brackets don't match, they get stripped. + EXPECT_EQ(StripUnpairedBrackets("call me (123 456 today", {8, 16}, unilib), + std::make_pair(9, 16)); + EXPECT_EQ(StripUnpairedBrackets("call me )123 456 today", {8, 16}, unilib), + std::make_pair(9, 16)); + EXPECT_EQ(StripUnpairedBrackets("call me 123 456) today", {8, 16}, unilib), + std::make_pair(8, 15)); + EXPECT_EQ(StripUnpairedBrackets("call me 123 456( today", {8, 16}, unilib), + std::make_pair(8, 15)); + + // Strips brackets correctly from length-1 selections that consist of + // a bracket only. + EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}, unilib), + std::make_pair(12, 12)); + EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}, unilib), + std::make_pair(12, 12)); + + // Handles invalid spans gracefully. + EXPECT_EQ(StripUnpairedBrackets("call me at today", {11, 11}, unilib), + std::make_pair(11, 11)); + EXPECT_EQ(StripUnpairedBrackets("hello world", {0, 0}, unilib), + std::make_pair(0, 0)); + EXPECT_EQ(StripUnpairedBrackets("hello world", {11, 11}, unilib), + std::make_pair(11, 11)); + EXPECT_EQ(StripUnpairedBrackets("hello world", {-1, -1}, unilib), + std::make_pair(-1, -1)); +} + +} // namespace +} // namespace libtextclassifier2 diff --git a/common/config.h b/tensor-view.cc index b883e95..4acadc5 100644 --- a/common/config.h +++ b/tensor-view.cc @@ -14,16 +14,18 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_COMMON_CONFIG_H_ -#define LIBTEXTCLASSIFIER_COMMON_CONFIG_H_ +#include "tensor-view.h" -#ifndef PORTABLE_SAFT_MOBILE -#if defined(__ANDROID__) || defined(__APPLE__) -#define PORTABLE_SAFT_MOBILE 1 -#else -#define PORTABLE_SAFT_MOBILE 0 -#endif +namespace libtextclassifier2 { -#endif +namespace internal { +int NumberOfElements(const std::vector<int>& shape) { + int size = 1; + for (const int dim : shape) { + size *= dim; + } + return size; +} +} // namespace internal -#endif // LIBTEXTCLASSIFIER_COMMON_CONFIG_H_ +} // namespace libtextclassifier2 diff --git a/tensor-view.h b/tensor-view.h new file mode 100644 index 0000000..00ab08c --- /dev/null +++ b/tensor-view.h @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_TENSOR_VIEW_H_ +#define LIBTEXTCLASSIFIER_TENSOR_VIEW_H_ + +#include <algorithm> +#include <vector> + +namespace libtextclassifier2 { +namespace internal { +// Computes the number of elements in a tensor of given shape. +int NumberOfElements(const std::vector<int>& shape); +} // namespace internal + +// View of a tensor of given type. +// NOTE: Does not own the underlying memory, so the contract about its validity +// needs to be specified on the interface that returns it. +template <typename T> +class TensorView { + public: + TensorView(const T* data, const std::vector<int>& shape) + : data_(data), shape_(shape), size_(internal::NumberOfElements(shape)) {} + + static TensorView Invalid() { + static std::vector<int>& invalid_shape = + *[]() { return new std::vector<int>(0); }(); + return TensorView(nullptr, invalid_shape); + } + + bool is_valid() const { return data_ != nullptr; } + + const std::vector<int>& shape() const { return shape_; } + + int dim(int i) const { return shape_[i]; } + + int dims() const { return shape_.size(); } + + const T* data() const { return data_; } + + int size() const { return size_; } + + bool copy_to(T* dest, int dest_size) const { + if (dest_size < size_) { + return false; + } + std::copy(data_, data_ + size_, dest); + return true; + } + + private: + const T* data_ = nullptr; + const std::vector<int> shape_; + const int size_; +}; + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_TENSOR_VIEW_H_ diff --git a/tensor-view_test.cc b/tensor-view_test.cc new file mode 100644 index 0000000..d50fac7 --- /dev/null +++ b/tensor-view_test.cc @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensor-view.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier2 { +namespace { + +TEST(TensorViewTest, TestSize) { + std::vector<float> data{0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + const TensorView<float> tensor(data.data(), {3, 1, 2}); + EXPECT_TRUE(tensor.is_valid()); + EXPECT_EQ(tensor.shape(), (std::vector<int>{3, 1, 2})); + EXPECT_EQ(tensor.data(), data.data()); + EXPECT_EQ(tensor.size(), 6); + EXPECT_EQ(tensor.dims(), 3); + EXPECT_EQ(tensor.dim(0), 3); + EXPECT_EQ(tensor.dim(1), 1); + EXPECT_EQ(tensor.dim(2), 2); + std::vector<float> output_data(6); + EXPECT_TRUE(tensor.copy_to(output_data.data(), output_data.size())); + EXPECT_EQ(data, output_data); + + // Should not copy when the output is small. + std::vector<float> small_output_data{-1, -1, -1}; + EXPECT_FALSE( + tensor.copy_to(small_output_data.data(), small_output_data.size())); + // The output buffer should not be changed. + EXPECT_EQ(small_output_data, (std::vector<float>{-1, -1, -1})); + + const TensorView<float> invalid_tensor = TensorView<float>::Invalid(); + EXPECT_FALSE(invalid_tensor.is_valid()); +} + +} // namespace +} // namespace libtextclassifier2 diff --git a/test_data/test_model.fb b/test_data/test_model.fb Binary files differnew file mode 100644 index 0000000..c651bdb --- /dev/null +++ b/test_data/test_model.fb diff --git a/test_data/test_model_cc.fb b/test_data/test_model_cc.fb Binary files differnew file mode 100644 index 0000000..53af6bf --- /dev/null +++ b/test_data/test_model_cc.fb diff --git a/test_data/wrong_embeddings.fb b/test_data/wrong_embeddings.fb Binary files differnew file mode 100644 index 0000000..e1aa3ea --- /dev/null +++ b/test_data/wrong_embeddings.fb diff --git a/tests/testdata/langid.model b/tests/testdata/langid.model Binary files differdeleted file mode 100644 index 6b68223..0000000 --- a/tests/testdata/langid.model +++ /dev/null diff --git a/tests/testdata/smartselection.model b/tests/testdata/smartselection.model Binary files differdeleted file mode 100644 index 645303d..0000000 --- a/tests/testdata/smartselection.model +++ /dev/null diff --git a/text-classifier.cc b/text-classifier.cc new file mode 100644 index 0000000..e20813a --- /dev/null +++ b/text-classifier.cc @@ -0,0 +1,1576 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "text-classifier.h" + +#include <algorithm> +#include <cctype> +#include <cmath> +#include <iterator> +#include <numeric> + +#include "util/base/logging.h" +#include "util/math/softmax.h" +#include "util/utf8/unicodetext.h" + +namespace libtextclassifier2 { +const std::string& TextClassifier::kOtherCollection = + *[]() { return new std::string("other"); }(); +const std::string& TextClassifier::kPhoneCollection = + *[]() { return new std::string("phone"); }(); +const std::string& TextClassifier::kAddressCollection = + *[]() { return new std::string("address"); }(); +const std::string& TextClassifier::kDateCollection = + *[]() { return new std::string("date"); }(); + +namespace { +const Model* LoadAndVerifyModel(const void* addr, int size) { + const Model* model = GetModel(addr); + + flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size); + if (model->Verify(verifier)) { + return model; + } else { + return nullptr; + } +} +} // namespace + +tflite::Interpreter* InterpreterManager::SelectionInterpreter() { + if (!selection_interpreter_) { + TC_CHECK(selection_executor_); + selection_interpreter_ = selection_executor_->CreateInterpreter(); + if (!selection_interpreter_) { + TC_LOG(ERROR) << "Could not build TFLite interpreter."; + } + } + return selection_interpreter_.get(); +} + +tflite::Interpreter* InterpreterManager::ClassificationInterpreter() { + if (!classification_interpreter_) { + TC_CHECK(classification_executor_); + classification_interpreter_ = classification_executor_->CreateInterpreter(); + if (!classification_interpreter_) { + TC_LOG(ERROR) << "Could not build TFLite interpreter."; + } + } + return classification_interpreter_.get(); +} + +std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer( + const char* buffer, int size, const UniLib* unilib) { + const Model* model = LoadAndVerifyModel(buffer, size); + if (model == nullptr) { + return nullptr; + } + + auto classifier = + std::unique_ptr<TextClassifier>(new TextClassifier(model, unilib)); + if (!classifier->IsInitialized()) { + return nullptr; + } + + return classifier; +} + +std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap( + std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib) { + if (!(*mmap)->handle().ok()) { + TC_VLOG(1) << "Mmap failed."; + return nullptr; + } + + const Model* model = LoadAndVerifyModel((*mmap)->handle().start(), + (*mmap)->handle().num_bytes()); + if (!model) { + TC_LOG(ERROR) << "Model verification failed."; + return nullptr; + } + + auto classifier = + std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model, unilib)); + if (!classifier->IsInitialized()) { + return nullptr; + } + + return classifier; +} + +std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor( + int fd, int offset, int size, const UniLib* unilib) { + std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size)); + return FromScopedMmap(&mmap, unilib); +} + +std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor( + int fd, const UniLib* unilib) { + std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd)); + return FromScopedMmap(&mmap, unilib); +} + +std::unique_ptr<TextClassifier> TextClassifier::FromPath( + const std::string& path, const UniLib* unilib) { + std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path)); + return FromScopedMmap(&mmap, unilib); +} + +void TextClassifier::ValidateAndInitialize() { + initialized_ = false; + + if (model_ == nullptr) { + TC_LOG(ERROR) << "No model specified."; + return; + } + + const bool model_enabled_for_annotation = + (model_->triggering_options() != nullptr && + (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)); + const bool model_enabled_for_classification = + (model_->triggering_options() != nullptr && + (model_->triggering_options()->enabled_modes() & + ModeFlag_CLASSIFICATION)); + const bool model_enabled_for_selection = + (model_->triggering_options() != nullptr && + (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)); + + // Annotation requires the selection model. + if (model_enabled_for_annotation || model_enabled_for_selection) { + if (!model_->selection_options()) { + TC_LOG(ERROR) << "No selection options."; + return; + } + if (!model_->selection_feature_options()) { + TC_LOG(ERROR) << "No selection feature options."; + return; + } + if (!model_->selection_feature_options()->bounds_sensitive_features()) { + TC_LOG(ERROR) << "No selection bounds sensitive feature options."; + return; + } + if (!model_->selection_model()) { + TC_LOG(ERROR) << "No selection model."; + return; + } + selection_executor_ = ModelExecutor::Instance(model_->selection_model()); + if (!selection_executor_) { + TC_LOG(ERROR) << "Could not initialize selection executor."; + return; + } + selection_feature_processor_.reset( + new FeatureProcessor(model_->selection_feature_options(), unilib_)); + } + + // Annotation requires the classification model for conflict resolution and + // scoring. + // Selection requires the classification model for conflict resolution. + if (model_enabled_for_annotation || model_enabled_for_classification || + model_enabled_for_selection) { + if (!model_->classification_options()) { + TC_LOG(ERROR) << "No classification options."; + return; + } + + if (!model_->classification_feature_options()) { + TC_LOG(ERROR) << "No classification feature options."; + return; + } + + if (!model_->classification_feature_options() + ->bounds_sensitive_features()) { + TC_LOG(ERROR) << "No classification bounds sensitive feature options."; + return; + } + if (!model_->classification_model()) { + TC_LOG(ERROR) << "No clf model."; + return; + } + + classification_executor_ = + ModelExecutor::Instance(model_->classification_model()); + if (!classification_executor_) { + TC_LOG(ERROR) << "Could not initialize classification executor."; + return; + } + + classification_feature_processor_.reset(new FeatureProcessor( + model_->classification_feature_options(), unilib_)); + } + + // The embeddings need to be specified if the model is to be used for + // classification or selection. + if (model_enabled_for_annotation || model_enabled_for_classification || + model_enabled_for_selection) { + if (!model_->embedding_model()) { + TC_LOG(ERROR) << "No embedding model."; + return; + } + + // Check that the embedding size of the selection and classification model + // matches, as they are using the same embeddings. + if (model_enabled_for_selection && + (model_->selection_feature_options()->embedding_size() != + model_->classification_feature_options()->embedding_size() || + model_->selection_feature_options()->embedding_quantization_bits() != + model_->classification_feature_options() + ->embedding_quantization_bits())) { + TC_LOG(ERROR) << "Mismatching embedding size/quantization."; + return; + } + + embedding_executor_ = TFLiteEmbeddingExecutor::Instance( + model_->embedding_model(), + model_->classification_feature_options()->embedding_size(), + model_->classification_feature_options() + ->embedding_quantization_bits()); + if (!embedding_executor_) { + TC_LOG(ERROR) << "Could not initialize embedding executor."; + return; + } + } + + std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); + if (model_->regex_model()) { + if (!InitializeRegexModel(decompressor.get())) { + TC_LOG(ERROR) << "Could not initialize regex model."; + return; + } + } + + if (model_->datetime_model()) { + datetime_parser_ = DatetimeParser::Instance(model_->datetime_model(), + *unilib_, decompressor.get()); + if (!datetime_parser_) { + TC_LOG(ERROR) << "Could not initialize datetime parser."; + return; + } + } + + if (model_->output_options()) { + if (model_->output_options()->filtered_collections_annotation()) { + for (const auto collection : + *model_->output_options()->filtered_collections_annotation()) { + filtered_collections_annotation_.insert(collection->str()); + } + } + if (model_->output_options()->filtered_collections_classification()) { + for (const auto collection : + *model_->output_options()->filtered_collections_classification()) { + filtered_collections_classification_.insert(collection->str()); + } + } + if (model_->output_options()->filtered_collections_selection()) { + for (const auto collection : + *model_->output_options()->filtered_collections_selection()) { + filtered_collections_selection_.insert(collection->str()); + } + } + } + + initialized_ = true; +} + +bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) { + if (!model_->regex_model()->patterns()) { + return true; + } + + // Initialize pattern recognizers. + int regex_pattern_id = 0; + for (const auto& regex_pattern : *model_->regex_model()->patterns()) { + std::unique_ptr<UniLib::RegexPattern> compiled_pattern = + UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(), + regex_pattern->compressed_pattern(), + decompressor); + if (!compiled_pattern) { + TC_LOG(INFO) << "Failed to load regex pattern"; + return false; + } + + if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) { + annotation_regex_patterns_.push_back(regex_pattern_id); + } + if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) { + classification_regex_patterns_.push_back(regex_pattern_id); + } + if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) { + selection_regex_patterns_.push_back(regex_pattern_id); + } + regex_patterns_.push_back({regex_pattern->collection_name()->str(), + regex_pattern->target_classification_score(), + regex_pattern->priority_score(), + std::move(compiled_pattern)}); + if (regex_pattern->use_approximate_matching()) { + regex_approximate_match_pattern_ids_.insert(regex_pattern_id); + } + ++regex_pattern_id; + } + + return true; +} + +namespace { + +int CountDigits(const std::string& str, CodepointSpan selection_indices) { + int count = 0; + int i = 0; + const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false); + for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) { + if (i >= selection_indices.first && i < selection_indices.second && + isdigit(*it)) { + ++count; + } + } + return count; +} + +std::string ExtractSelection(const std::string& context, + CodepointSpan selection_indices) { + const UnicodeText context_unicode = + UTF8ToUnicodeText(context, /*do_copy=*/false); + auto selection_begin = context_unicode.begin(); + std::advance(selection_begin, selection_indices.first); + auto selection_end = context_unicode.begin(); + std::advance(selection_end, selection_indices.second); + return UnicodeText::UTF8Substring(selection_begin, selection_end); +} +} // namespace + +namespace internal { +// Helper function, which if the initial 'span' contains only white-spaces, +// moves the selection to a single-codepoint selection on a left or right side +// of this space. +CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, + const UnicodeText& context_unicode, + const UniLib& unilib) { + TC_CHECK(ValidNonEmptySpan(span)); + + UnicodeText::const_iterator it; + + // Check that the current selection is all whitespaces. + it = context_unicode.begin(); + std::advance(it, span.first); + for (int i = 0; i < (span.second - span.first); ++i, ++it) { + if (!unilib.IsWhitespace(*it)) { + return span; + } + } + + CodepointSpan result; + + // Try moving left. + result = span; + it = context_unicode.begin(); + std::advance(it, span.first); + while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) { + --result.first; + --it; + } + result.second = result.first + 1; + if (!unilib.IsWhitespace(*it)) { + return result; + } + + // If moving left didn't find a non-whitespace character, just return the + // original span. + return span; +} +} // namespace internal + +bool TextClassifier::FilteredForAnnotation(const AnnotatedSpan& span) const { + return !span.classification.empty() && + filtered_collections_annotation_.find( + span.classification[0].collection) != + filtered_collections_annotation_.end(); +} + +bool TextClassifier::FilteredForClassification( + const ClassificationResult& classification) const { + return filtered_collections_classification_.find(classification.collection) != + filtered_collections_classification_.end(); +} + +bool TextClassifier::FilteredForSelection(const AnnotatedSpan& span) const { + return !span.classification.empty() && + filtered_collections_selection_.find( + span.classification[0].collection) != + filtered_collections_selection_.end(); +} + +CodepointSpan TextClassifier::SuggestSelection( + const std::string& context, CodepointSpan click_indices, + const SelectionOptions& options) const { + CodepointSpan original_click_indices = click_indices; + if (!initialized_) { + TC_LOG(ERROR) << "Not initialized"; + return original_click_indices; + } + if (!(model_->enabled_modes() & ModeFlag_SELECTION)) { + return original_click_indices; + } + + const UnicodeText context_unicode = UTF8ToUnicodeText(context, + /*do_copy=*/false); + + if (!context_unicode.is_valid()) { + return original_click_indices; + } + + const int context_codepoint_size = context_unicode.size_codepoints(); + + if (click_indices.first < 0 || click_indices.second < 0 || + click_indices.first >= context_codepoint_size || + click_indices.second > context_codepoint_size || + click_indices.first >= click_indices.second) { + TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: " + << click_indices.first << " " << click_indices.second; + return original_click_indices; + } + + if (model_->snap_whitespace_selections()) { + // We want to expand a purely white-space selection to a multi-selection it + // would've been part of. But with this feature disabled we would do a no- + // op, because no token is found. Therefore, we need to modify the + // 'click_indices' a bit to include a part of the token, so that the click- + // finding logic finds the clicked token correctly. This modification is + // done by the following function. Note, that it's enough to check the left + // side of the current selection, because if the white-space is a part of a + // multi-selection, neccessarily both tokens - on the left and the right + // sides need to be selected. Thus snapping only to the left is sufficient + // (there's a check at the bottom that makes sure that if we snap to the + // left token but the result does not contain the initial white-space, + // returns the original indices). + click_indices = internal::SnapLeftIfWhitespaceSelection( + click_indices, context_unicode, *unilib_); + } + + std::vector<AnnotatedSpan> candidates; + InterpreterManager interpreter_manager(selection_executor_.get(), + classification_executor_.get()); + std::vector<Token> tokens; + if (!ModelSuggestSelection(context_unicode, click_indices, + &interpreter_manager, &tokens, &candidates)) { + TC_LOG(ERROR) << "Model suggest selection failed."; + return original_click_indices; + } + if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) { + TC_LOG(ERROR) << "Regex suggest selection failed."; + return original_click_indices; + } + if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), + /*reference_time_ms_utc=*/0, /*reference_timezone=*/"", + options.locales, ModeFlag_SELECTION, &candidates)) { + TC_LOG(ERROR) << "Datetime suggest selection failed."; + return original_click_indices; + } + + // Sort candidates according to their position in the input, so that the next + // code can assume that any connected component of overlapping spans forms a + // contiguous block. + std::sort(candidates.begin(), candidates.end(), + [](const AnnotatedSpan& a, const AnnotatedSpan& b) { + return a.span.first < b.span.first; + }); + + std::vector<int> candidate_indices; + if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, + &candidate_indices)) { + TC_LOG(ERROR) << "Couldn't resolve conflicts."; + return original_click_indices; + } + + for (const int i : candidate_indices) { + if (SpansOverlap(candidates[i].span, click_indices) && + SpansOverlap(candidates[i].span, original_click_indices)) { + // Run model classification if not present but requested and there's a + // classification collection filter specified. + if (candidates[i].classification.empty() && + model_->selection_options()->always_classify_suggested_selection() && + !filtered_collections_selection_.empty()) { + if (!ModelClassifyText( + context, candidates[i].span, &interpreter_manager, + /*embedding_cache=*/nullptr, &candidates[i].classification)) { + return original_click_indices; + } + } + + // Ignore if span classification is filtered. + if (FilteredForSelection(candidates[i])) { + return original_click_indices; + } + + return candidates[i].span; + } + } + + return original_click_indices; +} + +namespace { +// Helper function that returns the index of the first candidate that +// transitively does not overlap with the candidate on 'start_index'. If the end +// of 'candidates' is reached, it returns the index that points right behind the +// array. +int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates, + int start_index) { + int first_non_overlapping = start_index + 1; + CodepointSpan conflicting_span = candidates[start_index].span; + while ( + first_non_overlapping < candidates.size() && + SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) { + // Grow the span to include the current one. + conflicting_span.second = std::max( + conflicting_span.second, candidates[first_non_overlapping].span.second); + + ++first_non_overlapping; + } + return first_non_overlapping; +} +} // namespace + +bool TextClassifier::ResolveConflicts( + const std::vector<AnnotatedSpan>& candidates, const std::string& context, + const std::vector<Token>& cached_tokens, + InterpreterManager* interpreter_manager, std::vector<int>* result) const { + result->clear(); + result->reserve(candidates.size()); + for (int i = 0; i < candidates.size();) { + int first_non_overlapping = + FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i); + + const bool conflict_found = first_non_overlapping != (i + 1); + if (conflict_found) { + std::vector<int> candidate_indices; + if (!ResolveConflict(context, cached_tokens, candidates, i, + first_non_overlapping, interpreter_manager, + &candidate_indices)) { + return false; + } + result->insert(result->end(), candidate_indices.begin(), + candidate_indices.end()); + } else { + result->push_back(i); + } + + // Skip over the whole conflicting group/go to next candidate. + i = first_non_overlapping; + } + return true; +} + +namespace { +inline bool ClassifiedAsOther( + const std::vector<ClassificationResult>& classification) { + return !classification.empty() && + classification[0].collection == TextClassifier::kOtherCollection; +} + +float GetPriorityScore( + const std::vector<ClassificationResult>& classification) { + if (!ClassifiedAsOther(classification)) { + return classification[0].priority_score; + } else { + return -1.0; + } +} +} // namespace + +bool TextClassifier::ResolveConflict( + const std::string& context, const std::vector<Token>& cached_tokens, + const std::vector<AnnotatedSpan>& candidates, int start_index, + int end_index, InterpreterManager* interpreter_manager, + std::vector<int>* chosen_indices) const { + std::vector<int> conflicting_indices; + std::unordered_map<int, float> scores; + for (int i = start_index; i < end_index; ++i) { + conflicting_indices.push_back(i); + if (!candidates[i].classification.empty()) { + scores[i] = GetPriorityScore(candidates[i].classification); + continue; + } + + // OPTIMIZATION: So that we don't have to classify all the ML model + // spans apriori, we wait until we get here, when they conflict with + // something and we need the actual classification scores. So if the + // candidate conflicts and comes from the model, we need to run a + // classification to determine its priority: + std::vector<ClassificationResult> classification; + if (!ModelClassifyText(context, cached_tokens, candidates[i].span, + interpreter_manager, + /*embedding_cache=*/nullptr, &classification)) { + return false; + } + + if (!classification.empty()) { + scores[i] = GetPriorityScore(classification); + } + } + + std::sort(conflicting_indices.begin(), conflicting_indices.end(), + [&scores](int i, int j) { return scores[i] > scores[j]; }); + + // Keeps the candidates sorted by their position in the text (their left span + // index) for fast retrieval down. + std::set<int, std::function<bool(int, int)>> chosen_indices_set( + [&candidates](int a, int b) { + return candidates[a].span.first < candidates[b].span.first; + }); + + // Greedily place the candidates if they don't conflict with the already + // placed ones. + for (int i = 0; i < conflicting_indices.size(); ++i) { + const int considered_candidate = conflicting_indices[i]; + if (!DoesCandidateConflict(considered_candidate, candidates, + chosen_indices_set)) { + chosen_indices_set.insert(considered_candidate); + } + } + + *chosen_indices = + std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end()); + + return true; +} + +bool TextClassifier::ModelSuggestSelection( + const UnicodeText& context_unicode, CodepointSpan click_indices, + InterpreterManager* interpreter_manager, std::vector<Token>* tokens, + std::vector<AnnotatedSpan>* result) const { + if (model_->triggering_options() == nullptr || + !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) { + return true; + } + + int click_pos; + *tokens = selection_feature_processor_->Tokenize(context_unicode); + selection_feature_processor_->RetokenizeAndFindClick( + context_unicode, click_indices, + selection_feature_processor_->GetOptions()->only_use_line_with_click(), + tokens, &click_pos); + if (click_pos == kInvalidIndex) { + TC_VLOG(1) << "Could not calculate the click position."; + return false; + } + + const int symmetry_context_size = + model_->selection_options()->symmetry_context_size(); + const FeatureProcessorOptions_::BoundsSensitiveFeatures* + bounds_sensitive_features = selection_feature_processor_->GetOptions() + ->bounds_sensitive_features(); + + // The symmetry context span is the clicked token with symmetry_context_size + // tokens on either side. + const TokenSpan symmetry_context_span = IntersectTokenSpans( + ExpandTokenSpan(SingleTokenSpan(click_pos), + /*num_tokens_left=*/symmetry_context_size, + /*num_tokens_right=*/symmetry_context_size), + {0, tokens->size()}); + + // Compute the extraction span based on the model type. + TokenSpan extraction_span; + if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { + // The extraction span is the symmetry context span expanded to include + // max_selection_span tokens on either side, which is how far a selection + // can stretch from the click, plus a relevant number of tokens outside of + // the bounds of the selection. + const int max_selection_span = + selection_feature_processor_->GetOptions()->max_selection_span(); + extraction_span = + ExpandTokenSpan(symmetry_context_span, + /*num_tokens_left=*/max_selection_span + + bounds_sensitive_features->num_tokens_before(), + /*num_tokens_right=*/max_selection_span + + bounds_sensitive_features->num_tokens_after()); + } else { + // The extraction span is the symmetry context span expanded to include + // context_size tokens on either side. + const int context_size = + selection_feature_processor_->GetOptions()->context_size(); + extraction_span = ExpandTokenSpan(symmetry_context_span, + /*num_tokens_left=*/context_size, + /*num_tokens_right=*/context_size); + } + extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()}); + + if (!selection_feature_processor_->HasEnoughSupportedCodepoints( + *tokens, extraction_span)) { + return true; + } + + std::unique_ptr<CachedFeatures> cached_features; + if (!selection_feature_processor_->ExtractFeatures( + *tokens, extraction_span, + /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, + embedding_executor_.get(), + /*embedding_cache=*/nullptr, + selection_feature_processor_->EmbeddingSize() + + selection_feature_processor_->DenseFeaturesCount(), + &cached_features)) { + TC_LOG(ERROR) << "Could not extract features."; + return false; + } + + // Produce selection model candidates. + std::vector<TokenSpan> chunks; + if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span, + interpreter_manager->SelectionInterpreter(), *cached_features, + &chunks)) { + TC_LOG(ERROR) << "Could not chunk."; + return false; + } + + for (const TokenSpan& chunk : chunks) { + AnnotatedSpan candidate; + candidate.span = selection_feature_processor_->StripBoundaryCodepoints( + context_unicode, TokenSpanToCodepointSpan(*tokens, chunk)); + if (model_->selection_options()->strip_unpaired_brackets()) { + candidate.span = + StripUnpairedBrackets(context_unicode, candidate.span, *unilib_); + } + + // Only output non-empty spans. + if (candidate.span.first != candidate.span.second) { + result->push_back(candidate); + } + } + return true; +} + +bool TextClassifier::ModelClassifyText( + const std::string& context, CodepointSpan selection_indices, + InterpreterManager* interpreter_manager, + FeatureProcessor::EmbeddingCache* embedding_cache, + std::vector<ClassificationResult>* classification_results) const { + if (model_->triggering_options() == nullptr || + !(model_->triggering_options()->enabled_modes() & + ModeFlag_CLASSIFICATION)) { + return true; + } + return ModelClassifyText(context, {}, selection_indices, interpreter_manager, + embedding_cache, classification_results); +} + +namespace internal { +std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, + CodepointSpan selection_indices, + TokenSpan tokens_around_selection_to_copy) { + const auto first_selection_token = std::upper_bound( + cached_tokens.begin(), cached_tokens.end(), selection_indices.first, + [](int selection_start, const Token& token) { + return selection_start < token.end; + }); + const auto last_selection_token = std::lower_bound( + cached_tokens.begin(), cached_tokens.end(), selection_indices.second, + [](const Token& token, int selection_end) { + return token.start < selection_end; + }); + + const int64 first_token = std::max( + static_cast<int64>(0), + static_cast<int64>((first_selection_token - cached_tokens.begin()) - + tokens_around_selection_to_copy.first)); + const int64 last_token = std::min( + static_cast<int64>(cached_tokens.size()), + static_cast<int64>((last_selection_token - cached_tokens.begin()) + + tokens_around_selection_to_copy.second)); + + std::vector<Token> tokens; + tokens.reserve(last_token - first_token); + for (int i = first_token; i < last_token; ++i) { + tokens.push_back(cached_tokens[i]); + } + return tokens; +} +} // namespace internal + +TokenSpan TextClassifier::ClassifyTextUpperBoundNeededTokens() const { + const FeatureProcessorOptions_::BoundsSensitiveFeatures* + bounds_sensitive_features = + classification_feature_processor_->GetOptions() + ->bounds_sensitive_features(); + if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { + // The extraction span is the selection span expanded to include a relevant + // number of tokens outside of the bounds of the selection. + return {bounds_sensitive_features->num_tokens_before(), + bounds_sensitive_features->num_tokens_after()}; + } else { + // The extraction span is the clicked token with context_size tokens on + // either side. + const int context_size = + selection_feature_processor_->GetOptions()->context_size(); + return {context_size, context_size}; + } +} + +bool TextClassifier::ModelClassifyText( + const std::string& context, const std::vector<Token>& cached_tokens, + CodepointSpan selection_indices, InterpreterManager* interpreter_manager, + FeatureProcessor::EmbeddingCache* embedding_cache, + std::vector<ClassificationResult>* classification_results) const { + std::vector<Token> tokens; + if (cached_tokens.empty()) { + tokens = classification_feature_processor_->Tokenize(context); + } else { + tokens = internal::CopyCachedTokens(cached_tokens, selection_indices, + ClassifyTextUpperBoundNeededTokens()); + } + + int click_pos; + classification_feature_processor_->RetokenizeAndFindClick( + context, selection_indices, + classification_feature_processor_->GetOptions() + ->only_use_line_with_click(), + &tokens, &click_pos); + const TokenSpan selection_token_span = + CodepointSpanToTokenSpan(tokens, selection_indices); + const int selection_num_tokens = TokenSpanSize(selection_token_span); + if (model_->classification_options()->max_num_tokens() > 0 && + model_->classification_options()->max_num_tokens() < + selection_num_tokens) { + *classification_results = {{kOtherCollection, 1.0}}; + return true; + } + + const FeatureProcessorOptions_::BoundsSensitiveFeatures* + bounds_sensitive_features = + classification_feature_processor_->GetOptions() + ->bounds_sensitive_features(); + if (selection_token_span.first == kInvalidIndex || + selection_token_span.second == kInvalidIndex) { + TC_LOG(ERROR) << "Could not determine span."; + return false; + } + + // Compute the extraction span based on the model type. + TokenSpan extraction_span; + if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { + // The extraction span is the selection span expanded to include a relevant + // number of tokens outside of the bounds of the selection. + extraction_span = ExpandTokenSpan( + selection_token_span, + /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(), + /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after()); + } else { + if (click_pos == kInvalidIndex) { + TC_LOG(ERROR) << "Couldn't choose a click position."; + return false; + } + // The extraction span is the clicked token with context_size tokens on + // either side. + const int context_size = + classification_feature_processor_->GetOptions()->context_size(); + extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos), + /*num_tokens_left=*/context_size, + /*num_tokens_right=*/context_size); + } + extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()}); + + if (!classification_feature_processor_->HasEnoughSupportedCodepoints( + tokens, extraction_span)) { + *classification_results = {{kOtherCollection, 1.0}}; + return true; + } + + std::unique_ptr<CachedFeatures> cached_features; + if (!classification_feature_processor_->ExtractFeatures( + tokens, extraction_span, selection_indices, embedding_executor_.get(), + embedding_cache, + classification_feature_processor_->EmbeddingSize() + + classification_feature_processor_->DenseFeaturesCount(), + &cached_features)) { + TC_LOG(ERROR) << "Could not extract features."; + return false; + } + + std::vector<float> features; + features.reserve(cached_features->OutputFeaturesSize()); + if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { + cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span, + &features); + } else { + cached_features->AppendClickContextFeaturesForClick(click_pos, &features); + } + + TensorView<float> logits = classification_executor_->ComputeLogits( + TensorView<float>(features.data(), + {1, static_cast<int>(features.size())}), + interpreter_manager->ClassificationInterpreter()); + if (!logits.is_valid()) { + TC_LOG(ERROR) << "Couldn't compute logits."; + return false; + } + + if (logits.dims() != 2 || logits.dim(0) != 1 || + logits.dim(1) != classification_feature_processor_->NumCollections()) { + TC_LOG(ERROR) << "Mismatching output"; + return false; + } + + const std::vector<float> scores = + ComputeSoftmax(logits.data(), logits.dim(1)); + + classification_results->resize(scores.size()); + for (int i = 0; i < scores.size(); i++) { + (*classification_results)[i] = { + classification_feature_processor_->LabelToCollection(i), scores[i]}; + } + std::sort(classification_results->begin(), classification_results->end(), + [](const ClassificationResult& a, const ClassificationResult& b) { + return a.score > b.score; + }); + + // Phone class sanity check. + if (!classification_results->empty() && + classification_results->begin()->collection == kPhoneCollection) { + const int digit_count = CountDigits(context, selection_indices); + if (digit_count < + model_->classification_options()->phone_min_num_digits() || + digit_count > + model_->classification_options()->phone_max_num_digits()) { + *classification_results = {{kOtherCollection, 1.0}}; + } + } + + // Address class sanity check. + if (!classification_results->empty() && + classification_results->begin()->collection == kAddressCollection) { + if (selection_num_tokens < + model_->classification_options()->address_min_num_tokens()) { + *classification_results = {{kOtherCollection, 1.0}}; + } + } + + return true; +} + +bool TextClassifier::RegexClassifyText( + const std::string& context, CodepointSpan selection_indices, + ClassificationResult* classification_result) const { + const std::string selection_text = + ExtractSelection(context, selection_indices); + const UnicodeText selection_text_unicode( + UTF8ToUnicodeText(selection_text, /*do_copy=*/false)); + + // Check whether any of the regular expressions match. + for (const int pattern_id : classification_regex_patterns_) { + const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; + const std::unique_ptr<UniLib::RegexMatcher> matcher = + regex_pattern.pattern->Matcher(selection_text_unicode); + int status = UniLib::RegexMatcher::kNoError; + bool matches; + if (regex_approximate_match_pattern_ids_.find(pattern_id) != + regex_approximate_match_pattern_ids_.end()) { + matches = matcher->ApproximatelyMatches(&status); + } else { + matches = matcher->Matches(&status); + } + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + if (matches) { + *classification_result = {regex_pattern.collection_name, + regex_pattern.target_classification_score, + regex_pattern.priority_score}; + return true; + } + if (status != UniLib::RegexMatcher::kNoError) { + TC_LOG(ERROR) << "Cound't match regex: " << pattern_id; + } + } + + return false; +} + +bool TextClassifier::DatetimeClassifyText( + const std::string& context, CodepointSpan selection_indices, + const ClassificationOptions& options, + ClassificationResult* classification_result) const { + if (!datetime_parser_) { + return false; + } + + const std::string selection_text = + ExtractSelection(context, selection_indices); + + std::vector<DatetimeParseResultSpan> datetime_spans; + if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc, + options.reference_timezone, options.locales, + ModeFlag_CLASSIFICATION, + /*anchor_start_end=*/true, &datetime_spans)) { + TC_LOG(ERROR) << "Error during parsing datetime."; + return false; + } + for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { + // Only consider the result valid if the selection and extracted datetime + // spans exactly match. + if (std::make_pair(datetime_span.span.first + selection_indices.first, + datetime_span.span.second + selection_indices.first) == + selection_indices) { + *classification_result = {kDateCollection, + datetime_span.target_classification_score}; + classification_result->datetime_parse_result = datetime_span.data; + return true; + } + } + return false; +} + +std::vector<ClassificationResult> TextClassifier::ClassifyText( + const std::string& context, CodepointSpan selection_indices, + const ClassificationOptions& options) const { + if (!initialized_) { + TC_LOG(ERROR) << "Not initialized"; + return {}; + } + + if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) { + return {}; + } + + if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { + return {}; + } + + if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) { + TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: " + << std::get<0>(selection_indices) << " " + << std::get<1>(selection_indices); + return {}; + } + + // Try the regular expression models. + ClassificationResult regex_result; + if (RegexClassifyText(context, selection_indices, ®ex_result)) { + if (!FilteredForClassification(regex_result)) { + return {regex_result}; + } else { + return {{kOtherCollection, 1.0}}; + } + } + + // Try the date model. + ClassificationResult datetime_result; + if (DatetimeClassifyText(context, selection_indices, options, + &datetime_result)) { + if (!FilteredForClassification(datetime_result)) { + return {datetime_result}; + } else { + return {{kOtherCollection, 1.0}}; + } + } + + // Fallback to the model. + std::vector<ClassificationResult> model_result; + + InterpreterManager interpreter_manager(selection_executor_.get(), + classification_executor_.get()); + if (ModelClassifyText(context, selection_indices, &interpreter_manager, + /*embedding_cache=*/nullptr, &model_result) && + !model_result.empty()) { + if (!FilteredForClassification(model_result[0])) { + return model_result; + } else { + return {{kOtherCollection, 1.0}}; + } + } + + // No classifications. + return {}; +} + +bool TextClassifier::ModelAnnotate(const std::string& context, + InterpreterManager* interpreter_manager, + std::vector<Token>* tokens, + std::vector<AnnotatedSpan>* result) const { + if (model_->triggering_options() == nullptr || + !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) { + return true; + } + + const UnicodeText context_unicode = UTF8ToUnicodeText(context, + /*do_copy=*/false); + std::vector<UnicodeTextRange> lines; + if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) { + lines.push_back({context_unicode.begin(), context_unicode.end()}); + } else { + lines = selection_feature_processor_->SplitContext(context_unicode); + } + + const float min_annotate_confidence = + (model_->triggering_options() != nullptr + ? model_->triggering_options()->min_annotate_confidence() + : 0.f); + + FeatureProcessor::EmbeddingCache embedding_cache; + for (const UnicodeTextRange& line : lines) { + const std::string line_str = + UnicodeText::UTF8Substring(line.first, line.second); + + *tokens = selection_feature_processor_->Tokenize(line_str); + selection_feature_processor_->RetokenizeAndFindClick( + line_str, {0, std::distance(line.first, line.second)}, + selection_feature_processor_->GetOptions()->only_use_line_with_click(), + tokens, + /*click_pos=*/nullptr); + const TokenSpan full_line_span = {0, tokens->size()}; + + // TODO(zilka): Add support for greater granularity of this check. + if (!selection_feature_processor_->HasEnoughSupportedCodepoints( + *tokens, full_line_span)) { + continue; + } + + std::unique_ptr<CachedFeatures> cached_features; + if (!selection_feature_processor_->ExtractFeatures( + *tokens, full_line_span, + /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, + embedding_executor_.get(), + /*embedding_cache=*/nullptr, + selection_feature_processor_->EmbeddingSize() + + selection_feature_processor_->DenseFeaturesCount(), + &cached_features)) { + TC_LOG(ERROR) << "Could not extract features."; + return false; + } + + std::vector<TokenSpan> local_chunks; + if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span, + interpreter_manager->SelectionInterpreter(), + *cached_features, &local_chunks)) { + TC_LOG(ERROR) << "Could not chunk."; + return false; + } + + const int offset = std::distance(context_unicode.begin(), line.first); + for (const TokenSpan& chunk : local_chunks) { + const CodepointSpan codepoint_span = + selection_feature_processor_->StripBoundaryCodepoints( + line_str, TokenSpanToCodepointSpan(*tokens, chunk)); + + // Skip empty spans. + if (codepoint_span.first != codepoint_span.second) { + std::vector<ClassificationResult> classification; + if (!ModelClassifyText(line_str, *tokens, codepoint_span, + interpreter_manager, &embedding_cache, + &classification)) { + TC_LOG(ERROR) << "Could not classify text: " + << (codepoint_span.first + offset) << " " + << (codepoint_span.second + offset); + return false; + } + + // Do not include the span if it's classified as "other". + if (!classification.empty() && !ClassifiedAsOther(classification) && + classification[0].score >= min_annotate_confidence) { + AnnotatedSpan result_span; + result_span.span = {codepoint_span.first + offset, + codepoint_span.second + offset}; + result_span.classification = std::move(classification); + result->push_back(std::move(result_span)); + } + } + } + } + return true; +} + +const FeatureProcessor* TextClassifier::SelectionFeatureProcessorForTests() + const { + return selection_feature_processor_.get(); +} + +const FeatureProcessor* TextClassifier::ClassificationFeatureProcessorForTests() + const { + return classification_feature_processor_.get(); +} + +const DatetimeParser* TextClassifier::DatetimeParserForTests() const { + return datetime_parser_.get(); +} + +std::vector<AnnotatedSpan> TextClassifier::Annotate( + const std::string& context, const AnnotationOptions& options) const { + std::vector<AnnotatedSpan> candidates; + + if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) { + return {}; + } + + if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { + return {}; + } + + InterpreterManager interpreter_manager(selection_executor_.get(), + classification_executor_.get()); + // Annotate with the selection model. + std::vector<Token> tokens; + if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) { + TC_LOG(ERROR) << "Couldn't run ModelAnnotate."; + return {}; + } + + // Annotate with the regular expression models. + if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), + annotation_regex_patterns_, &candidates)) { + TC_LOG(ERROR) << "Couldn't run RegexChunk."; + return {}; + } + + // Annotate with the datetime model. + if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), + options.reference_time_ms_utc, options.reference_timezone, + options.locales, ModeFlag_ANNOTATION, &candidates)) { + TC_LOG(ERROR) << "Couldn't run RegexChunk."; + return {}; + } + + // Sort candidates according to their position in the input, so that the next + // code can assume that any connected component of overlapping spans forms a + // contiguous block. + std::sort(candidates.begin(), candidates.end(), + [](const AnnotatedSpan& a, const AnnotatedSpan& b) { + return a.span.first < b.span.first; + }); + + std::vector<int> candidate_indices; + if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, + &candidate_indices)) { + TC_LOG(ERROR) << "Couldn't resolve conflicts."; + return {}; + } + + std::vector<AnnotatedSpan> result; + result.reserve(candidate_indices.size()); + for (const int i : candidate_indices) { + if (!candidates[i].classification.empty() && + !ClassifiedAsOther(candidates[i].classification) && + !FilteredForAnnotation(candidates[i])) { + result.push_back(std::move(candidates[i])); + } + } + + return result; +} + +bool TextClassifier::RegexChunk(const UnicodeText& context_unicode, + const std::vector<int>& rules, + std::vector<AnnotatedSpan>* result) const { + for (int pattern_id : rules) { + const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; + const auto matcher = regex_pattern.pattern->Matcher(context_unicode); + if (!matcher) { + TC_LOG(ERROR) << "Could not get regex matcher for pattern: " + << pattern_id; + return false; + } + + int status = UniLib::RegexMatcher::kNoError; + while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { + result->emplace_back(); + // Selection/annotation regular expressions need to specify a capturing + // group specifying the selection. + result->back().span = {matcher->Start(1, &status), + matcher->End(1, &status)}; + result->back().classification = { + {regex_pattern.collection_name, + regex_pattern.target_classification_score, + regex_pattern.priority_score}}; + } + } + return true; +} + +bool TextClassifier::ModelChunk(int num_tokens, + const TokenSpan& span_of_interest, + tflite::Interpreter* selection_interpreter, + const CachedFeatures& cached_features, + std::vector<TokenSpan>* chunks) const { + const int max_selection_span = + selection_feature_processor_->GetOptions()->max_selection_span(); + // The inference span is the span of interest expanded to include + // max_selection_span tokens on either side, which is how far a selection can + // stretch from the click. + const TokenSpan inference_span = IntersectTokenSpans( + ExpandTokenSpan(span_of_interest, + /*num_tokens_left=*/max_selection_span, + /*num_tokens_right=*/max_selection_span), + {0, num_tokens}); + + std::vector<ScoredChunk> scored_chunks; + if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() && + selection_feature_processor_->GetOptions() + ->bounds_sensitive_features() + ->enabled()) { + if (!ModelBoundsSensitiveScoreChunks( + num_tokens, span_of_interest, inference_span, cached_features, + selection_interpreter, &scored_chunks)) { + return false; + } + } else { + if (!ModelClickContextScoreChunks(num_tokens, span_of_interest, + cached_features, selection_interpreter, + &scored_chunks)) { + return false; + } + } + std::sort(scored_chunks.rbegin(), scored_chunks.rend(), + [](const ScoredChunk& lhs, const ScoredChunk& rhs) { + return lhs.score < rhs.score; + }); + + // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick + // them greedily as long as they do not overlap with any previously picked + // chunks. + std::vector<bool> token_used(TokenSpanSize(inference_span)); + chunks->clear(); + for (const ScoredChunk& scored_chunk : scored_chunks) { + bool feasible = true; + for (int i = scored_chunk.token_span.first; + i < scored_chunk.token_span.second; ++i) { + if (token_used[i - inference_span.first]) { + feasible = false; + break; + } + } + + if (!feasible) { + continue; + } + + for (int i = scored_chunk.token_span.first; + i < scored_chunk.token_span.second; ++i) { + token_used[i - inference_span.first] = true; + } + + chunks->push_back(scored_chunk.token_span); + } + + std::sort(chunks->begin(), chunks->end()); + + return true; +} + +namespace { +// Updates the value at the given key in the map to maximum of the current value +// and the given value, or simply inserts the value if the key is not yet there. +template <typename Map> +void UpdateMax(Map* map, typename Map::key_type key, + typename Map::mapped_type value) { + const auto it = map->find(key); + if (it != map->end()) { + it->second = std::max(it->second, value); + } else { + (*map)[key] = value; + } +} +} // namespace + +bool TextClassifier::ModelClickContextScoreChunks( + int num_tokens, const TokenSpan& span_of_interest, + const CachedFeatures& cached_features, + tflite::Interpreter* selection_interpreter, + std::vector<ScoredChunk>* scored_chunks) const { + const int max_batch_size = model_->selection_options()->batch_size(); + + std::vector<float> all_features; + std::map<TokenSpan, float> chunk_scores; + for (int batch_start = span_of_interest.first; + batch_start < span_of_interest.second; batch_start += max_batch_size) { + const int batch_end = + std::min(batch_start + max_batch_size, span_of_interest.second); + + // Prepare features for the whole batch. + all_features.clear(); + all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); + for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { + cached_features.AppendClickContextFeaturesForClick(click_pos, + &all_features); + } + + // Run batched inference. + const int batch_size = batch_end - batch_start; + const int features_size = cached_features.OutputFeaturesSize(); + TensorView<float> logits = selection_executor_->ComputeLogits( + TensorView<float>(all_features.data(), {batch_size, features_size}), + selection_interpreter); + if (!logits.is_valid()) { + TC_LOG(ERROR) << "Couldn't compute logits."; + return false; + } + if (logits.dims() != 2 || logits.dim(0) != batch_size || + logits.dim(1) != + selection_feature_processor_->GetSelectionLabelCount()) { + TC_LOG(ERROR) << "Mismatching output."; + return false; + } + + // Save results. + for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { + const std::vector<float> scores = ComputeSoftmax( + logits.data() + logits.dim(1) * (click_pos - batch_start), + logits.dim(1)); + for (int j = 0; + j < selection_feature_processor_->GetSelectionLabelCount(); ++j) { + TokenSpan relative_token_span; + if (!selection_feature_processor_->LabelToTokenSpan( + j, &relative_token_span)) { + TC_LOG(ERROR) << "Couldn't map the label to a token span."; + return false; + } + const TokenSpan candidate_span = ExpandTokenSpan( + SingleTokenSpan(click_pos), relative_token_span.first, + relative_token_span.second); + if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) { + UpdateMax(&chunk_scores, candidate_span, scores[j]); + } + } + } + } + + scored_chunks->clear(); + scored_chunks->reserve(chunk_scores.size()); + for (const auto& entry : chunk_scores) { + scored_chunks->push_back(ScoredChunk{entry.first, entry.second}); + } + + return true; +} + +bool TextClassifier::ModelBoundsSensitiveScoreChunks( + int num_tokens, const TokenSpan& span_of_interest, + const TokenSpan& inference_span, const CachedFeatures& cached_features, + tflite::Interpreter* selection_interpreter, + std::vector<ScoredChunk>* scored_chunks) const { + const int max_selection_span = + selection_feature_processor_->GetOptions()->max_selection_span(); + const int max_chunk_length = selection_feature_processor_->GetOptions() + ->selection_reduced_output_space() + ? max_selection_span + 1 + : 2 * max_selection_span + 1; + const bool score_single_token_spans_as_zero = + selection_feature_processor_->GetOptions() + ->bounds_sensitive_features() + ->score_single_token_spans_as_zero(); + + scored_chunks->clear(); + if (score_single_token_spans_as_zero) { + scored_chunks->reserve(TokenSpanSize(span_of_interest)); + } + + // Prepare all chunk candidates into one batch: + // - Are contained in the inference span + // - Have a non-empty intersection with the span of interest + // - Are at least one token long + // - Are not longer than the maximum chunk length + std::vector<TokenSpan> candidate_spans; + for (int start = inference_span.first; start < span_of_interest.second; + ++start) { + const int leftmost_end_index = std::max(start, span_of_interest.first) + 1; + for (int end = leftmost_end_index; + end <= inference_span.second && end - start <= max_chunk_length; + ++end) { + const TokenSpan candidate_span = {start, end}; + if (score_single_token_spans_as_zero && + TokenSpanSize(candidate_span) == 1) { + // Do not include the single token span in the batch, add a zero score + // for it directly to the output. + scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f}); + } else { + candidate_spans.push_back(candidate_span); + } + } + } + + const int max_batch_size = model_->selection_options()->batch_size(); + + std::vector<float> all_features; + scored_chunks->reserve(scored_chunks->size() + candidate_spans.size()); + for (int batch_start = 0; batch_start < candidate_spans.size(); + batch_start += max_batch_size) { + const int batch_end = std::min(batch_start + max_batch_size, + static_cast<int>(candidate_spans.size())); + + // Prepare features for the whole batch. + all_features.clear(); + all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); + for (int i = batch_start; i < batch_end; ++i) { + cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i], + &all_features); + } + + // Run batched inference. + const int batch_size = batch_end - batch_start; + const int features_size = cached_features.OutputFeaturesSize(); + TensorView<float> logits = selection_executor_->ComputeLogits( + TensorView<float>(all_features.data(), {batch_size, features_size}), + selection_interpreter); + if (!logits.is_valid()) { + TC_LOG(ERROR) << "Couldn't compute logits."; + return false; + } + if (logits.dims() != 2 || logits.dim(0) != batch_size || + logits.dim(1) != 1) { + TC_LOG(ERROR) << "Mismatching output."; + return false; + } + + // Save results. + for (int i = batch_start; i < batch_end; ++i) { + scored_chunks->push_back( + ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]}); + } + } + + return true; +} + +bool TextClassifier::DatetimeChunk(const UnicodeText& context_unicode, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& locales, ModeFlag mode, + std::vector<AnnotatedSpan>* result) const { + if (!datetime_parser_) { + return true; + } + + std::vector<DatetimeParseResultSpan> datetime_spans; + if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc, + reference_timezone, locales, mode, + /*anchor_start_end=*/false, &datetime_spans)) { + return false; + } + for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { + AnnotatedSpan annotated_span; + annotated_span.span = datetime_span.span; + annotated_span.classification = {{kDateCollection, + datetime_span.target_classification_score, + datetime_span.priority_score}}; + annotated_span.classification[0].datetime_parse_result = datetime_span.data; + + result->push_back(std::move(annotated_span)); + } + return true; +} + +const Model* ViewModel(const void* buffer, int size) { + if (!buffer) { + return nullptr; + } + + return LoadAndVerifyModel(buffer, size); +} + +} // namespace libtextclassifier2 diff --git a/text-classifier.h b/text-classifier.h new file mode 100644 index 0000000..0692ecd --- /dev/null +++ b/text-classifier.h @@ -0,0 +1,381 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Inference code for the text classification model. + +#ifndef LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_ +#define LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_ + +#include <memory> +#include <set> +#include <string> +#include <vector> + +#include "datetime/parser.h" +#include "feature-processor.h" +#include "model-executor.h" +#include "model_generated.h" +#include "strip-unpaired-brackets.h" +#include "types.h" +#include "util/memory/mmap.h" +#include "util/utf8/unilib.h" +#include "zlib-utils.h" + +namespace libtextclassifier2 { + +struct SelectionOptions { + // Comma-separated list of locale specification for the input text (BCP 47 + // tags). + std::string locales; + + static SelectionOptions Default() { return SelectionOptions(); } +}; + +struct ClassificationOptions { + // For parsing relative datetimes, the reference now time against which the + // relative datetimes get resolved. + // UTC milliseconds since epoch. + int64 reference_time_ms_utc = 0; + + // Timezone in which the input text was written (format as accepted by ICU). + std::string reference_timezone; + + // Comma-separated list of locale specification for the input text (BCP 47 + // tags). + std::string locales; + + static ClassificationOptions Default() { return ClassificationOptions(); } +}; + +struct AnnotationOptions { + // For parsing relative datetimes, the reference now time against which the + // relative datetimes get resolved. + // UTC milliseconds since epoch. + int64 reference_time_ms_utc = 0; + + // Timezone in which the input text was written (format as accepted by ICU). + std::string reference_timezone; + + // Comma-separated list of locale specification for the input text (BCP 47 + // tags). + std::string locales; + + static AnnotationOptions Default() { return AnnotationOptions(); } +}; + +// Holds TFLite interpreters for selection and classification models. +// NOTE: his class is not thread-safe, thus should NOT be re-used across +// threads. +class InterpreterManager { + public: + // The constructor can be called with nullptr for any of the executors, and is + // a defined behavior, as long as the corresponding *Interpreter() method is + // not called when the executor is null. + InterpreterManager(const ModelExecutor* selection_executor, + const ModelExecutor* classification_executor) + : selection_executor_(selection_executor), + classification_executor_(classification_executor) {} + + // Gets or creates and caches an interpreter for the selection model. + tflite::Interpreter* SelectionInterpreter(); + + // Gets or creates and caches an interpreter for the classification model. + tflite::Interpreter* ClassificationInterpreter(); + + private: + const ModelExecutor* selection_executor_; + const ModelExecutor* classification_executor_; + + std::unique_ptr<tflite::Interpreter> selection_interpreter_; + std::unique_ptr<tflite::Interpreter> classification_interpreter_; +}; + +// A text processing model that provides text classification, annotation, +// selection suggestion for various types. +// NOTE: This class is not thread-safe. +class TextClassifier { + public: + static std::unique_ptr<TextClassifier> FromUnownedBuffer( + const char* buffer, int size, const UniLib* unilib = nullptr); + // Takes ownership of the mmap. + static std::unique_ptr<TextClassifier> FromScopedMmap( + std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr); + static std::unique_ptr<TextClassifier> FromFileDescriptor( + int fd, int offset, int size, const UniLib* unilib = nullptr); + static std::unique_ptr<TextClassifier> FromFileDescriptor( + int fd, const UniLib* unilib = nullptr); + static std::unique_ptr<TextClassifier> FromPath( + const std::string& path, const UniLib* unilib = nullptr); + + // Returns true if the model is ready for use. + bool IsInitialized() { return initialized_; } + + // Runs inference for given a context and current selection (i.e. index + // of the first and one past last selected characters (utf8 codepoint + // offsets)). Returns the indices (utf8 codepoint offsets) of the selection + // beginning character and one past selection end character. + // Returns the original click_indices if an error occurs. + // NOTE: The selection indices are passed in and returned in terms of + // UTF8 codepoints (not bytes). + // Requires that the model is a smart selection model. + CodepointSpan SuggestSelection( + const std::string& context, CodepointSpan click_indices, + const SelectionOptions& options = SelectionOptions::Default()) const; + + // Classifies the selected text given the context string. + // Returns an empty result if an error occurs. + std::vector<ClassificationResult> ClassifyText( + const std::string& context, CodepointSpan selection_indices, + const ClassificationOptions& options = + ClassificationOptions::Default()) const; + + // Annotates given input text. The annotations are sorted by their position + // in the context string and exclude spans classified as 'other'. + std::vector<AnnotatedSpan> Annotate( + const std::string& context, + const AnnotationOptions& options = AnnotationOptions::Default()) const; + + // Exposes the feature processor for tests and evaluations. + const FeatureProcessor* SelectionFeatureProcessorForTests() const; + const FeatureProcessor* ClassificationFeatureProcessorForTests() const; + + // Exposes the date time parser for tests and evaluations. + const DatetimeParser* DatetimeParserForTests() const; + + // String collection names for various classes. + static const std::string& kOtherCollection; + static const std::string& kPhoneCollection; + static const std::string& kAddressCollection; + static const std::string& kDateCollection; + + protected: + struct ScoredChunk { + TokenSpan token_span; + float score; + }; + + // Constructs and initializes text classifier from given model. + // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'. + TextClassifier(std::unique_ptr<ScopedMmap>* mmap, const Model* model, + const UniLib* unilib) + : model_(model), + mmap_(std::move(*mmap)), + owned_unilib_(nullptr), + unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) { + ValidateAndInitialize(); + } + + // Constructs, validates and initializes text classifier from given model. + // Does not own the buffer that backs 'model'. + explicit TextClassifier(const Model* model, const UniLib* unilib) + : model_(model), + owned_unilib_(nullptr), + unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) { + ValidateAndInitialize(); + } + + // Checks that model contains all required fields, and initializes internal + // datastructures. + void ValidateAndInitialize(); + + // Initializes regular expressions for the regex model. + bool InitializeRegexModel(ZlibDecompressor* decompressor); + + // Resolves conflicts in the list of candidates by removing some overlapping + // ones. Returns indices of the surviving ones. + // NOTE: Assumes that the candidates are sorted according to their position in + // the span. + bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates, + const std::string& context, + const std::vector<Token>& cached_tokens, + InterpreterManager* interpreter_manager, + std::vector<int>* result) const; + + // Resolves one conflict between candidates on indices 'start_index' + // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate + // indices to 'chosen_indices'. Returns false if a problem arises. + bool ResolveConflict(const std::string& context, + const std::vector<Token>& cached_tokens, + const std::vector<AnnotatedSpan>& candidates, + int start_index, int end_index, + InterpreterManager* interpreter_manager, + std::vector<int>* chosen_indices) const; + + // Gets selection candidates from the ML model. + // Provides the tokens produced during tokenization of the context string for + // reuse. + bool ModelSuggestSelection(const UnicodeText& context_unicode, + CodepointSpan click_indices, + InterpreterManager* interpreter_manager, + std::vector<Token>* tokens, + std::vector<AnnotatedSpan>* result) const; + + // Classifies the selected text given the context string with the + // classification model. + // Returns true if no error occurred. + bool ModelClassifyText( + const std::string& context, const std::vector<Token>& cached_tokens, + CodepointSpan selection_indices, InterpreterManager* interpreter_manager, + FeatureProcessor::EmbeddingCache* embedding_cache, + std::vector<ClassificationResult>* classification_results) const; + + bool ModelClassifyText( + const std::string& context, CodepointSpan selection_indices, + InterpreterManager* interpreter_manager, + FeatureProcessor::EmbeddingCache* embedding_cache, + std::vector<ClassificationResult>* classification_results) const; + + // Returns a relative token span that represents how many tokens on the left + // from the selection and right from the selection are needed for the + // classifier input. + TokenSpan ClassifyTextUpperBoundNeededTokens() const; + + // Classifies the selected text with the regular expressions models. + // Returns true if any regular expression matched and the result was set. + bool RegexClassifyText(const std::string& context, + CodepointSpan selection_indices, + ClassificationResult* classification_result) const; + + // Classifies the selected text with the date time model. + // Returns true if there was a match and the result was set. + bool DatetimeClassifyText(const std::string& context, + CodepointSpan selection_indices, + const ClassificationOptions& options, + ClassificationResult* classification_result) const; + + // Chunks given input text with the selection model and classifies the spans + // with the classification model. + // The annotations are sorted by their position in the context string and + // exclude spans classified as 'other'. + // Provides the tokens produced during tokenization of the context string for + // reuse. + bool ModelAnnotate(const std::string& context, + InterpreterManager* interpreter_manager, + std::vector<Token>* tokens, + std::vector<AnnotatedSpan>* result) const; + + // Groups the tokens into chunks. A chunk is a token span that should be the + // suggested selection when any of its contained tokens is clicked. The chunks + // are non-overlapping and are sorted by their position in the context string. + // "num_tokens" is the total number of tokens available (as this method does + // not need the actual vector of tokens). + // "span_of_interest" is a span of all the tokens that could be clicked. + // The resulting chunks all have to overlap with it and they cover this span + // completely. The first and last chunk might extend beyond it. + // The chunks vector is cleared before filling. + bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest, + tflite::Interpreter* selection_interpreter, + const CachedFeatures& cached_features, + std::vector<TokenSpan>* chunks) const; + + // A helper method for ModelChunk(). It generates scored chunk candidates for + // a click context model. + // NOTE: The returned chunks can (and most likely do) overlap. + bool ModelClickContextScoreChunks( + int num_tokens, const TokenSpan& span_of_interest, + const CachedFeatures& cached_features, + tflite::Interpreter* selection_interpreter, + std::vector<ScoredChunk>* scored_chunks) const; + + // A helper method for ModelChunk(). It generates scored chunk candidates for + // a bounds-sensitive model. + // NOTE: The returned chunks can (and most likely do) overlap. + bool ModelBoundsSensitiveScoreChunks( + int num_tokens, const TokenSpan& span_of_interest, + const TokenSpan& inference_span, const CachedFeatures& cached_features, + tflite::Interpreter* selection_interpreter, + std::vector<ScoredChunk>* scored_chunks) const; + + // Produces chunks isolated by a set of regular expressions. + bool RegexChunk(const UnicodeText& context_unicode, + const std::vector<int>& rules, + std::vector<AnnotatedSpan>* result) const; + + // Produces chunks from the datetime parser. + bool DatetimeChunk(const UnicodeText& context_unicode, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& locales, ModeFlag mode, + std::vector<AnnotatedSpan>* result) const; + + // Returns whether a classification should be filtered. + bool FilteredForAnnotation(const AnnotatedSpan& span) const; + bool FilteredForClassification( + const ClassificationResult& classification) const; + bool FilteredForSelection(const AnnotatedSpan& span) const; + + const Model* model_; + + std::unique_ptr<const ModelExecutor> selection_executor_; + std::unique_ptr<const ModelExecutor> classification_executor_; + std::unique_ptr<const EmbeddingExecutor> embedding_executor_; + + std::unique_ptr<const FeatureProcessor> selection_feature_processor_; + std::unique_ptr<const FeatureProcessor> classification_feature_processor_; + + std::unique_ptr<const DatetimeParser> datetime_parser_; + + private: + struct CompiledRegexPattern { + std::string collection_name; + float target_classification_score; + float priority_score; + std::unique_ptr<UniLib::RegexPattern> pattern; + }; + + std::unique_ptr<ScopedMmap> mmap_; + bool initialized_ = false; + bool enabled_for_annotation_ = false; + bool enabled_for_classification_ = false; + bool enabled_for_selection_ = false; + std::unordered_set<std::string> filtered_collections_annotation_; + std::unordered_set<std::string> filtered_collections_classification_; + std::unordered_set<std::string> filtered_collections_selection_; + + std::vector<CompiledRegexPattern> regex_patterns_; + std::unordered_set<int> regex_approximate_match_pattern_ids_; + + // Indices into regex_patterns_ for the different modes. + std::vector<int> annotation_regex_patterns_, classification_regex_patterns_, + selection_regex_patterns_; + + std::unique_ptr<UniLib> owned_unilib_; + const UniLib* unilib_; +}; + +namespace internal { + +// Helper function, which if the initial 'span' contains only white-spaces, +// moves the selection to a single-codepoint selection on the left side +// of this block of white-space. +CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, + const UnicodeText& context_unicode, + const UniLib& unilib); + +// Copies tokens from 'cached_tokens' that are +// 'tokens_around_selection_to_copy' (on the left, and right) tokens distant +// from the tokens that correspond to 'selection_indices'. +std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, + CodepointSpan selection_indices, + TokenSpan tokens_around_selection_to_copy); +} // namespace internal + +// Interprets the buffer as a Model flatbuffer and returns it for reading. +const Model* ViewModel(const void* buffer, int size); + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_ diff --git a/text-classifier_test.cc b/text-classifier_test.cc new file mode 100644 index 0000000..c8ced76 --- /dev/null +++ b/text-classifier_test.cc @@ -0,0 +1,1291 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "text-classifier.h" + +#include <fstream> +#include <iostream> +#include <memory> +#include <string> + +#include "model_generated.h" +#include "types-test-util.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier2 { +namespace { + +using testing::ElementsAreArray; +using testing::IsEmpty; +using testing::Pair; +using testing::Values; + +std::string FirstResult(const std::vector<ClassificationResult>& results) { + if (results.empty()) { + return "<INVALID RESULTS>"; + } + return results[0].collection; +} + +MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") { + return testing::Value(arg.span, Pair(start, end)) && + testing::Value(FirstResult(arg.classification), best_class); +} + +std::string ReadFile(const std::string& file_name) { + std::ifstream file_stream(file_name); + return std::string(std::istreambuf_iterator<char>(file_stream), {}); +} + +std::string GetModelPath() { + return LIBTEXTCLASSIFIER_TEST_DATA_DIR; +} + +TEST(TextClassifierTest, EmbeddingExecutorLoadingFails) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + "wrong_embeddings.fb", &unilib); + EXPECT_FALSE(classifier); +} + +class TextClassifierTest : public ::testing::TestWithParam<const char*> {}; + +INSTANTIATE_TEST_CASE_P(ClickContext, TextClassifierTest, + Values("test_model_cc.fb")); +INSTANTIATE_TEST_CASE_P(BoundsSensitive, TextClassifierTest, + Values("test_model.fb")); + +TEST_P(TextClassifierTest, ClassifyText) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ("other", + FirstResult(classifier->ClassifyText( + "this afternoon Barack Obama gave a speech at", {15, 27}))); + EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( + "Call me at (800) 123-456 today", {11, 24}))); + + // More lines. + EXPECT_EQ("other", + FirstResult(classifier->ClassifyText( + "this afternoon Barack Obama gave a speech at|Visit " + "www.google.com every today!|Call me at (800) 123-456 today.", + {15, 27}))); + EXPECT_EQ("phone", + FirstResult(classifier->ClassifyText( + "this afternoon Barack Obama gave a speech at|Visit " + "www.google.com every today!|Call me at (800) 123-456 today.", + {90, 103}))); + + // Single word. + EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5}))); + EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4}))); + EXPECT_EQ("<INVALID RESULTS>", + FirstResult(classifier->ClassifyText("asdf", {0, 0}))); + + // Junk. + EXPECT_EQ("<INVALID RESULTS>", + FirstResult(classifier->ClassifyText("", {0, 0}))); + EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText( + "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5}))); + // Test invalid utf8 input. + EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText( + "\xf0\x9f\x98\x8b\x8b", {0, 0}))); +} + +TEST_P(TextClassifierTest, ClassifyTextDisabledFail) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + unpacked_model->classification_model.clear(); + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + + // The classification model is still needed for selection scores. + ASSERT_FALSE(classifier); +} + +TEST_P(TextClassifierTest, ClassifyTextDisabled) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + unpacked_model->triggering_options->enabled_modes = + ModeFlag_ANNOTATION_AND_SELECTION; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_THAT( + classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}), + IsEmpty()); +} + +TEST_P(TextClassifierTest, ClassifyTextFilteredCollections) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(), + &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( + "Call me at (800) 123-456 today", {11, 24}))); + + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + unpacked_model->output_options.reset(new OutputOptionsT); + + // Disable phone classification + unpacked_model->output_options->filtered_collections_classification.push_back( + "phone"); + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + classifier = TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ("other", FirstResult(classifier->ClassifyText( + "Call me at (800) 123-456 today", {11, 24}))); + + // Check that the address classification still passes. + EXPECT_EQ("address", FirstResult(classifier->ClassifyText( + "350 Third Street, Cambridge", {0, 27}))); +} + +std::unique_ptr<RegexModel_::PatternT> MakePattern( + const std::string& collection_name, const std::string& pattern, + const bool enabled_for_classification, const bool enabled_for_selection, + const bool enabled_for_annotation, const float score) { + std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT); + result->collection_name = collection_name; + result->pattern = pattern; + // We cannot directly operate with |= on the flag, so use an int here. + int enabled_modes = ModeFlag_NONE; + if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION; + if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION; + if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION; + result->enabled_modes = static_cast<ModeFlag>(enabled_modes); + result->target_classification_score = score; + result->priority_score = score; + return result; +} + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST_P(TextClassifierTest, ClassifyTextRegularExpression) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test regex models. + unpacked_model->regex_model->patterns.push_back(MakePattern( + "person", "Barack Obama", /*enabled_for_classification=*/true, + /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.push_back(MakePattern( + "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true, + /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5)); + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ("flight", + FirstResult(classifier->ClassifyText( + "Your flight LX373 is delayed by 3 hours.", {12, 17}))); + EXPECT_EQ("person", + FirstResult(classifier->ClassifyText( + "this afternoon Barack Obama gave a speech at", {15, 27}))); + EXPECT_EQ("email", + FirstResult(classifier->ClassifyText("you@android.com", {0, 15}))); + EXPECT_EQ("email", FirstResult(classifier->ClassifyText( + "Contact me at you@android.com", {14, 29}))); + + EXPECT_EQ("url", FirstResult(classifier->ClassifyText( + "Visit www.google.com every today!", {6, 20}))); + + EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5}))); + EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd", + {7, 12}))); + + // More lines. + EXPECT_EQ("url", + FirstResult(classifier->ClassifyText( + "this afternoon Barack Obama gave a speech at|Visit " + "www.google.com every today!|Call me at (800) 123-456 today.", + {51, 65}))); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test regex models. + unpacked_model->regex_model->patterns.push_back(MakePattern( + "person", " (Barack Obama) ", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.push_back(MakePattern( + "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.back()->priority_score = 1.1; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + // Check regular expression selection. + EXPECT_EQ(classifier->SuggestSelection( + "Your flight MA 0123 is delayed by 3 hours.", {12, 14}), + std::make_pair(12, 19)); + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon Barack Obama gave a speech at", {15, 21}), + std::make_pair(15, 27)); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST_P(TextClassifierTest, + SuggestSelectionRegularExpressionConflictsModelWins) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test regex models. + unpacked_model->regex_model->patterns.push_back(MakePattern( + "person", " (Barack Obama) ", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.push_back(MakePattern( + "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.back()->priority_score = 0.5; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); + ASSERT_TRUE(classifier); + + // Check conflict resolution. + EXPECT_EQ( + classifier->SuggestSelection( + "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123", + {55, 57}), + std::make_pair(26, 62)); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST_P(TextClassifierTest, + SuggestSelectionRegularExpressionConflictsRegexWins) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test regex models. + unpacked_model->regex_model->patterns.push_back(MakePattern( + "person", " (Barack Obama) ", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.push_back(MakePattern( + "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.back()->priority_score = 1.1; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); + ASSERT_TRUE(classifier); + + // Check conflict resolution. + EXPECT_EQ( + classifier->SuggestSelection( + "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123", + {55, 57}), + std::make_pair(55, 62)); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST_P(TextClassifierTest, AnnotateRegex) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test regex models. + unpacked_model->regex_model->patterns.push_back(MakePattern( + "person", " (Barack Obama) ", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0)); + unpacked_model->regex_model->patterns.push_back(MakePattern( + "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5)); + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ + IsAnnotatedSpan(6, 18, "person"), + IsAnnotatedSpan(19, 24, "date"), + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + })); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +TEST_P(TextClassifierTest, PhoneFiltering) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( + "phone: (123) 456 789", {7, 20}))); + EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( + "phone: (123) 456 789,0001112", {7, 25}))); + EXPECT_EQ("other", FirstResult(classifier->ClassifyText( + "phone: (123) 456 789,0001112", {7, 28}))); +} + +TEST_P(TextClassifierTest, SuggestSelection) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon Barack Obama gave a speech at", {15, 21}), + std::make_pair(15, 21)); + + // Try passing whole string. + // If more than 1 token is specified, we should return back what entered. + EXPECT_EQ( + classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}), + std::make_pair(0, 27)); + + // Single letter. + EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1)); + + // Single word. + EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4)); + + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), + std::make_pair(11, 23)); + + // Unpaired bracket stripping. + EXPECT_EQ( + classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}), + std::make_pair(11, 25)); + EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}), + std::make_pair(12, 15)); + EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}), + std::make_pair(11, 15)); + EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}), + std::make_pair(12, 15)); + + // If the resulting selection would be empty, the original span is returned. + EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}), + std::make_pair(11, 13)); + EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}), + std::make_pair(11, 12)); + EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}), + std::make_pair(11, 12)); +} + +TEST_P(TextClassifierTest, SuggestSelectionDisabledFail) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Disable the selection model. + unpacked_model->selection_model.clear(); + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + // Selection model needs to be present for annotation. + ASSERT_FALSE(classifier); +} + +TEST_P(TextClassifierTest, SuggestSelectionDisabled) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Disable the selection model. + unpacked_model->selection_model.clear(); + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION; + unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), + std::make_pair(11, 14)); + + EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( + "call me at (800) 123-456 today", {11, 24}))); + + EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"), + IsEmpty()); +} + +TEST_P(TextClassifierTest, SuggestSelectionFilteredCollections) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(), + &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), + std::make_pair(11, 23)); + + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + unpacked_model->output_options.reset(new OutputOptionsT); + + // Disable phone selection + unpacked_model->output_options->filtered_collections_selection.push_back( + "phone"); + // We need to force this for filtering. + unpacked_model->selection_options->always_classify_suggested_selection = true; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + classifier = TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), + std::make_pair(11, 14)); + + // Address selection should still work. + EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}), + std::make_pair(0, 27)); +} + +TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}), + std::make_pair(0, 27)); + EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}), + std::make_pair(0, 27)); + EXPECT_EQ( + classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}), + std::make_pair(0, 27)); + EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge", + {16, 22}), + std::make_pair(6, 33)); +} + +TEST_P(TextClassifierTest, SuggestSelectionWithNewLine) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}), + std::make_pair(4, 16)); + EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}), + std::make_pair(0, 12)); + + SelectionOptions options; + EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options), + std::make_pair(0, 7)); +} + +TEST_P(TextClassifierTest, SuggestSelectionWithPunctuation) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + // From the right. + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon BarackObama, gave a speech at", {15, 26}), + std::make_pair(15, 26)); + + // From the right multiple. + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}), + std::make_pair(15, 26)); + + // From the left multiple. + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}), + std::make_pair(21, 32)); + + // From both sides. + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon !BarackObama,- gave a speech at", {16, 27}), + std::make_pair(16, 27)); +} + +TEST_P(TextClassifierTest, SuggestSelectionNoCrashWithJunk) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + // Try passing in bunch of invalid selections. + EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27)); + EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}), + std::make_pair(-10, 27)); + EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}), + std::make_pair(0, 27)); + EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}), + std::make_pair(-30, 300)); + EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}), + std::make_pair(-10, -1)); + EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}), + std::make_pair(100, 17)); + + // Try passing invalid utf8. + EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}), + std::make_pair(-1, -1)); +} + +TEST_P(TextClassifierTest, SuggestSelectionSelectSpace) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}), + std::make_pair(11, 23)); + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}), + std::make_pair(10, 11)); + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}), + std::make_pair(23, 24)); + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}), + std::make_pair(23, 24)); + EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today", + {14, 17}), + std::make_pair(11, 25)); + EXPECT_EQ( + classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}), + std::make_pair(11, 23)); + EXPECT_EQ( + classifier->SuggestSelection( + "let's meet at 350 Third Street Cambridge and go there", {30, 31}), + std::make_pair(14, 40)); + EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}), + std::make_pair(4, 5)); + EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}), + std::make_pair(7, 8)); + + // With a punctuation around the selected whitespace. + EXPECT_EQ( + classifier->SuggestSelection( + "let's meet at 350 Third Street, Cambridge and go there", {31, 32}), + std::make_pair(14, 41)); + + // When all's whitespace, should return the original indices. + EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}), + std::make_pair(0, 1)); + EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}), + std::make_pair(0, 3)); + EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}), + std::make_pair(2, 3)); + EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}), + std::make_pair(5, 6)); +} + +TEST(TextClassifierTest, SnapLeftIfWhitespaceSelection) { + CREATE_UNILIB_FOR_TESTING; + UnicodeText text; + + text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib), + std::make_pair(3, 4)); + text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib), + std::make_pair(3, 4)); + + // Nothing on the left. + text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib), + std::make_pair(4, 5)); + text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib), + std::make_pair(0, 1)); + + // Whitespace only. + text = UTF8ToUnicodeText(" ", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib), + std::make_pair(2, 3)); + text = UTF8ToUnicodeText(" ", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib), + std::make_pair(4, 5)); + text = UTF8ToUnicodeText(" ", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib), + std::make_pair(0, 1)); +} + +TEST_P(TextClassifierTest, Annotate) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU + IsAnnotatedSpan(19, 24, "date"), +#endif + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + })); + + AnnotationOptions options; + EXPECT_THAT(classifier->Annotate("853 225 3556", options), + ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")})); + EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty()); + + // Try passing invalid utf8. + EXPECT_TRUE( + classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options) + .empty()); +} + +TEST_P(TextClassifierTest, AnnotateSmallBatches) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Set the batch size. + unpacked_model->selection_options->batch_size = 4; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU + IsAnnotatedSpan(19, 24, "date"), +#endif + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + })); + + AnnotationOptions options; + EXPECT_THAT(classifier->Annotate("853 225 3556", options), + ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")})); + EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty()); +} + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST_P(TextClassifierTest, AnnotateFilteringDiscardAll) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + // Add test threshold. + unpacked_model->triggering_options->min_annotate_confidence = + 2.f; // Discards all results. + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + + EXPECT_EQ(classifier->Annotate(test_string).size(), 1); +} +#endif + +TEST_P(TextClassifierTest, AnnotateFilteringKeepAll) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test thresholds. + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + unpacked_model->triggering_options->min_annotate_confidence = + 0.f; // Keeps all results. + unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU + EXPECT_EQ(classifier->Annotate(test_string).size(), 3); +#else + // In non-ICU mode there is no "date" result. + EXPECT_EQ(classifier->Annotate(test_string).size(), 2); +#endif +} + +TEST_P(TextClassifierTest, AnnotateDisabled) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Disable the model for annotation. + unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + EXPECT_THAT(classifier->Annotate(test_string), IsEmpty()); +} + +TEST_P(TextClassifierTest, AnnotateFilteredCollections) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(), + &unilib); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU + IsAnnotatedSpan(19, 24, "date"), +#endif + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + })); + + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + unpacked_model->output_options.reset(new OutputOptionsT); + + // Disable phone annotation + unpacked_model->output_options->filtered_collections_annotation.push_back( + "phone"); + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + classifier = TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU + IsAnnotatedSpan(19, 24, "date"), +#endif + IsAnnotatedSpan(28, 55, "address"), + })); +} + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST_P(TextClassifierTest, AnnotateFilteredCollectionsSuppress) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(), + &unilib); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU + IsAnnotatedSpan(19, 24, "date"), +#endif + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + })); + + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + unpacked_model->output_options.reset(new OutputOptionsT); + + // We add a custom annotator that wins against the phone classification + // below and that we subsequently suppress. + unpacked_model->output_options->filtered_collections_annotation.push_back( + "suppress"); + + unpacked_model->regex_model->patterns.push_back(MakePattern( + "suppress", "(\\d{3} ?\\d{4})", + /*enabled_for_classification=*/false, + /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0)); + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + classifier = TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ + IsAnnotatedSpan(19, 24, "date"), + IsAnnotatedSpan(28, 55, "address"), + })); +} +#endif + +#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU +TEST_P(TextClassifierTest, ClassifyTextDate) { + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam()); + EXPECT_TRUE(classifier); + + std::vector<ClassificationResult> result; + ClassificationOptions options; + + options.reference_timezone = "Europe/Zurich"; + result = classifier->ClassifyText("january 1, 2017", {0, 15}, options); + + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_DAY); + result.clear(); + + options.reference_timezone = "America/Los_Angeles"; + result = classifier->ClassifyText("march 1, 2017", {0, 13}, options); + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_DAY); + result.clear(); + + options.reference_timezone = "America/Los_Angeles"; + result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options); + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_SECOND); + result.clear(); + + // Date on another line. + options.reference_timezone = "Europe/Zurich"; + result = classifier->ClassifyText( + "hello world this is the first line\n" + "january 1, 2017", + {35, 50}, options); + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_DAY); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU +TEST_P(TextClassifierTest, ClassifyTextDatePriorities) { + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam()); + EXPECT_TRUE(classifier); + + std::vector<ClassificationResult> result; + ClassificationOptions options; + + result.clear(); + options.reference_timezone = "Europe/Zurich"; + options.locales = "en-US"; + result = classifier->ClassifyText("03.05.1970", {0, 10}, options); + + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_DAY); + + result.clear(); + options.reference_timezone = "Europe/Zurich"; + options.locales = "de"; + result = classifier->ClassifyText("03.05.1970", {0, 10}, options); + + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_DAY); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU +TEST_P(TextClassifierTest, SuggestTextDateDisabled) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Disable the patterns for selection. + for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) { + unpacked_model->datetime_model->patterns[i]->enabled_modes = + ModeFlag_ANNOTATION_AND_CLASSIFICATION; + } + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + EXPECT_EQ("date", + FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15}))); + EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}), + std::make_pair(0, 7)); + EXPECT_THAT(classifier->Annotate("january 1, 2017"), + ElementsAreArray({IsAnnotatedSpan(0, 15, "date")})); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +class TestingTextClassifier : public TextClassifier { + public: + TestingTextClassifier(const std::string& model, const UniLib* unilib) + : TextClassifier(ViewModel(model.data(), model.size()), unilib) {} + + using TextClassifier::ResolveConflicts; +}; + +AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span, + const std::string& collection, + const float score) { + AnnotatedSpan result; + result.span = span; + result.classification.push_back({collection, score}); + return result; +} + +TEST(TextClassifierTest, ResolveConflictsTrivial) { + CREATE_UNILIB_FOR_TESTING; + TestingTextClassifier classifier("", &unilib); + + std::vector<AnnotatedSpan> candidates{ + {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}}; + + std::vector<int> chosen; + classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, + /*interpreter_manager=*/nullptr, &chosen); + EXPECT_THAT(chosen, ElementsAreArray({0})); +} + +TEST(TextClassifierTest, ResolveConflictsSequence) { + CREATE_UNILIB_FOR_TESTING; + TestingTextClassifier classifier("", &unilib); + + std::vector<AnnotatedSpan> candidates{{ + MakeAnnotatedSpan({0, 1}, "phone", 1.0), + MakeAnnotatedSpan({1, 2}, "phone", 1.0), + MakeAnnotatedSpan({2, 3}, "phone", 1.0), + MakeAnnotatedSpan({3, 4}, "phone", 1.0), + MakeAnnotatedSpan({4, 5}, "phone", 1.0), + }}; + + std::vector<int> chosen; + classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, + /*interpreter_manager=*/nullptr, &chosen); + EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4})); +} + +TEST(TextClassifierTest, ResolveConflictsThreeSpans) { + CREATE_UNILIB_FOR_TESTING; + TestingTextClassifier classifier("", &unilib); + + std::vector<AnnotatedSpan> candidates{{ + MakeAnnotatedSpan({0, 3}, "phone", 1.0), + MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser! + MakeAnnotatedSpan({3, 7}, "phone", 1.0), + }}; + + std::vector<int> chosen; + classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, + /*interpreter_manager=*/nullptr, &chosen); + EXPECT_THAT(chosen, ElementsAreArray({0, 2})); +} + +TEST(TextClassifierTest, ResolveConflictsThreeSpansReversed) { + CREATE_UNILIB_FOR_TESTING; + TestingTextClassifier classifier("", &unilib); + + std::vector<AnnotatedSpan> candidates{{ + MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser! + MakeAnnotatedSpan({1, 5}, "phone", 1.0), + MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser! + }}; + + std::vector<int> chosen; + classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, + /*interpreter_manager=*/nullptr, &chosen); + EXPECT_THAT(chosen, ElementsAreArray({1})); +} + +TEST(TextClassifierTest, ResolveConflictsFiveSpans) { + CREATE_UNILIB_FOR_TESTING; + TestingTextClassifier classifier("", &unilib); + + std::vector<AnnotatedSpan> candidates{{ + MakeAnnotatedSpan({0, 3}, "phone", 0.5), + MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser! + MakeAnnotatedSpan({3, 7}, "phone", 0.6), + MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser! + MakeAnnotatedSpan({11, 15}, "phone", 0.9), + }}; + + std::vector<int> chosen; + classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, + /*interpreter_manager=*/nullptr, &chosen); + EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4})); +} + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST_P(TextClassifierTest, LongInput) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + for (const auto& type_value_pair : + std::vector<std::pair<std::string, std::string>>{ + {"address", "350 Third Street, Cambridge"}, + {"phone", "123 456-7890"}, + {"url", "www.google.com"}, + {"email", "someone@gmail.com"}, + {"flight", "LX 38"}, + {"date", "September 1, 2018"}}) { + const std::string input_100k = std::string(50000, ' ') + + type_value_pair.second + + std::string(50000, ' '); + const int value_length = type_value_pair.second.size(); + + EXPECT_THAT(classifier->Annotate(input_100k), + ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length, + type_value_pair.first)})); + EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}), + std::make_pair(50000, 50000 + value_length)); + EXPECT_EQ(type_value_pair.first, + FirstResult(classifier->ClassifyText( + input_100k, {50000, 50000 + value_length}))); + } +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +// These coarse tests are there only to make sure the execution happens in +// reasonable amount of time. +TEST_P(TextClassifierTest, LongInputNoResultCheck) { + CREATE_UNILIB_FOR_TESTING; + std::unique_ptr<TextClassifier> classifier = + TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); + ASSERT_TRUE(classifier); + + for (const std::string& value : + std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) { + const std::string input_100k = + std::string(50000, ' ') + value + std::string(50000, ' '); + const int value_length = value.size(); + + classifier->Annotate(input_100k); + classifier->SuggestSelection(input_100k, {50000, 50001}); + classifier->ClassifyText(input_100k, {50000, 50000 + value_length}); + } +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST_P(TextClassifierTest, MaxTokenLength) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + std::unique_ptr<TextClassifier> classifier; + + // With unrestricted number of tokens should behave normally. + unpacked_model->classification_options->max_num_tokens = -1; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + classifier = TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ(FirstResult(classifier->ClassifyText( + "I live at 350 Third Street, Cambridge.", {10, 37})), + "address"); + + // Raise the maximum number of tokens to suppress the classification. + unpacked_model->classification_options->max_num_tokens = 3; + + flatbuffers::FlatBufferBuilder builder2; + builder2.Finish(Model::Pack(builder2, unpacked_model.get())); + classifier = TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder2.GetBufferPointer()), + builder2.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ(FirstResult(classifier->ClassifyText( + "I live at 350 Third Street, Cambridge.", {10, 37})), + "other"); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST_P(TextClassifierTest, MinAddressTokenLength) { + CREATE_UNILIB_FOR_TESTING; + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + std::unique_ptr<TextClassifier> classifier; + + // With unrestricted number of address tokens should behave normally. + unpacked_model->classification_options->address_min_num_tokens = 0; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, unpacked_model.get())); + classifier = TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ(FirstResult(classifier->ClassifyText( + "I live at 350 Third Street, Cambridge.", {10, 37})), + "address"); + + // Raise number of address tokens to suppress the address classification. + unpacked_model->classification_options->address_min_num_tokens = 5; + + flatbuffers::FlatBufferBuilder builder2; + builder2.Finish(Model::Pack(builder2, unpacked_model.get())); + classifier = TextClassifier::FromUnownedBuffer( + reinterpret_cast<const char*>(builder2.GetBufferPointer()), + builder2.GetSize(), &unilib); + ASSERT_TRUE(classifier); + + EXPECT_EQ(FirstResult(classifier->ClassifyText( + "I live at 350 Third Street, Cambridge.", {10, 37})), + "other"); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +} // namespace +} // namespace libtextclassifier2 diff --git a/textclassifier_jni.cc b/textclassifier_jni.cc index 8740f4c..29cf745 100644 --- a/textclassifier_jni.cc +++ b/textclassifier_jni.cc @@ -14,56 +14,40 @@ * limitations under the License. */ -// Simple JNI wrapper for the SmartSelection library. +// JNI wrapper for the TextClassifier. #include "textclassifier_jni.h" #include <jni.h> +#include <type_traits> #include <vector> -#include "lang_id/lang-id.h" -#include "smartselect/text-classification-model.h" +#include "text-classifier.h" +#include "util/base/integral_types.h" #include "util/java/scoped_local_ref.h" +#include "util/java/string_utils.h" +#include "util/memory/mmap.h" +#include "util/utf8/unilib.h" + +using libtextclassifier2::AnnotatedSpan; +using libtextclassifier2::AnnotationOptions; +using libtextclassifier2::ClassificationOptions; +using libtextclassifier2::ClassificationResult; +using libtextclassifier2::CodepointSpan; +using libtextclassifier2::JStringToUtf8String; +using libtextclassifier2::Model; +using libtextclassifier2::ScopedLocalRef; +using libtextclassifier2::SelectionOptions; +using libtextclassifier2::TextClassifier; +#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU +using libtextclassifier2::UniLib; +#endif -using libtextclassifier::ModelOptions; -using libtextclassifier::TextClassificationModel; -using libtextclassifier::nlp_core::lang_id::LangId; - -namespace { - -bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, - std::string* result) { - if (jstr == nullptr) { - *result = std::string(); - return false; - } - - jclass string_class = env->FindClass("java/lang/String"); - if (!string_class) { - TC_LOG(ERROR) << "Can't find String class"; - return false; - } - - jmethodID get_bytes_id = - env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B"); - - jstring encoding = env->NewStringUTF("UTF-8"); - jbyteArray array = reinterpret_cast<jbyteArray>( - env->CallObjectMethod(jstr, get_bytes_id, encoding)); - - jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE); - int length = env->GetArrayLength(array); - - *result = std::string(reinterpret_cast<char*>(array_bytes), length); +namespace libtextclassifier2 { - // Release the array. - env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT); - env->DeleteLocalRef(array); - env->DeleteLocalRef(string_class); - env->DeleteLocalRef(encoding); +using libtextclassifier2::CodepointSpan; - return true; -} +namespace { std::string ToStlString(JNIEnv* env, const jstring& str) { std::string result; @@ -71,47 +55,143 @@ std::string ToStlString(JNIEnv* env, const jstring& str) { return result; } -jobjectArray ScoredStringsToJObjectArray( - JNIEnv* env, const std::string& result_class_name, - const std::vector<std::pair<std::string, float>>& classification_result) { - jclass result_class = env->FindClass(result_class_name.c_str()); +jobjectArray ClassificationResultsToJObjectArray( + JNIEnv* env, + const std::vector<ClassificationResult>& classification_result) { + const ScopedLocalRef<jclass> result_class( + env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"), + env); if (!result_class) { - TC_LOG(ERROR) << "Couldn't find result class: " << result_class_name; + TC_LOG(ERROR) << "Couldn't find ClassificationResult class."; + return nullptr; + } + const ScopedLocalRef<jclass> datetime_parse_class( + env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env); + if (!datetime_parse_class) { + TC_LOG(ERROR) << "Couldn't find DatetimeResult class."; return nullptr; } - jmethodID result_class_constructor = - env->GetMethodID(result_class, "<init>", "(Ljava/lang/String;F)V"); - - jobjectArray results = - env->NewObjectArray(classification_result.size(), result_class, nullptr); + const jmethodID result_class_constructor = + env->GetMethodID(result_class.get(), "<init>", + "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR + "$DatetimeResult;)V"); + const jmethodID datetime_parse_class_constructor = + env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V"); + const jobjectArray results = env->NewObjectArray(classification_result.size(), + result_class.get(), nullptr); for (int i = 0; i < classification_result.size(); i++) { jstring row_string = - env->NewStringUTF(classification_result[i].first.c_str()); + env->NewStringUTF(classification_result[i].collection.c_str()); + jobject row_datetime_parse = nullptr; + if (classification_result[i].datetime_parse_result.IsSet()) { + row_datetime_parse = env->NewObject( + datetime_parse_class.get(), datetime_parse_class_constructor, + classification_result[i].datetime_parse_result.time_ms_utc, + classification_result[i].datetime_parse_result.granularity); + } jobject result = - env->NewObject(result_class, result_class_constructor, row_string, - static_cast<jfloat>(classification_result[i].second)); + env->NewObject(result_class.get(), result_class_constructor, row_string, + static_cast<jfloat>(classification_result[i].score), + row_datetime_parse); env->SetObjectArrayElement(results, i, result); env->DeleteLocalRef(result); } - env->DeleteLocalRef(result_class); return results; } -} // namespace +template <typename T, typename F> +std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object, + jclass class_object, F function, + const std::string& method_name, + const std::string& return_java_type) { + const jmethodID method = env->GetMethodID(class_object, method_name.c_str(), + ("()" + return_java_type).c_str()); + if (!method) { + return std::make_pair(false, T()); + } + return std::make_pair(true, (env->*function)(object, method)); +} -namespace libtextclassifier { +SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) { + if (!joptions) { + return {}; + } -using libtextclassifier::CodepointSpan; + const ScopedLocalRef<jclass> options_class( + env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"), + env); + const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( + env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, + "getLocales", "Ljava/lang/String;"); + if (!status_or_locales.first) { + return {}; + } -namespace { + SelectionOptions options; + options.locales = + ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); + + return options; +} + +template <typename T> +T FromJavaOptionsInternal(JNIEnv* env, jobject joptions, + const std::string& class_name) { + if (!joptions) { + return {}; + } + + const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()), + env); + if (!options_class) { + return {}; + } + + const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( + env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, + "getLocale", "Ljava/lang/String;"); + const std::pair<bool, jobject> status_or_reference_timezone = + CallJniMethod0<jobject>(env, joptions, options_class.get(), + &JNIEnv::CallObjectMethod, "getReferenceTimezone", + "Ljava/lang/String;"); + const std::pair<bool, int64> status_or_reference_time_ms_utc = + CallJniMethod0<int64>(env, joptions, options_class.get(), + &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc", + "J"); + + if (!status_or_locales.first || !status_or_reference_timezone.first || + !status_or_reference_time_ms_utc.first) { + return {}; + } + + T options; + options.locales = + ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); + options.reference_timezone = ToStlString( + env, reinterpret_cast<jstring>(status_or_reference_timezone.second)); + options.reference_time_ms_utc = status_or_reference_time_ms_utc.second; + return options; +} + +ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, + jobject joptions) { + return FromJavaOptionsInternal<ClassificationOptions>( + env, joptions, + TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions"); +} + +AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) { + return FromJavaOptionsInternal<AnnotationOptions>( + env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions"); +} CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str, CodepointSpan orig_indices, bool from_utf8) { - const libtextclassifier::UnicodeText unicode_str = - libtextclassifier::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false); + const libtextclassifier2::UnicodeText unicode_str = + libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false); int unicode_index = 0; int bmp_index = 0; @@ -155,83 +235,142 @@ CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str, } // namespace CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str, - CodepointSpan orig_indices) { - return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/false); + CodepointSpan bmp_indices) { + return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false); } CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str, - CodepointSpan orig_indices) { - return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/true); + CodepointSpan utf8_indices) { + return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true); } -} // namespace libtextclassifier - -using libtextclassifier::CodepointSpan; -using libtextclassifier::ConvertIndicesBMPToUTF8; -using libtextclassifier::ConvertIndicesUTF8ToBMP; -using libtextclassifier::ScopedLocalRef; - -JNI_METHOD(jlong, SmartSelection, nativeNew) -(JNIEnv* env, jobject thiz, jint fd) { - TextClassificationModel* model = new TextClassificationModel(fd); - return reinterpret_cast<jlong>(model); -} - -JNI_METHOD(jlong, SmartSelection, nativeNewFromPath) -(JNIEnv* env, jobject thiz, jstring path) { - const std::string path_str = ToStlString(env, path); - TextClassificationModel* model = new TextClassificationModel(path_str); - return reinterpret_cast<jlong>(model); -} - -JNI_METHOD(jlong, SmartSelection, nativeNewFromAssetFileDescriptor) -(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { +jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) { // Get system-level file descriptor from AssetFileDescriptor. ScopedLocalRef<jclass> afd_class( env->FindClass("android/content/res/AssetFileDescriptor"), env); if (afd_class == nullptr) { - TC_LOG(ERROR) << "Couln't find AssetFileDescriptor."; + TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor."; return reinterpret_cast<jlong>(nullptr); } jmethodID afd_class_getFileDescriptor = env->GetMethodID( afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;"); if (afd_class_getFileDescriptor == nullptr) { - TC_LOG(ERROR) << "Couln't find getFileDescriptor."; + TC_LOG(ERROR) << "Couldn't find getFileDescriptor."; return reinterpret_cast<jlong>(nullptr); } ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"), env); if (fd_class == nullptr) { - TC_LOG(ERROR) << "Couln't find FileDescriptor."; + TC_LOG(ERROR) << "Couldn't find FileDescriptor."; return reinterpret_cast<jlong>(nullptr); } jfieldID fd_class_descriptor = env->GetFieldID(fd_class.get(), "descriptor", "I"); if (fd_class_descriptor == nullptr) { - TC_LOG(ERROR) << "Couln't find descriptor."; + TC_LOG(ERROR) << "Couldn't find descriptor."; return reinterpret_cast<jlong>(nullptr); } jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor); - jint bundle_cfd = env->GetIntField(bundle_jfd, fd_class_descriptor); + return env->GetIntField(bundle_jfd, fd_class_descriptor); +} - TextClassificationModel* model = - new TextClassificationModel(bundle_cfd, offset, size); - return reinterpret_cast<jlong>(model); +jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { + if (!mmap->handle().ok()) { + return env->NewStringUTF(""); + } + const Model* model = libtextclassifier2::ViewModel( + mmap->handle().start(), mmap->handle().num_bytes()); + if (!model || !model->locales()) { + return env->NewStringUTF(""); + } + return env->NewStringUTF(model->locales()->c_str()); } -JNI_METHOD(jintArray, SmartSelection, nativeSuggest) +jint GetVersionFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { + if (!mmap->handle().ok()) { + return 0; + } + const Model* model = libtextclassifier2::ViewModel( + mmap->handle().start(), mmap->handle().num_bytes()); + if (!model) { + return 0; + } + return model->version(); +} + +jstring GetNameFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { + if (!mmap->handle().ok()) { + return env->NewStringUTF(""); + } + const Model* model = libtextclassifier2::ViewModel( + mmap->handle().start(), mmap->handle().num_bytes()); + if (!model || !model->name()) { + return env->NewStringUTF(""); + } + return env->NewStringUTF(model->name()->c_str()); +} + +} // namespace libtextclassifier2 + +using libtextclassifier2::ClassificationResultsToJObjectArray; +using libtextclassifier2::ConvertIndicesBMPToUTF8; +using libtextclassifier2::ConvertIndicesUTF8ToBMP; +using libtextclassifier2::FromJavaAnnotationOptions; +using libtextclassifier2::FromJavaClassificationOptions; +using libtextclassifier2::FromJavaSelectionOptions; +using libtextclassifier2::ToStlString; + +JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew) +(JNIEnv* env, jobject thiz, jint fd) { +#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU + return reinterpret_cast<jlong>( + TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env)); +#else + return reinterpret_cast<jlong>( + TextClassifier::FromFileDescriptor(fd).release()); +#endif +} + +JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath) +(JNIEnv* env, jobject thiz, jstring path) { + const std::string path_str = ToStlString(env, path); +#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU + return reinterpret_cast<jlong>( + TextClassifier::FromPath(path_str, new UniLib(env)).release()); +#else + return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release()); +#endif +} + +JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); +#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU + return reinterpret_cast<jlong>( + TextClassifier::FromFileDescriptor(fd, offset, size, new UniLib(env)) + .release()); +#else + return reinterpret_cast<jlong>( + TextClassifier::FromFileDescriptor(fd, offset, size).release()); +#endif +} + +JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection) (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, - jint selection_end) { - TextClassificationModel* model = - reinterpret_cast<TextClassificationModel*>(ptr); + jint selection_end, jobject options) { + if (!ptr) { + return nullptr; + } + + TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); const std::string context_utf8 = ToStlString(env, context); CodepointSpan input_indices = ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); - CodepointSpan selection = - model->SuggestSelection(context_utf8, input_indices); + CodepointSpan selection = model->SuggestSelection( + context_utf8, input_indices, FromJavaSelectionOptions(env, options)); selection = ConvertIndicesUTF8ToBMP(context_utf8, selection); jintArray result = env->NewIntArray(2); @@ -240,39 +379,45 @@ JNI_METHOD(jintArray, SmartSelection, nativeSuggest) return result; } -JNI_METHOD(jobjectArray, SmartSelection, nativeClassifyText) +JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText) (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, - jint selection_end, jint input_flags) { - TextClassificationModel* ff_model = - reinterpret_cast<TextClassificationModel*>(ptr); - const std::vector<std::pair<std::string, float>> classification_result = - ff_model->ClassifyText(ToStlString(env, context), - {selection_begin, selection_end}, input_flags); - - return ScoredStringsToJObjectArray( - env, TC_PACKAGE_PATH "SmartSelection$ClassificationResult", - classification_result); + jint selection_end, jobject options) { + if (!ptr) { + return nullptr; + } + TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr); + + const std::string context_utf8 = ToStlString(env, context); + const CodepointSpan input_indices = + ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); + const std::vector<ClassificationResult> classification_result = + ff_model->ClassifyText(context_utf8, input_indices, + FromJavaClassificationOptions(env, options)); + + return ClassificationResultsToJObjectArray(env, classification_result); } -JNI_METHOD(jobjectArray, SmartSelection, nativeAnnotate) -(JNIEnv* env, jobject thiz, jlong ptr, jstring context) { - TextClassificationModel* model = - reinterpret_cast<TextClassificationModel*>(ptr); +JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) { + if (!ptr) { + return nullptr; + } + TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); std::string context_utf8 = ToStlString(env, context); - std::vector<TextClassificationModel::AnnotatedSpan> annotations = - model->Annotate(context_utf8); + std::vector<AnnotatedSpan> annotations = + model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options)); jclass result_class = - env->FindClass(TC_PACKAGE_PATH "SmartSelection$AnnotatedSpan"); + env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan"); if (!result_class) { TC_LOG(ERROR) << "Couldn't find result class: " - << TC_PACKAGE_PATH "SmartSelection$AnnotatedSpan"; + << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan"; return nullptr; } jmethodID result_class_constructor = env->GetMethodID( result_class, "<init>", - "(II[L" TC_PACKAGE_PATH "SmartSelection$ClassificationResult;)V"); + "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V"); jobjectArray results = env->NewObjectArray(annotations.size(), result_class, nullptr); @@ -283,9 +428,9 @@ JNI_METHOD(jobjectArray, SmartSelection, nativeAnnotate) jobject result = env->NewObject( result_class, result_class_constructor, static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second), - ScoredStringsToJObjectArray( - env, TC_PACKAGE_PATH "SmartSelection$ClassificationResult", - annotations[i].classification)); + ClassificationResultsToJObjectArray(env, + + annotations[i].classification)); env->SetObjectArrayElement(results, i, result); env->DeleteLocalRef(result); } @@ -293,58 +438,59 @@ JNI_METHOD(jobjectArray, SmartSelection, nativeAnnotate) return results; } -JNI_METHOD(void, SmartSelection, nativeClose) +JNI_METHOD(void, TC_CLASS_NAME, nativeClose) (JNIEnv* env, jobject thiz, jlong ptr) { - TextClassificationModel* model = - reinterpret_cast<TextClassificationModel*>(ptr); + TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); delete model; } -JNI_METHOD(jstring, SmartSelection, nativeGetLanguage) +JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage) (JNIEnv* env, jobject clazz, jint fd) { - ModelOptions model_options; - if (ReadSelectionModelOptions(fd, &model_options)) { - return env->NewStringUTF(model_options.language().c_str()); - } else { - return env->NewStringUTF("UNK"); - } + TC_LOG(WARNING) << "Using deprecated getLanguage()."; + return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd); } -JNI_METHOD(jint, SmartSelection, nativeGetVersion) +JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales) (JNIEnv* env, jobject clazz, jint fd) { - ModelOptions model_options; - if (ReadSelectionModelOptions(fd, &model_options)) { - return model_options.version(); - } else { - return -1; - } + const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( + new libtextclassifier2::ScopedMmap(fd)); + return GetLocalesFromMmap(env, mmap.get()); } -#ifndef LIBTEXTCLASSIFIER_DISABLE_LANG_ID -JNI_METHOD(jlong, LangId, nativeNew) -(JNIEnv* env, jobject thiz, jint fd) { - return reinterpret_cast<jlong>(new LangId(fd)); +JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); + const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( + new libtextclassifier2::ScopedMmap(fd, offset, size)); + return GetLocalesFromMmap(env, mmap.get()); } -JNI_METHOD(jobjectArray, LangId, nativeFindLanguages) -(JNIEnv* env, jobject thiz, jlong ptr, jstring text) { - LangId* lang_id = reinterpret_cast<LangId*>(ptr); - const std::vector<std::pair<std::string, float>> scored_languages = - lang_id->FindLanguages(ToStlString(env, text)); - - return ScoredStringsToJObjectArray( - env, TC_PACKAGE_PATH "LangId$ClassificationResult", scored_languages); +JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion) +(JNIEnv* env, jobject clazz, jint fd) { + const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( + new libtextclassifier2::ScopedMmap(fd)); + return GetVersionFromMmap(env, mmap.get()); } -JNI_METHOD(void, LangId, nativeClose) -(JNIEnv* env, jobject thiz, jlong ptr) { - LangId* lang_id = reinterpret_cast<LangId*>(ptr); - delete lang_id; +JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); + const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( + new libtextclassifier2::ScopedMmap(fd, offset, size)); + return GetVersionFromMmap(env, mmap.get()); } -JNI_METHOD(int, LangId, nativeGetVersion) +JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName) (JNIEnv* env, jobject clazz, jint fd) { - std::unique_ptr<LangId> lang_id(new LangId(fd)); - return lang_id->version(); + const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( + new libtextclassifier2::ScopedMmap(fd)); + return GetNameFromMmap(env, mmap.get()); +} + +JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); + const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( + new libtextclassifier2::ScopedMmap(fd, offset, size)); + return GetNameFromMmap(env, mmap.get()); } -#endif diff --git a/textclassifier_jni.h b/textclassifier_jni.h index 1709ff4..d6e742e 100644 --- a/textclassifier_jni.h +++ b/textclassifier_jni.h @@ -20,89 +20,115 @@ #include <jni.h> #include <string> -#include "smartselect/types.h" +#include "types.h" + +// When we use a macro as an argument for a macro, an additional level of +// indirection is needed, if the macro argument is used with # or ##. +#define ADD_QUOTES_HELPER(TOKEN) #TOKEN +#define ADD_QUOTES(TOKEN) ADD_QUOTES_HELPER(TOKEN) #ifndef TC_PACKAGE_NAME #define TC_PACKAGE_NAME android_view_textclassifier #endif + +#ifndef TC_CLASS_NAME +#define TC_CLASS_NAME TextClassifierImplNative +#endif +#define TC_CLASS_NAME_STR ADD_QUOTES(TC_CLASS_NAME) + #ifndef TC_PACKAGE_PATH #define TC_PACKAGE_PATH "android/view/textclassifier/" #endif +#define JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name) \ + Java_##package_name##_##class_name##_##method_name + #define JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, \ method_name) \ - JNIEXPORT return_type JNICALL \ - Java_##package_name##_##class_name##_##method_name + JNIEXPORT return_type JNICALL JNI_METHOD_NAME_INTERNAL( \ + package_name, class_name, method_name) // The indirection is needed to correctly expand the TC_PACKAGE_NAME macro. +// See the explanation near ADD_QUOTES macro. #define JNI_METHOD2(return_type, package_name, class_name, method_name) \ JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, method_name) #define JNI_METHOD(return_type, class_name, method_name) \ JNI_METHOD2(return_type, TC_PACKAGE_NAME, class_name, method_name) +#define JNI_METHOD_NAME2(package_name, class_name, method_name) \ + JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name) + +#define JNI_METHOD_NAME(class_name, method_name) \ + JNI_METHOD_NAME2(TC_PACKAGE_NAME, class_name, method_name) + #ifdef __cplusplus extern "C" { #endif // SmartSelection. -JNI_METHOD(jlong, SmartSelection, nativeNew) +JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew) (JNIEnv* env, jobject thiz, jint fd); -JNI_METHOD(jlong, SmartSelection, nativeNewFromPath) +JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath) (JNIEnv* env, jobject thiz, jstring path); -JNI_METHOD(jlong, SmartSelection, nativeNewFromAssetFileDescriptor) +JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor) (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); -JNI_METHOD(jintArray, SmartSelection, nativeSuggest) +JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection) (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, - jint selection_end); + jint selection_end, jobject options); -JNI_METHOD(jobjectArray, SmartSelection, nativeClassifyText) +JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText) (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, - jint selection_end, jint input_flags); + jint selection_end, jobject options); -JNI_METHOD(jobjectArray, SmartSelection, nativeAnnotate) -(JNIEnv* env, jobject thiz, jlong ptr, jstring context); +JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options); -JNI_METHOD(void, SmartSelection, nativeClose) +JNI_METHOD(void, TC_CLASS_NAME, nativeClose) (JNIEnv* env, jobject thiz, jlong ptr); -JNI_METHOD(jstring, SmartSelection, nativeGetLanguage) +// DEPRECATED. Use nativeGetLocales instead. +JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage) (JNIEnv* env, jobject clazz, jint fd); -JNI_METHOD(jint, SmartSelection, nativeGetVersion) +JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales) (JNIEnv* env, jobject clazz, jint fd); -#ifndef LIBTEXTCLASSIFIER_DISABLE_LANG_ID -// LangId. -JNI_METHOD(jlong, LangId, nativeNew)(JNIEnv* env, jobject thiz, jint fd); +JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); -JNI_METHOD(jobjectArray, LangId, nativeFindLanguages) -(JNIEnv* env, jobject thiz, jlong ptr, jstring text); +JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion) +(JNIEnv* env, jobject clazz, jint fd); -JNI_METHOD(void, LangId, nativeClose)(JNIEnv* env, jobject thiz, jlong ptr); +JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); -JNI_METHOD(int, LangId, nativeGetVersion)(JNIEnv* env, jobject clazz, jint fd); -#endif +JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName) +(JNIEnv* env, jobject clazz, jint fd); + +JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); #ifdef __cplusplus } #endif -namespace libtextclassifier { +namespace libtextclassifier2 { // Given a utf8 string and a span expressed in Java BMP (basic multilingual // plane) codepoints, converts it to a span expressed in utf8 codepoints. -libtextclassifier::CodepointSpan ConvertIndicesBMPToUTF8( - const std::string& utf8_str, libtextclassifier::CodepointSpan bmp_indices); +libtextclassifier2::CodepointSpan ConvertIndicesBMPToUTF8( + const std::string& utf8_str, libtextclassifier2::CodepointSpan bmp_indices); // Given a utf8 string and a span expressed in utf8 codepoints, converts it to a // span expressed in Java BMP (basic multilingual plane) codepoints. -libtextclassifier::CodepointSpan ConvertIndicesUTF8ToBMP( - const std::string& utf8_str, libtextclassifier::CodepointSpan utf8_indices); +libtextclassifier2::CodepointSpan ConvertIndicesUTF8ToBMP( + const std::string& utf8_str, + libtextclassifier2::CodepointSpan utf8_indices); -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_TEXTCLASSIFIER_JNI_H_ diff --git a/textclassifier_jni_test.cc b/textclassifier_jni_test.cc index ffc193b..87b96fa 100644 --- a/textclassifier_jni_test.cc +++ b/textclassifier_jni_test.cc @@ -19,7 +19,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier { +namespace libtextclassifier2 { namespace { TEST(TextClassifier, ConvertIndicesBMPUTF8) { @@ -76,4 +76,4 @@ TEST(TextClassifier, ConvertIndicesBMPUTF8) { } } // namespace -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/smartselect/token-feature-extractor.cc b/token-feature-extractor.cc index 6afd951..13fba30 100644 --- a/smartselect/token-feature-extractor.cc +++ b/token-feature-extractor.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "smartselect/token-feature-extractor.h" +#include "token-feature-extractor.h" #include <cctype> #include <string> @@ -23,12 +23,8 @@ #include "util/hash/farmhash.h" #include "util/strings/stringpiece.h" #include "util/utf8/unicodetext.h" -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT -#include "unicode/regex.h" -#include "unicode/uchar.h" -#endif -namespace libtextclassifier { +namespace libtextclassifier2 { namespace { @@ -50,69 +46,121 @@ std::string RemapTokenAscii(const std::string& token, return copy; } -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT void RemapTokenUnicode(const std::string& token, const TokenFeatureExtractorOptions& options, - UnicodeText* remapped) { + const UniLib& unilib, UnicodeText* remapped) { if (!options.remap_digits && !options.lowercase_tokens) { // Leave remapped untouched. return; } UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false); - icu::UnicodeString icu_string; + remapped->clear(); for (auto it = word.begin(); it != word.end(); ++it) { - if (options.remap_digits && u_isdigit(*it)) { - icu_string.append('0'); + if (options.remap_digits && unilib.IsDigit(*it)) { + remapped->AppendCodepoint('0'); } else if (options.lowercase_tokens) { - icu_string.append(u_tolower(*it)); + remapped->AppendCodepoint(unilib.ToLower(*it)); } else { - icu_string.append(*it); + remapped->AppendCodepoint(*it); } } - std::string utf8_str; - icu_string.toUTF8String(utf8_str); - remapped->CopyUTF8(utf8_str.data(), utf8_str.length()); } -#endif } // namespace TokenFeatureExtractor::TokenFeatureExtractor( - const TokenFeatureExtractorOptions& options) - : options_(options) { -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - UErrorCode status; + const TokenFeatureExtractorOptions& options, const UniLib& unilib) + : options_(options), unilib_(unilib) { for (const std::string& pattern : options.regexp_features) { - status = U_ZERO_ERROR; - regex_patterns_.push_back( - std::unique_ptr<icu::RegexPattern>(icu::RegexPattern::compile( - icu::UnicodeString(pattern.c_str(), pattern.size(), "utf-8"), 0, - status))); - if (U_FAILURE(status)) { - TC_LOG(WARNING) << "Failed to load pattern" << pattern; + regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>( + unilib_.CreateRegexPattern(UTF8ToUnicodeText( + pattern.c_str(), pattern.size(), /*do_copy=*/false)))); + } +} + +bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span, + std::vector<int>* sparse_features, + std::vector<float>* dense_features) const { + if (!dense_features) { + return false; + } + if (sparse_features) { + *sparse_features = ExtractCharactergramFeatures(token); + } + *dense_features = ExtractDenseFeatures(token, is_in_span); + return true; +} + +std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures( + const Token& token) const { + if (options_.unicode_aware_features) { + return ExtractCharactergramFeaturesUnicode(token); + } else { + return ExtractCharactergramFeaturesAscii(token); + } +} + +std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures( + const Token& token, bool is_in_span) const { + std::vector<float> dense_features; + + if (options_.extract_case_feature) { + if (options_.unicode_aware_features) { + UnicodeText token_unicode = + UTF8ToUnicodeText(token.value, /*do_copy=*/false); + const bool is_upper = unilib_.IsUpper(*token_unicode.begin()); + if (!token.value.empty() && is_upper) { + dense_features.push_back(1.0); + } else { + dense_features.push_back(-1.0); + } + } else { + if (!token.value.empty() && isupper(*token.value.begin())) { + dense_features.push_back(1.0); + } else { + dense_features.push_back(-1.0); + } } } -#else - bool found_unsupported_regexp_features = false; - for (const std::string& pattern : options.regexp_features) { - // A temporary solution to support this specific regexp pattern without - // adding too much binary size. - if (pattern == "^[^a-z]*$") { - enable_all_caps_feature_ = true; + + if (options_.extract_selection_mask_feature) { + if (is_in_span) { + dense_features.push_back(1.0); } else { - found_unsupported_regexp_features = true; + if (options_.unicode_aware_features) { + dense_features.push_back(-1.0); + } else { + dense_features.push_back(0.0); + } } } - if (found_unsupported_regexp_features) { - TC_LOG(WARNING) << "ICU not supported regexp features ignored."; + + // Add regexp features. + if (!regex_patterns_.empty()) { + UnicodeText token_unicode = + UTF8ToUnicodeText(token.value, /*do_copy=*/false); + for (int i = 0; i < regex_patterns_.size(); ++i) { + if (!regex_patterns_[i].get()) { + dense_features.push_back(-1.0); + continue; + } + auto matcher = regex_patterns_[i]->Matcher(token_unicode); + int status; + if (matcher->Matches(&status)) { + dense_features.push_back(1.0); + } else { + dense_features.push_back(-1.0); + } + } } -#endif + + return dense_features; } int TokenFeatureExtractor::HashToken(StringPiece token) const { if (options_.allowed_chargrams.empty()) { - return tcfarmhash::Fingerprint64(token) % options_.num_buckets; + return tc2farmhash::Fingerprint64(token) % options_.num_buckets; } else { // Padding and out-of-vocabulary tokens have extra buckets reserved because // they are special and important tokens, and we don't want them to share @@ -126,22 +174,13 @@ int TokenFeatureExtractor::HashToken(StringPiece token) const { options_.allowed_chargrams.end()) { return 0; // Out-of-vocabulary. } else { - return (tcfarmhash::Fingerprint64(token) % + return (tc2farmhash::Fingerprint64(token) % (options_.num_buckets - kNumExtraBuckets)) + kNumExtraBuckets; } } } -std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures( - const Token& token) const { - if (options_.unicode_aware_features) { - return ExtractCharactergramFeaturesUnicode(token); - } else { - return ExtractCharactergramFeaturesAscii(token); - } -} - std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii( const Token& token) const { std::vector<int> result; @@ -192,13 +231,12 @@ std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii( std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode( const Token& token) const { -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT std::vector<int> result; if (token.is_padding || token.value.empty()) { result.push_back(HashToken("<PAD>")); } else { UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false); - RemapTokenUnicode(token.value, options_, &word); + RemapTokenUnicode(token.value, options_, unilib_, &word); // Trim the word if needed by finding a left-cut point and right-cut point. auto left_cut = word.begin(); @@ -268,98 +306,6 @@ std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode( } } return result; -#else - TC_LOG(WARNING) << "ICU not supported. No feature extracted."; - return {}; -#endif -} - -bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span, - std::vector<int>* sparse_features, - std::vector<float>* dense_features) const { - if (sparse_features == nullptr || dense_features == nullptr) { - return false; - } - - *sparse_features = ExtractCharactergramFeatures(token); - - if (options_.extract_case_feature) { - if (options_.unicode_aware_features) { - UnicodeText token_unicode = - UTF8ToUnicodeText(token.value, /*do_copy=*/false); - bool is_upper; -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - is_upper = u_isupper(*token_unicode.begin()); -#else - TC_LOG(WARNING) << "Using non-unicode isupper because ICU is disabled."; - is_upper = isupper(*token_unicode.begin()); -#endif - if (!token.value.empty() && is_upper) { - dense_features->push_back(1.0); - } else { - dense_features->push_back(-1.0); - } - } else { - if (!token.value.empty() && isupper(*token.value.begin())) { - dense_features->push_back(1.0); - } else { - dense_features->push_back(-1.0); - } - } - } - - if (options_.extract_selection_mask_feature) { - if (is_in_span) { - dense_features->push_back(1.0); - } else { - if (options_.unicode_aware_features) { - dense_features->push_back(-1.0); - } else { - dense_features->push_back(0.0); - } - } - } - -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - // Add regexp features. - if (!regex_patterns_.empty()) { - icu::UnicodeString unicode_str(token.value.c_str(), token.value.size(), - "utf-8"); - for (int i = 0; i < regex_patterns_.size(); ++i) { - if (!regex_patterns_[i].get()) { - dense_features->push_back(-1.0); - continue; - } - - // Check for match. - UErrorCode status = U_ZERO_ERROR; - std::unique_ptr<icu::RegexMatcher> matcher( - regex_patterns_[i]->matcher(unicode_str, status)); - if (matcher->find()) { - dense_features->push_back(1.0); - } else { - dense_features->push_back(-1.0); - } - } - } -#else - if (enable_all_caps_feature_) { - bool is_all_caps = true; - for (const char character_byte : token.value) { - if (islower(character_byte)) { - is_all_caps = false; - break; - } - } - if (is_all_caps) { - dense_features->push_back(1.0); - } else { - dense_features->push_back(-1.0); - } - } -#endif - - return true; } -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/smartselect/token-feature-extractor.h b/token-feature-extractor.h index 5afeca4..fee1355 100644 --- a/smartselect/token-feature-extractor.h +++ b/token-feature-extractor.h @@ -14,20 +14,18 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_ -#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_ +#ifndef LIBTEXTCLASSIFIER_TOKEN_FEATURE_EXTRACTOR_H_ +#define LIBTEXTCLASSIFIER_TOKEN_FEATURE_EXTRACTOR_H_ #include <memory> #include <unordered_set> #include <vector> -#include "smartselect/types.h" +#include "types.h" #include "util/strings/stringpiece.h" -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT -#include "unicode/regex.h" -#endif +#include "util/utf8/unilib.h" -namespace libtextclassifier { +namespace libtextclassifier2 { struct TokenFeatureExtractorOptions { // Number of buckets used for hashing charactergrams. @@ -67,29 +65,30 @@ struct TokenFeatureExtractorOptions { class TokenFeatureExtractor { public: - explicit TokenFeatureExtractor(const TokenFeatureExtractorOptions& options); - - // Extracts features from a token. - // - is_in_span is a bool indicator whether the token is a part of the - // selection span (true) or not (false). - // - sparse_features are indices into a sparse feature vector of size - // options.num_buckets which are set to 1.0 (others are implicitly 0.0). - // - dense_features are values of a dense feature vector of size 0-2 - // (depending on the options) for the token + TokenFeatureExtractor(const TokenFeatureExtractorOptions& options, + const UniLib& unilib); + + // Extracts both the sparse (charactergram) and the dense features from a + // token. is_in_span is a bool indicator whether the token is a part of the + // selection span (true) or not (false). + // The sparse_features output is optional. Fails and returns false if + // dense_fatures in a nullptr. bool Extract(const Token& token, bool is_in_span, std::vector<int>* sparse_features, std::vector<float>* dense_features) const; + // Extracts the sparse (charactergram) features from the token. + std::vector<int> ExtractCharactergramFeatures(const Token& token) const; + + // Extracts the dense features from the token. is_in_span is a bool indicator + // whether the token is a part of the selection span (true) or not (false). + std::vector<float> ExtractDenseFeatures(const Token& token, + bool is_in_span) const; + int DenseFeaturesCount() const { int feature_count = options_.extract_case_feature + options_.extract_selection_mask_feature; -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT feature_count += regex_patterns_.size(); -#else - if (enable_all_caps_feature_) { - feature_count += 1; - } -#endif return feature_count; } @@ -97,9 +96,6 @@ class TokenFeatureExtractor { // Hashes given token to given number of buckets. int HashToken(StringPiece token) const; - // Extracts the charactergram features from the token. - std::vector<int> ExtractCharactergramFeatures(const Token& token) const; - // Extracts the charactergram features from the token in a non-unicode-aware // way. std::vector<int> ExtractCharactergramFeaturesAscii(const Token& token) const; @@ -110,13 +106,10 @@ class TokenFeatureExtractor { private: TokenFeatureExtractorOptions options_; -#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT - std::vector<std::unique_ptr<icu::RegexPattern>> regex_patterns_; -#else - bool enable_all_caps_feature_ = false; -#endif + std::vector<std::unique_ptr<UniLib::RegexPattern>> regex_patterns_; + const UniLib& unilib_; }; -} // namespace libtextclassifier +} // namespace libtextclassifier2 -#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_ +#endif // LIBTEXTCLASSIFIER_TOKEN_FEATURE_EXTRACTOR_H_ diff --git a/smartselect/token-feature-extractor_test.cc b/token-feature-extractor_test.cc index 4b635fd..4b7e011 100644 --- a/smartselect/token-feature-extractor_test.cc +++ b/token-feature-extractor_test.cc @@ -14,18 +14,18 @@ * limitations under the License. */ -#include "smartselect/token-feature-extractor.h" +#include "token-feature-extractor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier { +namespace libtextclassifier2 { namespace { class TestingTokenFeatureExtractor : public TokenFeatureExtractor { public: - using TokenFeatureExtractor::TokenFeatureExtractor; using TokenFeatureExtractor::HashToken; + using TokenFeatureExtractor::TokenFeatureExtractor; }; TEST(TokenFeatureExtractorTest, ExtractAscii) { @@ -35,7 +35,8 @@ TEST(TokenFeatureExtractorTest, ExtractAscii) { options.extract_case_feature = true; options.unicode_aware_features = false; options.extract_selection_mask_feature = true; - TestingTokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -105,7 +106,8 @@ TEST(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) { options.extract_case_feature = true; options.unicode_aware_features = false; options.extract_selection_mask_feature = true; - TestingTokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -134,7 +136,8 @@ TEST(TokenFeatureExtractorTest, ExtractUnicode) { options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = true; - TestingTokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -204,7 +207,8 @@ TEST(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) { options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = true; - TestingTokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -227,6 +231,7 @@ TEST(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) { EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); } +#ifdef LIBTEXTCLASSIFIER_TEST_ICU TEST(TokenFeatureExtractorTest, ICUCaseFeature) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; @@ -234,7 +239,8 @@ TEST(TokenFeatureExtractorTest, ICUCaseFeature) { options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = false; - TokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -260,6 +266,7 @@ TEST(TokenFeatureExtractorTest, ICUCaseFeature) { &dense_features); EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0})); } +#endif TEST(TokenFeatureExtractorTest, DigitRemapping) { TokenFeatureExtractorOptions options; @@ -267,7 +274,8 @@ TEST(TokenFeatureExtractorTest, DigitRemapping) { options.chargram_orders = std::vector<int>{1, 2}; options.remap_digits = true; options.unicode_aware_features = false; - TokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -291,7 +299,8 @@ TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) { options.chargram_orders = std::vector<int>{1, 2}; options.remap_digits = true; options.unicode_aware_features = true; - TokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -315,7 +324,8 @@ TEST(TokenFeatureExtractorTest, LowercaseAscii) { options.chargram_orders = std::vector<int>{1, 2}; options.lowercase_tokens = true; options.unicode_aware_features = false; - TokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -332,13 +342,15 @@ TEST(TokenFeatureExtractorTest, LowercaseAscii) { EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); } +#ifdef LIBTEXTCLASSIFIER_TEST_ICU TEST(TokenFeatureExtractorTest, LowercaseUnicode) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2}; options.lowercase_tokens = true; options.unicode_aware_features = true; - TokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -349,7 +361,9 @@ TEST(TokenFeatureExtractorTest, LowercaseUnicode) { &dense_features); EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); } +#endif +#ifdef LIBTEXTCLASSIFIER_TEST_ICU TEST(TokenFeatureExtractorTest, RegexFeatures) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; @@ -358,7 +372,8 @@ TEST(TokenFeatureExtractorTest, RegexFeatures) { options.unicode_aware_features = false; options.regexp_features.push_back("^[a-z]+$"); // all lower case. options.regexp_features.push_back("^[0-9]+$"); // all digits. - TokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -381,6 +396,7 @@ TEST(TokenFeatureExtractorTest, RegexFeatures) { &dense_features); EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0})); } +#endif TEST(TokenFeatureExtractorTest, ExtractTooLongWord) { TokenFeatureExtractorOptions options; @@ -389,7 +405,8 @@ TEST(TokenFeatureExtractorTest, ExtractTooLongWord) { options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = true; - TestingTokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); // Test that this runs. ASAN should catch problems. std::vector<int> sparse_features; @@ -413,10 +430,12 @@ TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) { options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = true; - TestingTokenFeatureExtractor extractor_unicode(options); + + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor_unicode(options, unilib); options.unicode_aware_features = false; - TestingTokenFeatureExtractor extractor_ascii(options); + TestingTokenFeatureExtractor extractor_ascii(options, unilib); for (const std::string& input : {"https://www.abcdefgh.com/in/xxxkkkvayio", @@ -447,7 +466,8 @@ TEST(TokenFeatureExtractorTest, ExtractForPadToken) { options.unicode_aware_features = false; options.extract_selection_mask_feature = true; - TestingTokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -473,7 +493,8 @@ TEST(TokenFeatureExtractorTest, ExtractFiltered) { options.allowed_chargrams.insert("!"); options.allowed_chargrams.insert("\xc4"); // UTF8 control character. - TestingTokenFeatureExtractor extractor(options); + CREATE_UNILIB_FOR_TESTING + TestingTokenFeatureExtractor extractor(options, unilib); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -540,4 +561,4 @@ TEST(TokenFeatureExtractorTest, ExtractFiltered) { } } // namespace -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/smartselect/tokenizer.cc b/tokenizer.cc index 2489a61..722a67b 100644 --- a/smartselect/tokenizer.cc +++ b/tokenizer.cc @@ -14,30 +14,36 @@ * limitations under the License. */ -#include "smartselect/tokenizer.h" +#include "tokenizer.h" #include <algorithm> +#include "util/base/logging.h" #include "util/strings/utf8.h" -#include "util/utf8/unicodetext.h" -namespace libtextclassifier { +namespace libtextclassifier2 { Tokenizer::Tokenizer( - const std::vector<TokenizationCodepointRange>& codepoint_ranges) - : codepoint_ranges_(codepoint_ranges) { + const std::vector<const TokenizationCodepointRange*>& codepoint_ranges, + bool split_on_script_change) + : split_on_script_change_(split_on_script_change) { + for (const TokenizationCodepointRange* range : codepoint_ranges) { + codepoint_ranges_.emplace_back(range->UnPack()); + } + std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(), - [](const TokenizationCodepointRange& a, - const TokenizationCodepointRange& b) { - return a.start() < b.start(); + [](const std::unique_ptr<const TokenizationCodepointRangeT>& a, + const std::unique_ptr<const TokenizationCodepointRangeT>& b) { + return a->start < b->start; }); } -TokenizationCodepointRange::Role Tokenizer::FindTokenizationRole( +const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange( int codepoint) const { auto it = std::lower_bound( codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint, - [](const TokenizationCodepointRange& range, int codepoint) { + [](const std::unique_ptr<const TokenizationCodepointRangeT>& range, + int codepoint) { // This function compares range with the codepoint for the purpose of // finding the first greater or equal range. Because of the use of // std::lower_bound it needs to return true when range < codepoint; @@ -47,43 +53,68 @@ TokenizationCodepointRange::Role Tokenizer::FindTokenizationRole( // It might seem weird that the condition is range.end <= codepoint // here but when codepoint == range.end it means it's actually just // outside of the range, thus the range is less than the codepoint. - return range.end() <= codepoint; + return range->end <= codepoint; }); - if (it != codepoint_ranges_.end() && it->start() <= codepoint && - it->end() > codepoint) { - return it->role(); + if (it != codepoint_ranges_.end() && (*it)->start <= codepoint && + (*it)->end > codepoint) { + return it->get(); + } else { + return nullptr; + } +} + +void Tokenizer::GetScriptAndRole(char32 codepoint, + TokenizationCodepointRange_::Role* role, + int* script) const { + const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint); + if (range) { + *role = range->role; + *script = range->script_id; } else { - return TokenizationCodepointRange::DEFAULT_ROLE; + *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + *script = kUnknownScript; } } -std::vector<Token> Tokenizer::Tokenize(const std::string& utf8_text) const { - UnicodeText context_unicode = UTF8ToUnicodeText(utf8_text, /*do_copy=*/false); +std::vector<Token> Tokenizer::Tokenize(const std::string& text) const { + UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false); + return Tokenize(text_unicode); +} +std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const { std::vector<Token> result; Token new_token("", 0, 0); int codepoint_index = 0; - for (auto it = context_unicode.begin(); it != context_unicode.end(); + + int last_script = kInvalidScript; + for (auto it = text_unicode.begin(); it != text_unicode.end(); ++it, ++codepoint_index) { - TokenizationCodepointRange::Role role = FindTokenizationRole(*it); - if (role & TokenizationCodepointRange::SPLIT_BEFORE) { + TokenizationCodepointRange_::Role role; + int script; + GetScriptAndRole(*it, &role, &script); + + if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE || + (split_on_script_change_ && last_script != kInvalidScript && + last_script != script)) { if (!new_token.value.empty()) { result.push_back(new_token); } new_token = Token("", codepoint_index, codepoint_index); } - if (!(role & TokenizationCodepointRange::DISCARD_CODEPOINT)) { + if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) { new_token.value += std::string( it.utf8_data(), it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data())); ++new_token.end; } - if (role & TokenizationCodepointRange::SPLIT_AFTER) { + if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) { if (!new_token.value.empty()) { result.push_back(new_token); } new_token = Token("", codepoint_index + 1, codepoint_index + 1); } + + last_script = script; } if (!new_token.value.empty()) { result.push_back(new_token); @@ -92,4 +123,4 @@ std::vector<Token> Tokenizer::Tokenize(const std::string& utf8_text) const { return result; } -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/tokenizer.h b/tokenizer.h new file mode 100644 index 0000000..2524e12 --- /dev/null +++ b/tokenizer.h @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_TOKENIZER_H_ +#define LIBTEXTCLASSIFIER_TOKENIZER_H_ + +#include <string> +#include <vector> + +#include "model_generated.h" +#include "types.h" +#include "util/base/integral_types.h" +#include "util/utf8/unicodetext.h" + +namespace libtextclassifier2 { + +const int kInvalidScript = -1; +const int kUnknownScript = -2; + +// Tokenizer splits the input string into a sequence of tokens, according to the +// configuration. +class Tokenizer { + public: + explicit Tokenizer( + const std::vector<const TokenizationCodepointRange*>& codepoint_ranges, + bool split_on_script_change); + + // Tokenizes the input string using the selected tokenization method. + std::vector<Token> Tokenize(const std::string& text) const; + + // Same as above but takes UnicodeText. + std::vector<Token> Tokenize(const UnicodeText& text_unicode) const; + + protected: + // Finds the tokenization codepoint range config for given codepoint. + // Internally uses binary search so should be O(log(# of codepoint_ranges)). + const TokenizationCodepointRangeT* FindTokenizationRange(int codepoint) const; + + // Finds the role and script for given codepoint. If not found, DEFAULT_ROLE + // and kUnknownScript are assigned. + void GetScriptAndRole(char32 codepoint, + TokenizationCodepointRange_::Role* role, + int* script) const; + + private: + // Codepoint ranges that determine how different codepoints are tokenized. + // The ranges must not overlap. + std::vector<std::unique_ptr<const TokenizationCodepointRangeT>> + codepoint_ranges_; + + // If true, tokens will be additionally split when the codepoint's script_id + // changes. + bool split_on_script_change_; +}; + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_TOKENIZER_H_ diff --git a/tokenizer_test.cc b/tokenizer_test.cc new file mode 100644 index 0000000..65072f3 --- /dev/null +++ b/tokenizer_test.cc @@ -0,0 +1,334 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tokenizer.h" + +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier2 { +namespace { + +using testing::ElementsAreArray; + +class TestingTokenizer : public Tokenizer { + public: + explicit TestingTokenizer( + const std::vector<const TokenizationCodepointRange*>& + codepoint_range_configs, + bool split_on_script_change) + : Tokenizer(codepoint_range_configs, split_on_script_change) {} + + using Tokenizer::FindTokenizationRange; +}; + +class TestingTokenizerProxy { + public: + explicit TestingTokenizerProxy( + const std::vector<TokenizationCodepointRangeT>& codepoint_range_configs, + bool split_on_script_change) { + int num_configs = codepoint_range_configs.size(); + std::vector<const TokenizationCodepointRange*> configs_fb; + buffers_.reserve(num_configs); + for (int i = 0; i < num_configs; i++) { + flatbuffers::FlatBufferBuilder builder; + builder.Finish(CreateTokenizationCodepointRange( + builder, &codepoint_range_configs[i])); + buffers_.push_back(builder.Release()); + configs_fb.push_back( + flatbuffers::GetRoot<TokenizationCodepointRange>(buffers_[i].data())); + } + tokenizer_ = std::unique_ptr<TestingTokenizer>( + new TestingTokenizer(configs_fb, split_on_script_change)); + } + + TokenizationCodepointRange_::Role TestFindTokenizationRole(int c) const { + const TokenizationCodepointRangeT* range = + tokenizer_->FindTokenizationRange(c); + if (range != nullptr) { + return range->role; + } else { + return TokenizationCodepointRange_::Role_DEFAULT_ROLE; + } + } + + std::vector<Token> Tokenize(const std::string& utf8_text) const { + return tokenizer_->Tokenize(utf8_text); + } + + private: + std::vector<flatbuffers::DetachedBuffer> buffers_; + std::unique_ptr<TestingTokenizer> tokenizer_; +}; + +TEST(TokenizerTest, FindTokenizationRange) { + std::vector<TokenizationCodepointRangeT> configs; + TokenizationCodepointRangeT* config; + + configs.emplace_back(); + config = &configs.back(); + config->start = 0; + config->end = 10; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + + configs.emplace_back(); + config = &configs.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + configs.emplace_back(); + config = &configs.back(); + config->start = 1234; + config->end = 12345; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + + TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false); + + // Test hits to the first group. + EXPECT_EQ(tokenizer.TestFindTokenizationRole(0), + TokenizationCodepointRange_::Role_TOKEN_SEPARATOR); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(5), + TokenizationCodepointRange_::Role_TOKEN_SEPARATOR); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(10), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); + + // Test a hit to the second group. + EXPECT_EQ(tokenizer.TestFindTokenizationRole(31), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(32), + TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(33), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); + + // Test hits to the third group. + EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234), + TokenizationCodepointRange_::Role_TOKEN_SEPARATOR); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344), + TokenizationCodepointRange_::Role_TOKEN_SEPARATOR); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); + + // Test a hit outside. + EXPECT_EQ(tokenizer.TestFindTokenizationRole(99), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); +} + +TEST(TokenizerTest, TokenizeOnSpace) { + std::vector<TokenizationCodepointRangeT> configs; + TokenizationCodepointRangeT* config; + + configs.emplace_back(); + config = &configs.back(); + // Space character. + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false); + std::vector<Token> tokens = tokenizer.Tokenize("Hello world!"); + + EXPECT_THAT(tokens, + ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)})); +} + +TEST(TokenizerTest, TokenizeOnSpaceAndScriptChange) { + std::vector<TokenizationCodepointRangeT> configs; + TokenizationCodepointRangeT* config; + + // Latin. + configs.emplace_back(); + config = &configs.back(); + config->start = 0; + config->end = 32; + config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + config->script_id = 1; + configs.emplace_back(); + config = &configs.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + config->script_id = 1; + configs.emplace_back(); + config = &configs.back(); + config->start = 33; + config->end = 0x77F + 1; + config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + config->script_id = 1; + + TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/true); + EXPECT_THAT(tokenizer.Tokenize("앨라배마 주 전화(123) 456-789웹사이트"), + std::vector<Token>({Token("앨라배마", 0, 4), Token("주", 5, 6), + Token("전화", 7, 10), Token("(123)", 10, 15), + Token("456-789", 16, 23), + Token("웹사이트", 23, 28)})); +} // namespace + +TEST(TokenizerTest, TokenizeComplex) { + std::vector<TokenizationCodepointRangeT> configs; + TokenizationCodepointRangeT* config; + + // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt + // Latin - cyrilic. + // 0000..007F; Basic Latin + // 0080..00FF; Latin-1 Supplement + // 0100..017F; Latin Extended-A + // 0180..024F; Latin Extended-B + // 0250..02AF; IPA Extensions + // 02B0..02FF; Spacing Modifier Letters + // 0300..036F; Combining Diacritical Marks + // 0370..03FF; Greek and Coptic + // 0400..04FF; Cyrillic + // 0500..052F; Cyrillic Supplement + // 0530..058F; Armenian + // 0590..05FF; Hebrew + // 0600..06FF; Arabic + // 0700..074F; Syriac + // 0750..077F; Arabic Supplement + configs.emplace_back(); + config = &configs.back(); + config->start = 0; + config->end = 32; + config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + configs.emplace_back(); + config = &configs.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 33; + config->end = 0x77F + 1; + config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + + // CJK + // 2E80..2EFF; CJK Radicals Supplement + // 3000..303F; CJK Symbols and Punctuation + // 3040..309F; Hiragana + // 30A0..30FF; Katakana + // 3100..312F; Bopomofo + // 3130..318F; Hangul Compatibility Jamo + // 3190..319F; Kanbun + // 31A0..31BF; Bopomofo Extended + // 31C0..31EF; CJK Strokes + // 31F0..31FF; Katakana Phonetic Extensions + // 3200..32FF; Enclosed CJK Letters and Months + // 3300..33FF; CJK Compatibility + // 3400..4DBF; CJK Unified Ideographs Extension A + // 4DC0..4DFF; Yijing Hexagram Symbols + // 4E00..9FFF; CJK Unified Ideographs + // A000..A48F; Yi Syllables + // A490..A4CF; Yi Radicals + // A4D0..A4FF; Lisu + // A500..A63F; Vai + // F900..FAFF; CJK Compatibility Ideographs + // FE30..FE4F; CJK Compatibility Forms + // 20000..2A6DF; CJK Unified Ideographs Extension B + // 2A700..2B73F; CJK Unified Ideographs Extension C + // 2B740..2B81F; CJK Unified Ideographs Extension D + // 2B820..2CEAF; CJK Unified Ideographs Extension E + // 2CEB0..2EBEF; CJK Unified Ideographs Extension F + // 2F800..2FA1F; CJK Compatibility Ideographs Supplement + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2E80; + config->end = 0x2EFF + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x3000; + config->end = 0xA63F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0xF900; + config->end = 0xFAFF + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0xFE30; + config->end = 0xFE4F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x20000; + config->end = 0x2A6DF + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2A700; + config->end = 0x2B73F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2B740; + config->end = 0x2B81F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2B820; + config->end = 0x2CEAF + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2CEB0; + config->end = 0x2EBEF + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2F800; + config->end = 0x2FA1F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + + // Thai. + // 0E00..0E7F; Thai + configs.emplace_back(); + config = &configs.back(); + config->start = 0x0E00; + config->end = 0x0E7F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + + TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false); + std::vector<Token> tokens; + + tokens = tokenizer.Tokenize( + "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。"); + EXPECT_EQ(tokens.size(), 30); + + tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ"); + // clang-format off + EXPECT_THAT( + tokens, + ElementsAreArray({Token("問", 0, 1), + Token("少", 1, 2), + Token("目", 2, 3), + Token("hello", 4, 9), + Token("木", 10, 11), + Token("輸", 11, 12), + Token("ย", 12, 13), + Token("า", 13, 14), + Token("ม", 14, 15), + Token("き", 15, 16), + Token("ゃ", 16, 17)})); + // clang-format on +} + +} // namespace +} // namespace libtextclassifier2 diff --git a/types-test-util.h b/types-test-util.h new file mode 100644 index 0000000..1679e7c --- /dev/null +++ b/types-test-util.h @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_TYPES_TEST_UTIL_H_ +#define LIBTEXTCLASSIFIER_TYPES_TEST_UTIL_H_ + +#include <ostream> + +#include "types.h" +#include "util/base/logging.h" + +namespace libtextclassifier2 { + +inline std::ostream& operator<<(std::ostream& stream, const Token& value) { + logging::LoggingStringStream tmp_stream; + tmp_stream << value; + return stream << tmp_stream.message; +} + +inline std::ostream& operator<<(std::ostream& stream, + const AnnotatedSpan& value) { + logging::LoggingStringStream tmp_stream; + tmp_stream << value; + return stream << tmp_stream.message; +} + +inline std::ostream& operator<<(std::ostream& stream, + const DatetimeParseResultSpan& value) { + logging::LoggingStringStream tmp_stream; + tmp_stream << value; + return stream << tmp_stream.message; +} + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_TYPES_TEST_UTIL_H_ @@ -0,0 +1,396 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_TYPES_H_ +#define LIBTEXTCLASSIFIER_TYPES_H_ + +#include <algorithm> +#include <cmath> +#include <functional> +#include <set> +#include <string> +#include <utility> +#include <vector> +#include "util/base/integral_types.h" + +#include "util/base/logging.h" + +namespace libtextclassifier2 { + +constexpr int kInvalidIndex = -1; + +// Index for a 0-based array of tokens. +using TokenIndex = int; + +// Index for a 0-based array of codepoints. +using CodepointIndex = int; + +// Marks a span in a sequence of codepoints. The first element is the index of +// the first codepoint of the span, and the second element is the index of the +// codepoint one past the end of the span. +// TODO(b/71982294): Make it a struct. +using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>; + +inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) { + return a.first < b.second && b.first < a.second; +} + +inline bool ValidNonEmptySpan(const CodepointSpan& span) { + return span.first < span.second && span.first >= 0 && span.second >= 0; +} + +template <typename T> +bool DoesCandidateConflict( + const int considered_candidate, const std::vector<T>& candidates, + const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) { + if (chosen_indices_set.empty()) { + return false; + } + + auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate); + // Check conflict on the right. + if (conflicting_it != chosen_indices_set.end() && + SpansOverlap(candidates[considered_candidate].span, + candidates[*conflicting_it].span)) { + return true; + } + + // Check conflict on the left. + // If we can't go more left, there can't be a conflict: + if (conflicting_it == chosen_indices_set.begin()) { + return false; + } + // Otherwise move one span left and insert if it doesn't overlap with the + // candidate. + --conflicting_it; + if (!SpansOverlap(candidates[considered_candidate].span, + candidates[*conflicting_it].span)) { + return false; + } + + return true; +} + +// Marks a span in a sequence of tokens. The first element is the index of the +// first token in the span, and the second element is the index of the token one +// past the end of the span. +// TODO(b/71982294): Make it a struct. +using TokenSpan = std::pair<TokenIndex, TokenIndex>; + +// Returns the size of the token span. Assumes that the span is valid. +inline int TokenSpanSize(const TokenSpan& token_span) { + return token_span.second - token_span.first; +} + +// Returns a token span consisting of one token. +inline TokenSpan SingleTokenSpan(int token_index) { + return {token_index, token_index + 1}; +} + +// Returns an intersection of two token spans. Assumes that both spans are valid +// and overlapping. +inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1, + const TokenSpan& token_span2) { + return {std::max(token_span1.first, token_span2.first), + std::min(token_span1.second, token_span2.second)}; +} + +// Returns and expanded token span by adding a certain number of tokens on its +// left and on its right. +inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span, + int num_tokens_left, int num_tokens_right) { + return {token_span.first - num_tokens_left, + token_span.second + num_tokens_right}; +} + +// Token holds a token, its position in the original string and whether it was +// part of the input span. +struct Token { + std::string value; + CodepointIndex start; + CodepointIndex end; + + // Whether the token is a padding token. + bool is_padding; + + // Default constructor constructs the padding-token. + Token() + : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {} + + Token(const std::string& arg_value, CodepointIndex arg_start, + CodepointIndex arg_end) + : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {} + + bool operator==(const Token& other) const { + return value == other.value && start == other.start && end == other.end && + is_padding == other.is_padding; + } + + bool IsContainedInSpan(CodepointSpan span) const { + return start >= span.first && end <= span.second; + } +}; + +// Pretty-printing function for Token. +inline logging::LoggingStringStream& operator<<( + logging::LoggingStringStream& stream, const Token& token) { + if (!token.is_padding) { + return stream << "Token(\"" << token.value << "\", " << token.start << ", " + << token.end << ")"; + } else { + return stream << "Token()"; + } +} + +enum DatetimeGranularity { + GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this + // structure being uninitialized. + GRANULARITY_YEAR = 0, + GRANULARITY_MONTH = 1, + GRANULARITY_WEEK = 2, + GRANULARITY_DAY = 3, + GRANULARITY_HOUR = 4, + GRANULARITY_MINUTE = 5, + GRANULARITY_SECOND = 6 +}; + +struct DatetimeParseResult { + // The absolute time in milliseconds since the epoch in UTC. This is derived + // from the reference time and the fields specified in the text - so it may + // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm) + int64 time_ms_utc; + + // The precision of the estimate then in to calculating the milliseconds + DatetimeGranularity granularity; + + DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {} + + DatetimeParseResult(int64 arg_time_ms_utc, + DatetimeGranularity arg_granularity) + : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {} + + bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; } + + bool operator==(const DatetimeParseResult& other) const { + return granularity == other.granularity && time_ms_utc == other.time_ms_utc; + } +}; + +const float kFloatCompareEpsilon = 1e-5; + +struct DatetimeParseResultSpan { + CodepointSpan span; + DatetimeParseResult data; + float target_classification_score; + float priority_score; + + bool operator==(const DatetimeParseResultSpan& other) const { + return span == other.span && data.granularity == other.data.granularity && + data.time_ms_utc == other.data.time_ms_utc && + std::abs(target_classification_score - + other.target_classification_score) < kFloatCompareEpsilon && + std::abs(priority_score - other.priority_score) < + kFloatCompareEpsilon; + } +}; + +// Pretty-printing function for DatetimeParseResultSpan. +inline logging::LoggingStringStream& operator<<( + logging::LoggingStringStream& stream, + const DatetimeParseResultSpan& value) { + return stream << "DatetimeParseResultSpan({" << value.span.first << ", " + << value.span.second << "}, {/*time_ms_utc=*/ " + << value.data.time_ms_utc << ", /*granularity=*/ " + << value.data.granularity << "})"; +} + +struct ClassificationResult { + std::string collection; + float score; + DatetimeParseResult datetime_parse_result; + + // Internal score used for conflict resolution. + float priority_score; + + explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {} + + ClassificationResult(const std::string& arg_collection, float arg_score) + : collection(arg_collection), + score(arg_score), + priority_score(arg_score) {} + + ClassificationResult(const std::string& arg_collection, float arg_score, + float arg_priority_score) + : collection(arg_collection), + score(arg_score), + priority_score(arg_priority_score) {} +}; + +// Pretty-printing function for ClassificationResult. +inline logging::LoggingStringStream& operator<<( + logging::LoggingStringStream& stream, const ClassificationResult& result) { + return stream << "ClassificationResult(" << result.collection << ", " + << result.score << ")"; +} + +// Pretty-printing function for std::vector<ClassificationResult>. +inline logging::LoggingStringStream& operator<<( + logging::LoggingStringStream& stream, + const std::vector<ClassificationResult>& results) { + stream = stream << "{\n"; + for (const ClassificationResult& result : results) { + stream = stream << " " << result << "\n"; + } + stream = stream << "}"; + return stream; +} + +// Represents a result of Annotate call. +struct AnnotatedSpan { + // Unicode codepoint indices in the input string. + CodepointSpan span = {kInvalidIndex, kInvalidIndex}; + + // Classification result for the span. + std::vector<ClassificationResult> classification; +}; + +// Pretty-printing function for AnnotatedSpan. +inline logging::LoggingStringStream& operator<<( + logging::LoggingStringStream& stream, const AnnotatedSpan& span) { + std::string best_class; + float best_score = -1; + if (!span.classification.empty()) { + best_class = span.classification[0].collection; + best_score = span.classification[0].score; + } + return stream << "Span(" << span.span.first << ", " << span.span.second + << ", " << best_class << ", " << best_score << ")"; +} + +// StringPiece analogue for std::vector<T>. +template <class T> +class VectorSpan { + public: + VectorSpan() : begin_(), end_() {} + VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit) + : begin_(v.begin()), end_(v.end()) {} + VectorSpan(typename std::vector<T>::const_iterator begin, + typename std::vector<T>::const_iterator end) + : begin_(begin), end_(end) {} + + const T& operator[](typename std::vector<T>::size_type i) const { + return *(begin_ + i); + } + + int size() const { return end_ - begin_; } + typename std::vector<T>::const_iterator begin() const { return begin_; } + typename std::vector<T>::const_iterator end() const { return end_; } + const float* data() const { return &(*begin_); } + + private: + typename std::vector<T>::const_iterator begin_; + typename std::vector<T>::const_iterator end_; +}; + +struct DateParseData { + enum Relation { + NEXT = 1, + NEXT_OR_SAME = 2, + LAST = 3, + NOW = 4, + TOMORROW = 5, + YESTERDAY = 6, + PAST = 7, + FUTURE = 8 + }; + + enum RelationType { + MONDAY = 1, + TUESDAY = 2, + WEDNESDAY = 3, + THURSDAY = 4, + FRIDAY = 5, + SATURDAY = 6, + SUNDAY = 7, + DAY = 8, + WEEK = 9, + MONTH = 10, + YEAR = 11 + }; + + enum Fields { + YEAR_FIELD = 1 << 0, + MONTH_FIELD = 1 << 1, + DAY_FIELD = 1 << 2, + HOUR_FIELD = 1 << 3, + MINUTE_FIELD = 1 << 4, + SECOND_FIELD = 1 << 5, + AMPM_FIELD = 1 << 6, + ZONE_OFFSET_FIELD = 1 << 7, + DST_OFFSET_FIELD = 1 << 8, + RELATION_FIELD = 1 << 9, + RELATION_TYPE_FIELD = 1 << 10, + RELATION_DISTANCE_FIELD = 1 << 11 + }; + + enum AMPM { AM = 0, PM = 1 }; + + enum TimeUnit { + DAYS = 1, + WEEKS = 2, + MONTHS = 3, + HOURS = 4, + MINUTES = 5, + SECONDS = 6, + YEARS = 7 + }; + + // Bit mask of fields which have been set on the struct + int field_set_mask; + + // Fields describing absolute date fields. + // Year of the date seen in the text match. + int year; + // Month of the year starting with January = 1. + int month; + // Day of the month starting with 1. + int day_of_month; + // Hour of the day with a range of 0-23, + // values less than 12 need the AMPM field below or heuristics + // to definitively determine the time. + int hour; + // Hour of the day with a range of 0-59. + int minute; + // Hour of the day with a range of 0-59. + int second; + // 0 == AM, 1 == PM + int ampm; + // Number of hours offset from UTC this date time is in. + int zone_offset; + // Number of hours offest for DST + int dst_offset; + + // The permutation from now that was made to find the date time. + Relation relation; + // The unit of measure of the change to the date time. + RelationType relation_type; + // The number of units of change that were made. + int relation_distance; +}; + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_TYPES_H_ diff --git a/util/base/casts.h b/util/base/casts.h index 805ee89..a1d2056 100644 --- a/util/base/casts.h +++ b/util/base/casts.h @@ -19,7 +19,7 @@ #include <string.h> // for memcpy -namespace libtextclassifier { +namespace libtextclassifier2 { // bit_cast<Dest, Source> is a template function that implements the equivalent // of "*reinterpret_cast<Dest*>(&source)". We need this in very low-level @@ -87,6 +87,6 @@ inline Dest bit_cast(const Source &source) { return dest; } -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_BASE_CASTS_H_ diff --git a/util/base/config.h b/util/base/config.h index e6c19a4..8844b14 100644 --- a/util/base/config.h +++ b/util/base/config.h @@ -19,7 +19,7 @@ #ifndef LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_ #define LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_ -namespace libtextclassifier { +namespace libtextclassifier2 { // Define LANG_CXX11 to 1 if current compiler supports C++11. // @@ -38,6 +38,6 @@ namespace libtextclassifier { #define LANG_CXX11 1 #endif -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_ diff --git a/util/base/endian.h b/util/base/endian.h index 75f8bf7..2dfbfd6 100644 --- a/util/base/endian.h +++ b/util/base/endian.h @@ -19,7 +19,7 @@ #include "util/base/integral_types.h" -namespace libtextclassifier { +namespace libtextclassifier2 { #if defined OS_LINUX || defined OS_CYGWIN || defined OS_ANDROID || \ defined(__ANDROID__) @@ -133,6 +133,6 @@ class LittleEndian { #endif /* ENDIAN */ }; -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_ diff --git a/util/base/integral_types.h b/util/base/integral_types.h index 0322d33..f82c9cd 100644 --- a/util/base/integral_types.h +++ b/util/base/integral_types.h @@ -21,7 +21,7 @@ #include "util/base/config.h" -namespace libtextclassifier { +namespace libtextclassifier2 { typedef unsigned int uint32; typedef unsigned long long uint64; @@ -56,6 +56,6 @@ static_assert(sizeof(char32) == 4, "wrong size"); static_assert(sizeof(int64) == 8, "wrong size"); #endif // LANG_CXX11 -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_BASE_INTEGRAL_TYPES_H_ diff --git a/util/base/logging.cc b/util/base/logging.cc index 9de35ca..919bb36 100644 --- a/util/base/logging.cc +++ b/util/base/logging.cc @@ -22,7 +22,7 @@ #include "util/base/logging_raw.h" -namespace libtextclassifier { +namespace libtextclassifier2 { namespace logging { namespace { @@ -57,12 +57,11 @@ LogMessage::LogMessage(LogSeverity severity, const char *file_name, } LogMessage::~LogMessage() { - const std::string message = stream_.str(); - LowLevelLogging(severity_, /* tag = */ "txtClsf", message); + LowLevelLogging(severity_, /* tag = */ "txtClsf", stream_.message); if (severity_ == FATAL) { exit(1); } } } // namespace logging -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/util/base/logging.h b/util/base/logging.h index dba0ed4..4391d46 100644 --- a/util/base/logging.h +++ b/util/base/logging.h @@ -18,32 +18,52 @@ #define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_ #include <cassert> -#include <sstream> #include <string> #include "util/base/logging_levels.h" #include "util/base/port.h" -// TC_STRIP -namespace libtextclassifier { -// string class that can't be instantiated. Makes sure that the code does not -// compile when non std::string is used. -// -// NOTE: defined here because most files directly or transitively include this -// file. Asking people to include a special header just to make sure they don't -// use the unqualified string doesn't work: as that header doesn't produce any -// immediate benefit, one can easily forget about it. -class string { - public: - // Makes the class non-instantiable. - virtual ~string() = 0; -}; -} // namespace libtextclassifier -// TC_END_STRIP -namespace libtextclassifier { +namespace libtextclassifier2 { namespace logging { +// A tiny code footprint string stream for assembling log messages. +struct LoggingStringStream { + LoggingStringStream() {} + LoggingStringStream &stream() { return *this; } + // Needed for invocation in TC_CHECK macro. + explicit operator bool() const { return true; } + + std::string message; +}; + +template <typename T> +inline LoggingStringStream &operator<<(LoggingStringStream &stream, + const T &entry) { + stream.message.append(std::to_string(entry)); + return stream; +} + +inline LoggingStringStream &operator<<(LoggingStringStream &stream, + const char *message) { + stream.message.append(message); + return stream; +} + +#if defined(HAS_GLOBAL_STRING) +inline LoggingStringStream &operator<<(LoggingStringStream &stream, + const ::string &message) { + stream.message.append(message); + return stream; +} +#endif + +inline LoggingStringStream &operator<<(LoggingStringStream &stream, + const std::string &message) { + stream.message.append(message); + return stream; +} + // The class that does all the work behind our TC_LOG(severity) macros. Each // TC_LOG(severity) << obj1 << obj2 << ...; logging statement creates a // LogMessage temporary object containing a stringstream. Each operator<< adds @@ -61,19 +81,34 @@ class LogMessage { ~LogMessage() TC_ATTRIBUTE_NOINLINE; // Returns the stream associated with the logger object. - std::stringstream &stream() { return stream_; } + LoggingStringStream &stream() { return stream_; } private: const LogSeverity severity_; // Stream that "prints" all info into a string (not to a file). We construct // here the entire logging message and next print it in one operation. - std::stringstream stream_; + LoggingStringStream stream_; +}; + +// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing +// anything. +class NullStream { + public: + NullStream() {} + NullStream &stream() { return *this; } }; +template <typename T> +inline NullStream &operator<<(NullStream &str, const T &) { + return str; +} -#define TC_LOG(severity) \ - ::libtextclassifier::logging::LogMessage( \ - ::libtextclassifier::logging::severity, __FILE__, __LINE__) \ +} // namespace logging +} // namespace libtextclassifier2 + +#define TC_LOG(severity) \ + ::libtextclassifier2::logging::LogMessage( \ + ::libtextclassifier2::logging::severity, __FILE__, __LINE__) \ .stream() // If condition x is true, does nothing. Otherwise, crashes the program (liek @@ -92,19 +127,7 @@ class LogMessage { #define TC_CHECK_GE(x, y) TC_CHECK((x) >= (y)) #define TC_CHECK_NE(x, y) TC_CHECK((x) != (y)) -// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing -// anything. -class NullStream { - public: - NullStream() {} - NullStream &stream() { return *this; } -}; -template <typename T> -inline NullStream &operator<<(NullStream &str, const T &) { - return str; -} - -#define TC_NULLSTREAM ::libtextclassifier::logging::NullStream().stream() +#define TC_NULLSTREAM ::libtextclassifier2::logging::NullStream().stream() // Debug checks: a TC_DCHECK<suffix> macro should behave like TC_CHECK<suffix> // in debug mode an don't check / don't print anything in non-debug mode. @@ -133,15 +156,12 @@ inline NullStream &operator<<(NullStream &str, const T &) { #endif // NDEBUG #ifdef LIBTEXTCLASSIFIER_VLOG -#define TC_VLOG(severity) \ - ::libtextclassifier::logging::LogMessage(::libtextclassifier::logging::INFO, \ - __FILE__, __LINE__) \ +#define TC_VLOG(severity) \ + ::libtextclassifier2::logging::LogMessage( \ + ::libtextclassifier2::logging::INFO, __FILE__, __LINE__) \ .stream() #else #define TC_VLOG(severity) TC_NULLSTREAM #endif -} // namespace logging -} // namespace libtextclassifier - #endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_ diff --git a/util/base/logging_levels.h b/util/base/logging_levels.h index d16f96a..17c882f 100644 --- a/util/base/logging_levels.h +++ b/util/base/logging_levels.h @@ -17,7 +17,7 @@ #ifndef LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_ #define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_ -namespace libtextclassifier { +namespace libtextclassifier2 { namespace logging { enum LogSeverity { @@ -28,6 +28,6 @@ enum LogSeverity { }; } // namespace logging -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_ diff --git a/util/base/logging_raw.cc b/util/base/logging_raw.cc index 8e0eb1b..6d97852 100644 --- a/util/base/logging_raw.cc +++ b/util/base/logging_raw.cc @@ -26,7 +26,7 @@ // Compiled as part of Android. #include <android/log.h> -namespace libtextclassifier { +namespace libtextclassifier2 { namespace logging { namespace { @@ -60,12 +60,12 @@ void LowLevelLogging(LogSeverity severity, const std::string& tag, } } // namespace logging -} // namespace libtextclassifier +} // namespace libtextclassifier2 #else // if defined(__ANDROID__) // Not on Android: implement LowLevelLogging to print to stderr (see below). -namespace libtextclassifier { +namespace libtextclassifier2 { namespace logging { namespace { @@ -94,6 +94,6 @@ void LowLevelLogging(LogSeverity severity, const std::string &tag, } } // namespace logging -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // if defined(__ANDROID__) diff --git a/util/base/logging_raw.h b/util/base/logging_raw.h index 40c2497..e6265c7 100644 --- a/util/base/logging_raw.h +++ b/util/base/logging_raw.h @@ -21,7 +21,7 @@ #include "util/base/logging_levels.h" -namespace libtextclassifier { +namespace libtextclassifier2 { namespace logging { // Low-level logging primitive. Logs a message, with the indicated log @@ -31,6 +31,6 @@ void LowLevelLogging(LogSeverity severity, const std::string &tag, const std::string &message); } // namespace logging -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_RAW_H_ diff --git a/util/base/macros.h b/util/base/macros.h index aec3a8a..a021ab9 100644 --- a/util/base/macros.h +++ b/util/base/macros.h @@ -19,7 +19,7 @@ #include "util/base/config.h" -namespace libtextclassifier { +namespace libtextclassifier2 { #if LANG_CXX11 #define TC_DISALLOW_COPY_AND_ASSIGN(TypeName) \ @@ -68,16 +68,18 @@ namespace libtextclassifier { // // In either case this macro has no effect on runtime behavior and performance // of code. -#if defined(__clang__) && defined(LANG_CXX11) && defined(__has_warning) +#if defined(__clang__) && defined(__has_warning) #if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough") -#define TC_FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT +#define TC_FALLTHROUGH_INTENDED [[clang::fallthrough]] #endif +#elif defined(__GNUC__) && __GNUC__ >= 7 +#define TC_FALLTHROUGH_INTENDED [[gnu::fallthrough]] #endif #ifndef TC_FALLTHROUGH_INTENDED #define TC_FALLTHROUGH_INTENDED do { } while (0) #endif -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_BASE_MACROS_H_ diff --git a/util/base/port.h b/util/base/port.h index 394aaab..90a2bce 100644 --- a/util/base/port.h +++ b/util/base/port.h @@ -19,7 +19,7 @@ #ifndef LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_ #define LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_ -namespace libtextclassifier { +namespace libtextclassifier2 { #if defined(__GNUC__) && \ (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1)) @@ -40,6 +40,6 @@ namespace libtextclassifier { #define TC_ATTRIBUTE_NOINLINE #endif -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_ diff --git a/util/calendar/calendar-icu.cc b/util/calendar/calendar-icu.cc new file mode 100644 index 0000000..34ea22d --- /dev/null +++ b/util/calendar/calendar-icu.cc @@ -0,0 +1,436 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/calendar/calendar-icu.h" + +#include <memory> + +#include "util/base/macros.h" +#include "unicode/gregocal.h" +#include "unicode/timezone.h" +#include "unicode/ucal.h" + +namespace libtextclassifier2 { +namespace { +int MapToDayOfWeekOrDefault(int relation_type, int default_value) { + switch (relation_type) { + case DateParseData::MONDAY: + return UCalendarDaysOfWeek::UCAL_MONDAY; + case DateParseData::TUESDAY: + return UCalendarDaysOfWeek::UCAL_TUESDAY; + case DateParseData::WEDNESDAY: + return UCalendarDaysOfWeek::UCAL_WEDNESDAY; + case DateParseData::THURSDAY: + return UCalendarDaysOfWeek::UCAL_THURSDAY; + case DateParseData::FRIDAY: + return UCalendarDaysOfWeek::UCAL_FRIDAY; + case DateParseData::SATURDAY: + return UCalendarDaysOfWeek::UCAL_SATURDAY; + case DateParseData::SUNDAY: + return UCalendarDaysOfWeek::UCAL_SUNDAY; + default: + return default_value; + } +} + +bool DispatchToRecedeOrToLastDayOfWeek(icu::Calendar* date, int relation_type, + int distance) { + UErrorCode status = U_ZERO_ERROR; + switch (relation_type) { + case DateParseData::MONDAY: + case DateParseData::TUESDAY: + case DateParseData::WEDNESDAY: + case DateParseData::THURSDAY: + case DateParseData::FRIDAY: + case DateParseData::SATURDAY: + case DateParseData::SUNDAY: + for (int i = 0; i < distance; i++) { + do { + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error day of week"; + return false; + } + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a day"; + return false; + } + } while (date->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status) != + MapToDayOfWeekOrDefault(relation_type, 1)); + } + return true; + case DateParseData::DAY: + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, -1 * distance, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a day"; + return false; + } + + return true; + case DateParseData::WEEK: + date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1); + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, -7 * (distance - 1), + status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a week"; + return false; + } + + return true; + case DateParseData::MONTH: + date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1); + date->add(UCalendarDateFields::UCAL_MONTH, -1 * (distance - 1), status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a month"; + return false; + } + return true; + case DateParseData::YEAR: + date->set(UCalendarDateFields::UCAL_DAY_OF_YEAR, 1); + date->add(UCalendarDateFields::UCAL_YEAR, -1 * (distance - 1), status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a year"; + + return true; + default: + return false; + } + return false; + } +} + +bool DispatchToAdvancerOrToNextOrSameDayOfWeek(icu::Calendar* date, + int relation_type) { + UErrorCode status = U_ZERO_ERROR; + switch (relation_type) { + case DateParseData::MONDAY: + case DateParseData::TUESDAY: + case DateParseData::WEDNESDAY: + case DateParseData::THURSDAY: + case DateParseData::FRIDAY: + case DateParseData::SATURDAY: + case DateParseData::SUNDAY: + while (date->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status) != + MapToDayOfWeekOrDefault(relation_type, 1)) { + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error day of week"; + return false; + } + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a day"; + return false; + } + } + return true; + case DateParseData::DAY: + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a day"; + return false; + } + + return true; + case DateParseData::WEEK: + date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1); + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 7, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a week"; + return false; + } + + return true; + case DateParseData::MONTH: + date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1); + date->add(UCalendarDateFields::UCAL_MONTH, 1, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a month"; + return false; + } + return true; + case DateParseData::YEAR: + date->set(UCalendarDateFields::UCAL_DAY_OF_YEAR, 1); + date->add(UCalendarDateFields::UCAL_YEAR, 1, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a year"; + + return true; + default: + return false; + } + return false; + } +} + +bool DispatchToAdvancerOrToNextDayOfWeek(icu::Calendar* date, int relation_type, + int distance) { + UErrorCode status = U_ZERO_ERROR; + switch (relation_type) { + case DateParseData::MONDAY: + case DateParseData::TUESDAY: + case DateParseData::WEDNESDAY: + case DateParseData::THURSDAY: + case DateParseData::FRIDAY: + case DateParseData::SATURDAY: + case DateParseData::SUNDAY: + for (int i = 0; i < distance; i++) { + do { + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error day of week"; + return false; + } + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a day"; + return false; + } + } while (date->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status) != + MapToDayOfWeekOrDefault(relation_type, 1)); + } + return true; + case DateParseData::DAY: + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, distance, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a day"; + return false; + } + + return true; + case DateParseData::WEEK: + date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1); + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 7 * distance, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a week"; + return false; + } + + return true; + case DateParseData::MONTH: + date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1); + date->add(UCalendarDateFields::UCAL_MONTH, 1 * distance, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a month"; + return false; + } + return true; + case DateParseData::YEAR: + date->set(UCalendarDateFields::UCAL_DAY_OF_YEAR, 1); + date->add(UCalendarDateFields::UCAL_YEAR, 1 * distance, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a year"; + + return true; + default: + return false; + } + return false; + } +} + +bool RoundToGranularity(DatetimeGranularity granularity, + icu::Calendar* calendar) { + // Force recomputation before doing the rounding. + UErrorCode status = U_ZERO_ERROR; + calendar->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "Can't interpret date."; + return false; + } + + switch (granularity) { + case GRANULARITY_YEAR: + calendar->set(UCalendarDateFields::UCAL_MONTH, 0); + TC_FALLTHROUGH_INTENDED; + case GRANULARITY_MONTH: + calendar->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1); + TC_FALLTHROUGH_INTENDED; + case GRANULARITY_DAY: + calendar->set(UCalendarDateFields::UCAL_HOUR, 0); + TC_FALLTHROUGH_INTENDED; + case GRANULARITY_HOUR: + calendar->set(UCalendarDateFields::UCAL_MINUTE, 0); + TC_FALLTHROUGH_INTENDED; + case GRANULARITY_MINUTE: + calendar->set(UCalendarDateFields::UCAL_SECOND, 0); + break; + + case GRANULARITY_WEEK: + calendar->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, + calendar->getFirstDayOfWeek()); + calendar->set(UCalendarDateFields::UCAL_HOUR, 0); + calendar->set(UCalendarDateFields::UCAL_MINUTE, 0); + calendar->set(UCalendarDateFields::UCAL_SECOND, 0); + break; + + case GRANULARITY_UNKNOWN: + case GRANULARITY_SECOND: + break; + } + + return true; +} + +} // namespace + +bool CalendarLib::InterpretParseData(const DateParseData& parse_data, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, + DatetimeGranularity granularity, + int64* interpreted_time_ms_utc) const { + UErrorCode status = U_ZERO_ERROR; + + std::unique_ptr<icu::Calendar> date(icu::Calendar::createInstance( + icu::Locale::createFromName(reference_locale.c_str()), status)); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error getting calendar instance"; + return false; + } + + date->adoptTimeZone(icu::TimeZone::createTimeZone( + icu::UnicodeString::fromUTF8(reference_timezone))); + date->setTime(reference_time_ms_utc, status); + + // By default, the parsed time is interpreted to be on the reference day. But + // a parsed date, should have time 0:00:00 unless specified. + date->set(UCalendarDateFields::UCAL_HOUR_OF_DAY, 0); + date->set(UCalendarDateFields::UCAL_MINUTE, 0); + date->set(UCalendarDateFields::UCAL_SECOND, 0); + date->set(UCalendarDateFields::UCAL_MILLISECOND, 0); + + static const int64 kMillisInHour = 1000 * 60 * 60; + if (parse_data.field_set_mask & DateParseData::Fields::ZONE_OFFSET_FIELD) { + date->set(UCalendarDateFields::UCAL_ZONE_OFFSET, + parse_data.zone_offset * kMillisInHour); + } + if (parse_data.field_set_mask & DateParseData::Fields::DST_OFFSET_FIELD) { + // convert from hours to milliseconds + date->set(UCalendarDateFields::UCAL_DST_OFFSET, + parse_data.dst_offset * kMillisInHour); + } + + if (parse_data.field_set_mask & DateParseData::Fields::RELATION_FIELD) { + switch (parse_data.relation) { + case DateParseData::Relation::NEXT: + if (parse_data.field_set_mask & + DateParseData::Fields::RELATION_TYPE_FIELD) { + if (!DispatchToAdvancerOrToNextDayOfWeek( + date.get(), parse_data.relation_type, 1)) { + return false; + } + } + break; + case DateParseData::Relation::NEXT_OR_SAME: + if (parse_data.field_set_mask & + DateParseData::Fields::RELATION_TYPE_FIELD) { + if (!DispatchToAdvancerOrToNextOrSameDayOfWeek( + date.get(), parse_data.relation_type)) { + return false; + } + } + break; + case DateParseData::Relation::LAST: + if (parse_data.field_set_mask & + DateParseData::Fields::RELATION_TYPE_FIELD) { + if (!DispatchToRecedeOrToLastDayOfWeek(date.get(), + parse_data.relation_type, 1)) { + return false; + } + } + break; + case DateParseData::Relation::NOW: + // NOOP + break; + case DateParseData::Relation::TOMORROW: + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error adding a day"; + return false; + } + break; + case DateParseData::Relation::YESTERDAY: + date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, -1, status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error subtracting a day"; + return false; + } + break; + case DateParseData::Relation::PAST: + if (parse_data.field_set_mask & + DateParseData::Fields::RELATION_TYPE_FIELD) { + if (parse_data.field_set_mask & + DateParseData::Fields::RELATION_DISTANCE_FIELD) { + if (!DispatchToRecedeOrToLastDayOfWeek( + date.get(), parse_data.relation_type, + parse_data.relation_distance)) { + return false; + } + } + } + break; + case DateParseData::Relation::FUTURE: + if (parse_data.field_set_mask & + DateParseData::Fields::RELATION_TYPE_FIELD) { + if (parse_data.field_set_mask & + DateParseData::Fields::RELATION_DISTANCE_FIELD) { + if (!DispatchToAdvancerOrToNextDayOfWeek( + date.get(), parse_data.relation_type, + parse_data.relation_distance)) { + return false; + } + } + } + break; + } + } + if (parse_data.field_set_mask & DateParseData::Fields::YEAR_FIELD) { + date->set(UCalendarDateFields::UCAL_YEAR, parse_data.year); + } + if (parse_data.field_set_mask & DateParseData::Fields::MONTH_FIELD) { + // NOTE: Java and ICU disagree on month formats + date->set(UCalendarDateFields::UCAL_MONTH, parse_data.month - 1); + } + if (parse_data.field_set_mask & DateParseData::Fields::DAY_FIELD) { + date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, parse_data.day_of_month); + } + if (parse_data.field_set_mask & DateParseData::Fields::HOUR_FIELD) { + if (parse_data.field_set_mask & DateParseData::Fields::AMPM_FIELD && + parse_data.ampm == 1 && parse_data.hour < 12) { + date->set(UCalendarDateFields::UCAL_HOUR_OF_DAY, parse_data.hour + 12); + } else { + date->set(UCalendarDateFields::UCAL_HOUR_OF_DAY, parse_data.hour); + } + } + if (parse_data.field_set_mask & DateParseData::Fields::MINUTE_FIELD) { + date->set(UCalendarDateFields::UCAL_MINUTE, parse_data.minute); + } + if (parse_data.field_set_mask & DateParseData::Fields::SECOND_FIELD) { + date->set(UCalendarDateFields::UCAL_SECOND, parse_data.second); + } + + if (!RoundToGranularity(granularity, date.get())) { + return false; + } + + *interpreted_time_ms_utc = date->getTime(status); + if (U_FAILURE(status)) { + TC_LOG(ERROR) << "error getting time from instance"; + return false; + } + + return true; +} +} // namespace libtextclassifier2 diff --git a/util/calendar/calendar-icu.h b/util/calendar/calendar-icu.h new file mode 100644 index 0000000..8aae7ab --- /dev/null +++ b/util/calendar/calendar-icu.h @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_ICU_H_ +#define LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_ICU_H_ + +#include <string> + +#include "types.h" +#include "util/base/integral_types.h" +#include "util/base/logging.h" + +namespace libtextclassifier2 { + +class CalendarLib { + public: + // Interprets parse_data as milliseconds since_epoch. Relative times are + // resolved against the current time (reference_time_ms_utc). Returns true if + // the interpratation was successful, false otherwise. + bool InterpretParseData(const DateParseData& parse_data, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, + DatetimeGranularity granularity, + int64* interpreted_time_ms_utc) const; +}; +} // namespace libtextclassifier2 +#endif // LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_ICU_H_ diff --git a/smartselect/model-parser.h b/util/calendar/calendar.h index 801262f..b0cf2e6 100644 --- a/smartselect/model-parser.h +++ b/util/calendar/calendar.h @@ -14,16 +14,9 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_ -#define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_ +#ifndef LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_H_ +#define LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_H_ -namespace libtextclassifier { +#include "util/calendar/calendar-icu.h" -// Parse a merged model image. -bool ParseMergedModel(const void* addr, const int size, - const char** selection_model, int* selection_model_length, - const char** sharing_model, int* sharing_model_length); - -} // namespace libtextclassifier - -#endif // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_ +#endif // LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_H_ diff --git a/util/calendar/calendar_test.cc b/util/calendar/calendar_test.cc new file mode 100644 index 0000000..1f29106 --- /dev/null +++ b/util/calendar/calendar_test.cc @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This test serves the purpose of making sure all the different implementations +// of the unspoken CalendarLib interface support the same methods. + +#include "util/calendar/calendar.h" +#include "util/base/logging.h" + +#include "gtest/gtest.h" + +namespace libtextclassifier2 { +namespace { + +TEST(CalendarTest, Interface) { + CalendarLib calendar; + int64 time; + std::string timezone; + bool result = calendar.InterpretParseData( + DateParseData{0l, 0, 0, 0, 0, 0, 0, 0, 0, 0, + static_cast<DateParseData::Relation>(0), + static_cast<DateParseData::RelationType>(0), 0}, + 0L, "Zurich", "en-CH", GRANULARITY_UNKNOWN, &time); + TC_LOG(INFO) << result; +} + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST(CalendarTest, RoundingToGranularity) { + CalendarLib calendar; + int64 time; + std::string timezone; + DateParseData data; + data.year = 2018; + data.month = 4; + data.day_of_month = 25; + data.hour = 9; + data.minute = 33; + data.second = 59; + data.field_set_mask = DateParseData::YEAR_FIELD | DateParseData::MONTH_FIELD | + DateParseData::DAY_FIELD | DateParseData::HOUR_FIELD | + DateParseData::MINUTE_FIELD | + DateParseData::SECOND_FIELD; + ASSERT_TRUE(calendar.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_YEAR, &time)); + EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */); + + ASSERT_TRUE(calendar.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_MONTH, &time)); + EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */); + + ASSERT_TRUE(calendar.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_WEEK, &time)); + EXPECT_EQ(time, 1524434400000L /* Mon Apr 23 2018 00:00:00 */); + + ASSERT_TRUE(calendar.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"*-CH", + /*granularity=*/GRANULARITY_WEEK, &time)); + EXPECT_EQ(time, 1524434400000L /* Mon Apr 23 2018 00:00:00 */); + + ASSERT_TRUE(calendar.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-US", + /*granularity=*/GRANULARITY_WEEK, &time)); + EXPECT_EQ(time, 1524348000000L /* Sun Apr 22 2018 00:00:00 */); + + ASSERT_TRUE(calendar.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"*-US", + /*granularity=*/GRANULARITY_WEEK, &time)); + EXPECT_EQ(time, 1524348000000L /* Sun Apr 22 2018 00:00:00 */); + + ASSERT_TRUE(calendar.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_DAY, &time)); + EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */); + + ASSERT_TRUE(calendar.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_HOUR, &time)); + EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */); + + ASSERT_TRUE(calendar.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_MINUTE, &time)); + EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */); + + ASSERT_TRUE(calendar.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_SECOND, &time)); + EXPECT_EQ(time, 1524641639000 /* Apr 25 2018 09:33:59 */); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_DUMMY + +} // namespace +} // namespace libtextclassifier2 diff --git a/common/mock_functions.cc b/util/flatbuffers.cc index c661b70..6c0108e 100644 --- a/common/mock_functions.cc +++ b/util/flatbuffers.cc @@ -14,16 +14,13 @@ * limitations under the License. */ -#include "common/mock_functions.h" +#include "util/flatbuffers.h" -#include "common/registry.h" +namespace libtextclassifier2 { -namespace libtextclassifier { -namespace nlp_core { +template <> +const char* FlatbufferFileIdentifier<Model>() { + return ModelIdentifier(); +} -TC_DEFINE_CLASS_REGISTRY_NAME("function", functions::Function); - -TC_DEFINE_CLASS_REGISTRY_NAME("int-function", functions::IntFunction); - -} // namespace nlp_core -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/util/flatbuffers.h b/util/flatbuffers.h new file mode 100644 index 0000000..93d73b6 --- /dev/null +++ b/util/flatbuffers.h @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Utility functions for working with FlatBuffers. + +#ifndef LIBTEXTCLASSIFIER_UTIL_FLATBUFFERS_H_ +#define LIBTEXTCLASSIFIER_UTIL_FLATBUFFERS_H_ + +#include <memory> +#include <string> + +#include "model_generated.h" +#include "flatbuffers/flatbuffers.h" + +namespace libtextclassifier2 { + +// Loads and interprets the buffer as 'FlatbufferMessage' and verifies its +// integrity. +template <typename FlatbufferMessage> +const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) { + const FlatbufferMessage* message = + flatbuffers::GetRoot<FlatbufferMessage>(buffer); + if (message == nullptr) { + return nullptr; + } + flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer), + size); + if (message->Verify(verifier)) { + return message; + } else { + return nullptr; + } +} + +// Same as above but takes string. +template <typename FlatbufferMessage> +const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) { + return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(), + buffer.size()); +} + +// Loads and interprets the buffer as 'FlatbufferMessage', verifies its +// integrity and returns its mutable version. +template <typename FlatbufferMessage> +std::unique_ptr<typename FlatbufferMessage::NativeTableType> +LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) { + const FlatbufferMessage* message = + LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size); + if (message == nullptr) { + return nullptr; + } + return std::unique_ptr<typename FlatbufferMessage::NativeTableType>( + message->UnPack()); +} + +// Same as above but takes string. +template <typename FlatbufferMessage> +std::unique_ptr<typename FlatbufferMessage::NativeTableType> +LoadAndVerifyMutableFlatbuffer(const std::string& buffer) { + return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(), + buffer.size()); +} + +template <typename FlatbufferMessage> +const char* FlatbufferFileIdentifier() { + return nullptr; +} + +template <> +const char* FlatbufferFileIdentifier<Model>(); + +// Packs the mutable flatbuffer message to string. +template <typename FlatbufferMessage> +std::string PackFlatbuffer( + const typename FlatbufferMessage::NativeTableType* mutable_message) { + flatbuffers::FlatBufferBuilder builder; + builder.Finish(FlatbufferMessage::Pack(builder, mutable_message), + FlatbufferFileIdentifier<FlatbufferMessage>()); + return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); +} + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_UTIL_FLATBUFFERS_H_ diff --git a/util/gtl/map_util.h b/util/gtl/map_util.h index b5eaafa..bd020f8 100644 --- a/util/gtl/map_util.h +++ b/util/gtl/map_util.h @@ -17,7 +17,7 @@ #ifndef LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_ #define LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_ -namespace libtextclassifier { +namespace libtextclassifier2 { // Returns a const reference to the value associated with the given key if it // exists, otherwise returns a const reference to the provided default value. @@ -60,6 +60,6 @@ bool InsertIfNotPresent( typename Collection::value_type(key, value)); } -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_ diff --git a/util/gtl/stl_util.h b/util/gtl/stl_util.h index 8e1c452..7b88e05 100644 --- a/util/gtl/stl_util.h +++ b/util/gtl/stl_util.h @@ -17,7 +17,7 @@ #ifndef LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_ #define LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_ -namespace libtextclassifier { +namespace libtextclassifier2 { // Deletes all the elements in an STL container and clears the container. This // function is suitable for use with a vector, set, hash_set, or any other STL @@ -50,6 +50,6 @@ void STLDeleteValues(T *container) { container->clear(); } -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_ diff --git a/util/hash/farmhash.h b/util/hash/farmhash.h index 7adf3aa..477b7a8 100644 --- a/util/hash/farmhash.h +++ b/util/hash/farmhash.h @@ -24,7 +24,7 @@ #include <utility> #ifndef NAMESPACE_FOR_HASH_FUNCTIONS -#define NAMESPACE_FOR_HASH_FUNCTIONS tcfarmhash +#define NAMESPACE_FOR_HASH_FUNCTIONS tc2farmhash #endif namespace NAMESPACE_FOR_HASH_FUNCTIONS { diff --git a/util/hash/hash.cc b/util/hash/hash.cc index 1261417..9722ddc 100644 --- a/util/hash/hash.cc +++ b/util/hash/hash.cc @@ -18,7 +18,7 @@ #include "util/base/macros.h" -namespace libtextclassifier { +namespace libtextclassifier2 { namespace { // Lower-level versions of Get... that read directly from a character buffer @@ -76,4 +76,4 @@ uint32 Hash32(const char *data, size_t n, uint32 seed) { return h; } -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/util/hash/hash.h b/util/hash/hash.h index 0abb72b..b7a3b53 100644 --- a/util/hash/hash.h +++ b/util/hash/hash.h @@ -21,7 +21,7 @@ #include "util/base/integral_types.h" -namespace libtextclassifier { +namespace libtextclassifier2 { uint32 Hash32(const char *data, size_t n, uint32 seed); @@ -33,6 +33,6 @@ static inline uint32 Hash32WithDefaultSeed(const std::string &input) { return Hash32WithDefaultSeed(input.data(), input.size()); } -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_ diff --git a/util/i18n/locale.cc b/util/i18n/locale.cc new file mode 100644 index 0000000..c587d2d --- /dev/null +++ b/util/i18n/locale.cc @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/i18n/locale.h" + +#include "util/strings/split.h" + +namespace libtextclassifier2 { + +namespace { + +bool CheckLanguage(StringPiece language) { + if (language.size() != 2 && language.size() != 3) { + return false; + } + + // Needs to be all lowercase. + for (int i = 0; i < language.size(); ++i) { + if (!std::islower(language[i])) { + return false; + } + } + + return true; +} + +bool CheckScript(StringPiece script) { + if (script.size() != 4) { + return false; + } + + if (!std::isupper(script[0])) { + return false; + } + + // Needs to be all lowercase. + for (int i = 1; i < script.size(); ++i) { + if (!std::islower(script[i])) { + return false; + } + } + + return true; +} + +bool CheckRegion(StringPiece region) { + if (region.size() == 2) { + return std::isupper(region[0]) && std::isupper(region[1]); + } else if (region.size() == 3) { + return std::isdigit(region[0]) && std::isdigit(region[1]) && + std::isdigit(region[2]); + } else { + return false; + } +} + +} // namespace + +Locale Locale::FromBCP47(const std::string& locale_tag) { + std::vector<StringPiece> parts = strings::Split(locale_tag, '-'); + if (parts.empty()) { + return Locale::Invalid(); + } + + auto parts_it = parts.begin(); + StringPiece language = *parts_it; + if (!CheckLanguage(language)) { + return Locale::Invalid(); + } + ++parts_it; + + StringPiece script; + if (parts_it != parts.end()) { + script = *parts_it; + if (!CheckScript(script)) { + script = ""; + } else { + ++parts_it; + } + } + + StringPiece region; + if (parts_it != parts.end()) { + region = *parts_it; + if (!CheckRegion(region)) { + region = ""; + } else { + ++parts_it; + } + } + + // NOTE: We don't parse the rest of the BCP47 tag here even if specified. + + return Locale(language.ToString(), script.ToString(), region.ToString()); +} + +} // namespace libtextclassifier2 diff --git a/util/i18n/locale.h b/util/i18n/locale.h new file mode 100644 index 0000000..16f10dc --- /dev/null +++ b/util/i18n/locale.h @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTIL_I18N_LOCALE_H_ +#define LIBTEXTCLASSIFIER_UTIL_I18N_LOCALE_H_ + +#include <string> + +#include "util/base/integral_types.h" + +namespace libtextclassifier2 { + +class Locale { + public: + // Constructs the object from a valid BCP47 tag. If the tag is invalid, + // an object is created that gives false when IsInvalid() is called. + static Locale FromBCP47(const std::string& locale_tag); + + // Creates a prototypical invalid locale object. + static Locale Invalid() { + Locale locale(/*language=*/"", /*script=*/"", /*region=*/""); + locale.is_valid_ = false; + return locale; + } + + std::string Language() const { return language_; } + + std::string Script() const { return script_; } + + std::string Region() const { return region_; } + + bool IsValid() const { return is_valid_; } + + private: + Locale(const std::string& language, const std::string& script, + const std::string& region) + : language_(language), + script_(script), + region_(region), + is_valid_(true) {} + + std::string language_; + std::string script_; + std::string region_; + bool is_valid_; +}; + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_UTIL_I18N_LOCALE_H_ diff --git a/util/i18n/locale_test.cc b/util/i18n/locale_test.cc new file mode 100644 index 0000000..72ece98 --- /dev/null +++ b/util/i18n/locale_test.cc @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/i18n/locale.h" + +#include "gtest/gtest.h" + +namespace libtextclassifier2 { +namespace { + +TEST(LocaleTest, ParseUnknown) { + Locale locale = Locale::Invalid(); + EXPECT_FALSE(locale.IsValid()); +} + +TEST(LocaleTest, ParseSwissEnglish) { + Locale locale = Locale::FromBCP47("en-CH"); + EXPECT_TRUE(locale.IsValid()); + EXPECT_EQ(locale.Language(), "en"); + EXPECT_EQ(locale.Script(), ""); + EXPECT_EQ(locale.Region(), "CH"); +} + +TEST(LocaleTest, ParseChineseChina) { + Locale locale = Locale::FromBCP47("zh-CN"); + EXPECT_TRUE(locale.IsValid()); + EXPECT_EQ(locale.Language(), "zh"); + EXPECT_EQ(locale.Script(), ""); + EXPECT_EQ(locale.Region(), "CN"); +} + +TEST(LocaleTest, ParseChineseTaiwan) { + Locale locale = Locale::FromBCP47("zh-Hant-TW"); + EXPECT_TRUE(locale.IsValid()); + EXPECT_EQ(locale.Language(), "zh"); + EXPECT_EQ(locale.Script(), "Hant"); + EXPECT_EQ(locale.Region(), "TW"); +} + +TEST(LocaleTest, ParseEnglish) { + Locale locale = Locale::FromBCP47("en"); + EXPECT_TRUE(locale.IsValid()); + EXPECT_EQ(locale.Language(), "en"); + EXPECT_EQ(locale.Script(), ""); + EXPECT_EQ(locale.Region(), ""); +} + +TEST(LocaleTest, ParseCineseTraditional) { + Locale locale = Locale::FromBCP47("zh-Hant"); + EXPECT_TRUE(locale.IsValid()); + EXPECT_EQ(locale.Language(), "zh"); + EXPECT_EQ(locale.Script(), "Hant"); + EXPECT_EQ(locale.Region(), ""); +} + +} // namespace +} // namespace libtextclassifier2 diff --git a/util/java/scoped_global_ref.h b/util/java/scoped_global_ref.h new file mode 100644 index 0000000..3f8754d --- /dev/null +++ b/util/java/scoped_global_ref.h @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_GLOBAL_REF_H_ +#define LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_GLOBAL_REF_H_ + +#include <jni.h> +#include <memory> +#include <type_traits> + +#include "util/base/logging.h" + +namespace libtextclassifier2 { + +// A deleter to be used with std::unique_ptr to delete JNI global references. +class GlobalRefDeleter { + public: + GlobalRefDeleter() : jvm_(nullptr) {} + + // Style guide violating implicit constructor so that the GlobalRefDeleter + // is implicitly constructed from the second argument to ScopedGlobalRef. + GlobalRefDeleter(JavaVM* jvm) : jvm_(jvm) {} // NOLINT(runtime/explicit) + + GlobalRefDeleter(const GlobalRefDeleter& orig) = default; + + // Copy assignment to allow move semantics in ScopedGlobalRef. + GlobalRefDeleter& operator=(const GlobalRefDeleter& rhs) { + TC_CHECK_EQ(jvm_, rhs.jvm_); + return *this; + } + + // The delete operator. + void operator()(jobject object) const { + JNIEnv* env; + if (object != nullptr && jvm_ != nullptr && + JNI_OK == + jvm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_4)) { + env->DeleteGlobalRef(object); + } + } + + private: + // The jvm_ stashed to use for deletion. + JavaVM* const jvm_; +}; + +// A smart pointer that deletes a JNI global reference when it goes out +// of scope. Usage is: +// ScopedGlobalRef<jobject> scoped_global(env->JniFunction(), jvm); +template <typename T> +using ScopedGlobalRef = + std::unique_ptr<typename std::remove_pointer<T>::type, GlobalRefDeleter>; + +// A helper to create global references. +template <typename T> +ScopedGlobalRef<T> MakeGlobalRef(T object, JNIEnv* env, JavaVM* jvm) { + const jobject globalObject = env->NewGlobalRef(object); + return ScopedGlobalRef<T>(reinterpret_cast<T>(globalObject), jvm); +} + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_GLOBAL_REF_H_ diff --git a/util/java/scoped_local_ref.h b/util/java/scoped_local_ref.h index d995468..8476767 100644 --- a/util/java/scoped_local_ref.h +++ b/util/java/scoped_local_ref.h @@ -23,11 +23,13 @@ #include "util/base/logging.h" -namespace libtextclassifier { +namespace libtextclassifier2 { // A deleter to be used with std::unique_ptr to delete JNI local references. class LocalRefDeleter { public: + LocalRefDeleter() : env_(nullptr) {} + // Style guide violating implicit constructor so that the LocalRefDeleter // is implicitly constructed from the second argument to ScopedLocalRef. LocalRefDeleter(JNIEnv* env) : env_(env) {} // NOLINT(runtime/explicit) @@ -43,7 +45,11 @@ class LocalRefDeleter { } // The delete operator. - void operator()(jobject o) const { env_->DeleteLocalRef(o); } + void operator()(jobject object) const { + if (env_) { + env_->DeleteLocalRef(object); + } + } private: // The env_ stashed to use for deletion. Thread-local, don't share! @@ -60,6 +66,6 @@ template <typename T> using ScopedLocalRef = std::unique_ptr<typename std::remove_pointer<T>::type, LocalRefDeleter>; -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_ diff --git a/util/java/string_utils.cc b/util/java/string_utils.cc new file mode 100644 index 0000000..ffd5b11 --- /dev/null +++ b/util/java/string_utils.cc @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/java/string_utils.h" + +#include "util/base/logging.h" + +namespace libtextclassifier2 { + +bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, + std::string* result) { + if (jstr == nullptr) { + *result = std::string(); + return false; + } + + jclass string_class = env->FindClass("java/lang/String"); + if (!string_class) { + TC_LOG(ERROR) << "Can't find String class"; + return false; + } + + jmethodID get_bytes_id = + env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B"); + + jstring encoding = env->NewStringUTF("UTF-8"); + jbyteArray array = reinterpret_cast<jbyteArray>( + env->CallObjectMethod(jstr, get_bytes_id, encoding)); + + jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE); + int length = env->GetArrayLength(array); + + *result = std::string(reinterpret_cast<char*>(array_bytes), length); + + // Release the array. + env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT); + env->DeleteLocalRef(array); + env->DeleteLocalRef(string_class); + env->DeleteLocalRef(encoding); + + return true; +} + +} // namespace libtextclassifier2 diff --git a/lang_id/light-sentence-features.cc b/util/java/string_utils.h index aec6b81..6a85856 100644 --- a/lang_id/light-sentence-features.cc +++ b/util/java/string_utils.h @@ -14,16 +14,16 @@ * limitations under the License. */ -#include "lang_id/light-sentence-features.h" +#ifndef LIBTEXTCLASSIFIER_UTIL_JAVA_STRING_UTILS_H_ +#define LIBTEXTCLASSIFIER_UTIL_JAVA_STRING_UTILS_H_ -#include "common/registry.h" +#include <jni.h> +#include <string> -namespace libtextclassifier { -namespace nlp_core { +namespace libtextclassifier2 { -// Registry for the features on whole light sentences. -TC_DEFINE_CLASS_REGISTRY_NAME("light sentence feature function", - lang_id::LightSentenceFeature); +bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, std::string* result); -} // namespace nlp_core -} // namespace libtextclassifier +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_UTIL_JAVA_STRING_UTILS_H_ diff --git a/common/fastexp.cc b/util/math/fastexp.cc index 0376ad2..4bf8592 100644 --- a/common/fastexp.cc +++ b/util/math/fastexp.cc @@ -14,10 +14,9 @@ * limitations under the License. */ -#include "common/fastexp.h" +#include "util/math/fastexp.h" -namespace libtextclassifier { -namespace nlp_core { +namespace libtextclassifier2 { const int FastMathClass::kBits; const int FastMathClass::kMask1; @@ -46,5 +45,4 @@ const FastMathClass::Table FastMathClass::cache_ = { 7940441, 8029106, 8118253, 8207884, 8298001} }; -} // namespace nlp_core -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/common/fastexp.h b/util/math/fastexp.h index 1781b36..af7a08c 100644 --- a/common/fastexp.h +++ b/util/math/fastexp.h @@ -16,8 +16,8 @@ // Fast approximation for exp. -#ifndef LIBTEXTCLASSIFIER_COMMON_FASTEXP_H_ -#define LIBTEXTCLASSIFIER_COMMON_FASTEXP_H_ +#ifndef LIBTEXTCLASSIFIER_UTIL_MATH_FASTEXP_H_ +#define LIBTEXTCLASSIFIER_UTIL_MATH_FASTEXP_H_ #include <cassert> #include <cmath> @@ -27,8 +27,7 @@ #include "util/base/integral_types.h" #include "util/base/logging.h" -namespace libtextclassifier { -namespace nlp_core { +namespace libtextclassifier2 { class FastMathClass { private: @@ -64,7 +63,6 @@ extern FastMathClass FastMathInstance; inline float VeryFastExp2(float f) { return FastMathInstance.VeryFastExp2(f); } inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); } -} // namespace nlp_core -} // namespace libtextclassifier +} // namespace libtextclassifier2 -#endif // LIBTEXTCLASSIFIER_COMMON_FASTEXP_H_ +#endif // LIBTEXTCLASSIFIER_UTIL_MATH_FASTEXP_H_ diff --git a/common/softmax.cc b/util/math/softmax.cc index 3610de8..986787f 100644 --- a/common/softmax.cc +++ b/util/math/softmax.cc @@ -14,15 +14,14 @@ * limitations under the License. */ -#include "common/softmax.h" +#include "util/math/softmax.h" #include <limits> -#include "common/fastexp.h" #include "util/base/logging.h" +#include "util/math/fastexp.h" -namespace libtextclassifier { -namespace nlp_core { +namespace libtextclassifier2 { float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) { if ((label < 0) || (label >= scores.size())) { @@ -71,18 +70,24 @@ float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) { } std::vector<float> ComputeSoftmax(const std::vector<float> &scores) { + return ComputeSoftmax(scores.data(), scores.size()); +} + +std::vector<float> ComputeSoftmax(const float *scores, int scores_size) { std::vector<float> softmax; std::vector<float> exp_scores; - exp_scores.reserve(scores.size()); - softmax.reserve(scores.size()); + exp_scores.reserve(scores_size); + softmax.reserve(scores_size); // Find max value in "scores" vector and rescale to avoid overflows. float max = std::numeric_limits<float>::min(); - for (const auto &score : scores) { + for (int i = 0; i < scores_size; ++i) { + const float score = scores[i]; if (score > max) max = score; } float denominator = 0; - for (auto &score : scores) { + for (int i = 0; i < scores_size; ++i) { + const float score = scores[i]; // See comments above in ComputeSoftmaxProbability for the reasoning behind // this approximation. const float exp_score = score - max < -16.0f ? 0 : VeryFastExp(score - max); @@ -90,11 +95,10 @@ std::vector<float> ComputeSoftmax(const std::vector<float> &scores) { denominator += exp_score; } - for (int i = 0; i < scores.size(); ++i) { + for (int i = 0; i < scores_size; ++i) { softmax.push_back(exp_scores[i] / denominator); } return softmax; } -} // namespace nlp_core -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/common/softmax.h b/util/math/softmax.h index e1cc2d9..f70a9ab 100644 --- a/common/softmax.h +++ b/util/math/softmax.h @@ -14,13 +14,12 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_COMMON_SOFTMAX_H_ -#define LIBTEXTCLASSIFIER_COMMON_SOFTMAX_H_ +#ifndef LIBTEXTCLASSIFIER_UTIL_MATH_SOFTMAX_H_ +#define LIBTEXTCLASSIFIER_UTIL_MATH_SOFTMAX_H_ #include <vector> -namespace libtextclassifier { -namespace nlp_core { +namespace libtextclassifier2 { // Computes probability of a softmax label. Parameter "scores" is the vector of // softmax logits. Returns 0.0f if "label" is outside the range [0, @@ -31,7 +30,9 @@ float ComputeSoftmaxProbability(const std::vector<float> &scores, int label); // "scores" is the vector of softmax logits. std::vector<float> ComputeSoftmax(const std::vector<float> &scores); -} // namespace nlp_core -} // namespace libtextclassifier +// Same as above but operates on an array of floats. +std::vector<float> ComputeSoftmax(const float *scores, int scores_size); -#endif // LIBTEXTCLASSIFIER_COMMON_SOFTMAX_H_ +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_UTIL_MATH_SOFTMAX_H_ diff --git a/common/mmap.cc b/util/memory/mmap.cc index 6e15a84..6b0bdf2 100644 --- a/common/mmap.cc +++ b/util/memory/mmap.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "common/mmap.h" +#include "util/memory/mmap.h" #include <errno.h> #include <fcntl.h> @@ -27,8 +27,7 @@ #include "util/base/logging.h" #include "util/base/macros.h" -namespace libtextclassifier { -namespace nlp_core { +namespace libtextclassifier2 { namespace { inline std::string GetLastSystemError() { return std::string(strerror(errno)); } @@ -133,5 +132,4 @@ bool Unmap(MmapHandle mmap_handle) { return true; } -} // namespace nlp_core -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/common/mmap.h b/util/memory/mmap.h index 69f7b4c..7d28b64 100644 --- a/common/mmap.h +++ b/util/memory/mmap.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_COMMON_MMAP_H_ -#define LIBTEXTCLASSIFIER_COMMON_MMAP_H_ +#ifndef LIBTEXTCLASSIFIER_UTIL_MEMORY_MMAP_H_ +#define LIBTEXTCLASSIFIER_UTIL_MEMORY_MMAP_H_ #include <stddef.h> @@ -24,8 +24,7 @@ #include "util/base/integral_types.h" #include "util/strings/stringpiece.h" -namespace libtextclassifier { -namespace nlp_core { +namespace libtextclassifier2 { // Handle for a memory area where a file has been mmapped. // @@ -137,7 +136,6 @@ class ScopedMmap { MmapHandle handle_; }; -} // namespace nlp_core -} // namespace libtextclassifier +} // namespace libtextclassifier2 -#endif // LIBTEXTCLASSIFIER_COMMON_MMAP_H_ +#endif // LIBTEXTCLASSIFIER_UTIL_MEMORY_MMAP_H_ diff --git a/util/strings/numbers.cc b/util/strings/numbers.cc index 4bd8b82..a89c0ef 100644 --- a/util/strings/numbers.cc +++ b/util/strings/numbers.cc @@ -22,7 +22,7 @@ #include <stdlib.h> -namespace libtextclassifier { +namespace libtextclassifier2 { bool ParseInt32(const char *c_str, int32 *value) { char *temp; @@ -72,4 +72,4 @@ std::string IntToString(int64 input) { } #endif // COMPILER_MSVC -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/util/strings/numbers.h b/util/strings/numbers.h index eda53bf..a2c8c6e 100644 --- a/util/strings/numbers.h +++ b/util/strings/numbers.h @@ -21,7 +21,7 @@ #include "util/base/integral_types.h" -namespace libtextclassifier { +namespace libtextclassifier2 { // Parses an int32 from a C-style string. // @@ -47,7 +47,6 @@ bool ParseDouble(const char *c_str, double *value); // int types. std::string IntToString(int64 input); - -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_NUMBERS_H_ diff --git a/util/strings/numbers_test.cc b/util/strings/numbers_test.cc index f3a3f27..1fdd78a 100644 --- a/util/strings/numbers_test.cc +++ b/util/strings/numbers_test.cc @@ -19,7 +19,7 @@ #include "util/base/integral_types.h" #include "gtest/gtest.h" -namespace libtextclassifier { +namespace libtextclassifier2 { namespace { void TestParseInt32(const char *c_str, bool expected_parsing_success, @@ -100,4 +100,4 @@ TEST(ParseDoubleTest, ErrorCases) { TestParseDouble("23.5a", false); } } // namespace -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/util/strings/split.cc b/util/strings/split.cc index 8d250bb..2c610ba 100644 --- a/util/strings/split.cc +++ b/util/strings/split.cc @@ -16,17 +16,17 @@ #include "util/strings/split.h" -namespace libtextclassifier { +namespace libtextclassifier2 { namespace strings { -std::vector<std::string> Split(const std::string &text, char delim) { - std::vector<std::string> result; +std::vector<StringPiece> Split(const StringPiece &text, char delim) { + std::vector<StringPiece> result; int token_start = 0; if (!text.empty()) { for (size_t i = 0; i < text.size() + 1; i++) { if ((i == text.size()) || (text[i] == delim)) { result.push_back( - std::string(text.data() + token_start, i - token_start)); + StringPiece(text.data() + token_start, i - token_start)); token_start = i + 1; } } @@ -35,4 +35,4 @@ std::vector<std::string> Split(const std::string &text, char delim) { } } // namespace strings -} // namespace libtextclassifier +} // namespace libtextclassifier2 diff --git a/util/strings/split.h b/util/strings/split.h index b661ede0..96f73fe 100644 --- a/util/strings/split.h +++ b/util/strings/split.h @@ -20,12 +20,14 @@ #include <string> #include <vector> -namespace libtextclassifier { +#include "util/strings/stringpiece.h" + +namespace libtextclassifier2 { namespace strings { -std::vector<std::string> Split(const std::string &text, char delim); +std::vector<StringPiece> Split(const StringPiece &text, char delim); } // namespace strings -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_SPLIT_H_ diff --git a/util/strings/stringpiece.h b/util/strings/stringpiece.h index 8c42d83..cd07848 100644 --- a/util/strings/stringpiece.h +++ b/util/strings/stringpiece.h @@ -21,7 +21,7 @@ #include <string> -namespace libtextclassifier { +namespace libtextclassifier2 { // Read-only "view" of a piece of data. Does not own the underlying data. class StringPiece { @@ -51,6 +51,8 @@ class StringPiece { size_t size() const { return size_; } size_t length() const { return size_; } + bool empty() const { return size_ == 0; } + // Returns a std::string containing a copy of the underlying data. std::string ToString() const { return std::string(data(), size()); @@ -61,6 +63,6 @@ class StringPiece { size_t size_; }; -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_STRINGPIECE_H_ diff --git a/util/strings/utf8.cc b/util/strings/utf8.cc new file mode 100644 index 0000000..39dcb4e --- /dev/null +++ b/util/strings/utf8.cc @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/strings/utf8.h" + +namespace libtextclassifier2 { +bool IsValidUTF8(const char *src, int size) { + for (int i = 0; i < size;) { + // Unexpected trail byte. + if (IsTrailByte(src[i])) { + return false; + } + + const int num_codepoint_bytes = GetNumBytesForUTF8Char(&src[i]); + if (num_codepoint_bytes <= 0 || i + num_codepoint_bytes > size) { + return false; + } + + // Check that remaining bytes in the codepoint are trailing bytes. + i++; + for (int k = 1; k < num_codepoint_bytes; k++, i++) { + if (!IsTrailByte(src[i])) { + return false; + } + } + } + return true; +} +} // namespace libtextclassifier2 diff --git a/util/strings/utf8.h b/util/strings/utf8.h index 93c7fea..1e75da2 100644 --- a/util/strings/utf8.h +++ b/util/strings/utf8.h @@ -17,7 +17,7 @@ #ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_ #define LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_ -namespace libtextclassifier { +namespace libtextclassifier2 { // Returns the length (number of bytes) of the Unicode code point starting at // src, based on inspecting just that one byte. Preconditions: src != NULL, @@ -44,6 +44,9 @@ static inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } -} // namespace libtextclassifier +// Returns true iff src points to a well-formed UTF-8 string. +bool IsValidUTF8(const char *src, int size); + +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_ diff --git a/util/utf8/unicodetext.cc b/util/utf8/unicodetext.cc index dbab1c8..2ef79e9 100644 --- a/util/utf8/unicodetext.cc +++ b/util/utf8/unicodetext.cc @@ -22,11 +22,21 @@ #include "util/strings/utf8.h" -namespace libtextclassifier { +namespace libtextclassifier2 { // *************** Data representation ********** // Note: the copy constructor is undefined. +UnicodeText::Repr& UnicodeText::Repr::operator=(Repr&& src) { + if (ours_ && data_) delete[] data_; + data_ = src.data_; + size_ = src.size_; + capacity_ = src.capacity_; + ours_ = src.ours_; + src.ours_ = false; + return *this; +} + void UnicodeText::Repr::PointTo(const char* data, int size) { if (ours_ && data_) delete[] data_; // If we owned the old buffer, free it. data_ = const_cast<char*>(data); @@ -89,6 +99,11 @@ UnicodeText::UnicodeText() {} UnicodeText::UnicodeText(const UnicodeText& src) { Copy(src); } +UnicodeText& UnicodeText::operator=(UnicodeText&& src) { + this->repr_ = std::move(src.repr_); + return *this; +} + UnicodeText& UnicodeText::Copy(const UnicodeText& src) { repr_.Copy(src.repr_.data_, src.repr_.size_); return *this; @@ -109,9 +124,87 @@ UnicodeText& UnicodeText::AppendUTF8(const char* utf8, int len) { return *this; } +const char* UnicodeText::data() const { return repr_.data_; } + +int UnicodeText::size_bytes() const { return repr_.size_; } + +namespace { + +enum { + RuneError = 0xFFFD, // Decoding error in UTF. + RuneMax = 0x10FFFF, // Maximum rune value. +}; + +int runetochar(const char32 rune, char* dest) { + // Convert to unsigned for range check. + uint32 c; + + // 1 char 00-7F + c = rune; + if (c <= 0x7F) { + dest[0] = static_cast<char>(c); + return 1; + } + + // 2 char 0080-07FF + if (c <= 0x07FF) { + dest[0] = 0xC0 | static_cast<char>(c >> 1 * 6); + dest[1] = 0x80 | (c & 0x3F); + return 2; + } + + // Range check + if (c > RuneMax) { + c = RuneError; + } + + // 3 char 0800-FFFF + if (c <= 0xFFFF) { + dest[0] = 0xE0 | static_cast<char>(c >> 2 * 6); + dest[1] = 0x80 | ((c >> 1 * 6) & 0x3F); + dest[2] = 0x80 | (c & 0x3F); + return 3; + } + + // 4 char 10000-1FFFFF + dest[0] = 0xF0 | static_cast<char>(c >> 3 * 6); + dest[1] = 0x80 | ((c >> 2 * 6) & 0x3F); + dest[2] = 0x80 | ((c >> 1 * 6) & 0x3F); + dest[3] = 0x80 | (c & 0x3F); + return 4; +} + +} // namespace + +UnicodeText& UnicodeText::AppendCodepoint(char32 ch) { + char str[4]; + int char_len = runetochar(ch, str); + repr_.append(str, char_len); + return *this; +} + void UnicodeText::clear() { repr_.clear(); } -int UnicodeText::size() const { return std::distance(begin(), end()); } +int UnicodeText::size_codepoints() const { + return std::distance(begin(), end()); +} + +bool UnicodeText::empty() const { return size_bytes() == 0; } + +bool UnicodeText::is_valid() const { + return IsValidUTF8(repr_.data_, repr_.size_); +} + +bool UnicodeText::operator==(const UnicodeText& other) const { + if (repr_.size_ != other.repr_.size_) { + return false; + } + return memcmp(repr_.data_, other.repr_.data_, repr_.size_) == 0; +} + +std::string UnicodeText::ToUTF8String() const { + return UTF8Substring(begin(), end()); +} std::string UnicodeText::UTF8Substring(const const_iterator& first, const const_iterator& last) { @@ -191,8 +284,16 @@ UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len, bool do_copy) { return t; } +UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy) { + return UTF8ToUnicodeText(utf8_buf, strlen(utf8_buf), do_copy); +} + UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy) { return UTF8ToUnicodeText(str.data(), str.size(), do_copy); } -} // namespace libtextclassifier +UnicodeText UTF8ToUnicodeText(const std::string& str) { + return UTF8ToUnicodeText(str, /*do_copy=*/true); +} + +} // namespace libtextclassifier2 diff --git a/util/utf8/unicodetext.h b/util/utf8/unicodetext.h index 6a21058..ec08f53 100644 --- a/util/utf8/unicodetext.h +++ b/util/utf8/unicodetext.h @@ -23,7 +23,7 @@ #include "util/base/integral_types.h" -namespace libtextclassifier { +namespace libtextclassifier2 { // ***************************** UnicodeText ************************** // @@ -68,6 +68,7 @@ class UnicodeText { UnicodeText(); // Create an empty text. UnicodeText(const UnicodeText& src); + UnicodeText& operator=(UnicodeText&& src); ~UnicodeText(); class const_iterator { @@ -77,7 +78,7 @@ class UnicodeText { typedef std::input_iterator_tag iterator_category; typedef char32 value_type; typedef int difference_type; - typedef void pointer; // (Not needed.) + typedef void pointer; // (Not needed.) typedef const char32 reference; // (Needed for const_reverse_iterator) // Iterators are default-constructible. @@ -88,7 +89,7 @@ class UnicodeText { char32 operator*() const; // Dereference - const_iterator& operator++(); // Advance (++iter) + const_iterator& operator++(); // Advance (++iter) const_iterator operator++(int) { // (iter++) const_iterator result(*this); ++*this; @@ -132,14 +133,31 @@ class UnicodeText { private: friend class UnicodeText; - explicit const_iterator(const char *it) : it_(it) {} + explicit const_iterator(const char* it) : it_(it) {} - const char *it_; + const char* it_; }; const_iterator begin() const; const_iterator end() const; - int size() const; // the number of Unicode characters (codepoints) + + // Gets pointer to the underlying utf8 data. + const char* data() const; + + // Gets length (in bytes) of the underlying utf8 data. + int size_bytes() const; + + // Computes length (in number of Unicode codepoints) of the underlying utf8 + // data. + // NOTE: Complexity O(n). + int size_codepoints() const; + + bool empty() const; + + // Checks whether the underlying data is valid utf8 data. + bool is_valid() const; + + bool operator==(const UnicodeText& other) const; // x.PointToUTF8(buf,len) changes x so that it points to buf // ("becomes an alias"). It does not take ownership or copy buf. @@ -150,8 +168,10 @@ class UnicodeText { // Calling this may invalidate pointers to underlying data. UnicodeText& AppendUTF8(const char* utf8, int len); + UnicodeText& AppendCodepoint(char32 ch); void clear(); + std::string ToUTF8String() const; static std::string UTF8Substring(const const_iterator& first, const const_iterator& last); @@ -166,6 +186,7 @@ class UnicodeText { bool ours_; // Do we own data_? Repr() : data_(nullptr), size_(0), capacity_(0), ours_(true) {} + Repr& operator=(Repr&& src); ~Repr() { if (ours_) delete[] data_; } @@ -176,7 +197,6 @@ class UnicodeText { void append(const char* bytes, int byte_length); void Copy(const char* data, int size); - void TakeOwnershipOf(char* data, int size, int capacity); void PointTo(const char* data, int size); private: @@ -190,9 +210,15 @@ class UnicodeText { typedef std::pair<UnicodeText::const_iterator, UnicodeText::const_iterator> UnicodeTextRange; +// NOTE: The following are needed to avoid implicit conversion from char* to +// std::string, or from ::string to std::string, because if this happens it +// often results in invalid memory access to a temporary object created during +// such conversion (if do_copy == false). UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len, bool do_copy); +UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy); UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy); +UnicodeText UTF8ToUnicodeText(const std::string& str); -} // namespace libtextclassifier +} // namespace libtextclassifier2 #endif // LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_ diff --git a/util/utf8/unicodetext_test.cc b/util/utf8/unicodetext_test.cc new file mode 100644 index 0000000..9ec7621 --- /dev/null +++ b/util/utf8/unicodetext_test.cc @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/utf8/unicodetext.h" + +#include "gtest/gtest.h" + +namespace libtextclassifier2 { +namespace { + +class UnicodeTextTest : public testing::Test { + protected: + UnicodeTextTest() : empty_text_() { + text_.AppendCodepoint(0x1C0); + text_.AppendCodepoint(0x4E8C); + text_.AppendCodepoint(0xD7DB); + text_.AppendCodepoint(0x34); + text_.AppendCodepoint(0x1D11E); + } + + UnicodeText empty_text_; + UnicodeText text_; +}; + +// Tests for our modifications of UnicodeText. +TEST(UnicodeTextTest, Custom) { + UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false); + EXPECT_EQ(text.ToUTF8String(), "1234😋hello"); + EXPECT_EQ(text.size_codepoints(), 10); + EXPECT_EQ(text.size_bytes(), 13); + + auto it_begin = text.begin(); + std::advance(it_begin, 4); + auto it_end = text.begin(); + std::advance(it_end, 6); + EXPECT_EQ(text.UTF8Substring(it_begin, it_end), "😋h"); +} + +TEST(UnicodeTextTest, Ownership) { + const std::string src = "\u304A\u00B0\u106B"; + + UnicodeText alias; + alias.PointToUTF8(src.data(), src.size()); + EXPECT_EQ(alias.data(), src.data()); + UnicodeText::const_iterator it = alias.begin(); + EXPECT_EQ(*it++, 0x304A); + EXPECT_EQ(*it++, 0x00B0); + EXPECT_EQ(*it++, 0x106B); + EXPECT_EQ(it, alias.end()); + + UnicodeText t = alias; // Copy initialization copies the data. + EXPECT_NE(t.data(), alias.data()); +} + +TEST(UnicodeTextTest, Validation) { + EXPECT_TRUE(UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false).is_valid()); + EXPECT_TRUE( + UTF8ToUnicodeText("\u304A\u00B0\u106B", /*do_copy=*/false).is_valid()); + EXPECT_TRUE( + UTF8ToUnicodeText("this is a test😋😋😋", /*do_copy=*/false).is_valid()); + EXPECT_TRUE( + UTF8ToUnicodeText("\xf0\x9f\x98\x8b", /*do_copy=*/false).is_valid()); + // Too short (string is too short). + EXPECT_FALSE(UTF8ToUnicodeText("\xf0\x9f", /*do_copy=*/false).is_valid()); + // Too long (too many trailing bytes). + EXPECT_FALSE( + UTF8ToUnicodeText("\xf0\x9f\x98\x8b\x8b", /*do_copy=*/false).is_valid()); + // Too short (too few trailing bytes). + EXPECT_FALSE( + UTF8ToUnicodeText("\xf0\x9f\x98\x61\x61", /*do_copy=*/false).is_valid()); + // Invalid with context. + EXPECT_FALSE( + UTF8ToUnicodeText("hello \xf0\x9f\x98\x61\x61 world1", /*do_copy=*/false) + .is_valid()); +} + +class IteratorTest : public UnicodeTextTest {}; + +TEST_F(IteratorTest, Iterates) { + UnicodeText::const_iterator iter = text_.begin(); + EXPECT_EQ(0x1C0, *iter); + EXPECT_EQ(&iter, &++iter); // operator++ returns *this. + EXPECT_EQ(0x4E8C, *iter++); + EXPECT_EQ(0xD7DB, *iter); + // Make sure you can dereference more than once. + EXPECT_EQ(0xD7DB, *iter); + EXPECT_EQ(0x34, *++iter); + EXPECT_EQ(0x1D11E, *++iter); + ASSERT_TRUE(iter != text_.end()); + iter++; + EXPECT_TRUE(iter == text_.end()); +} + +TEST_F(IteratorTest, MultiPass) { + // Also tests Default Constructible and Assignable. + UnicodeText::const_iterator i1, i2; + i1 = text_.begin(); + i2 = i1; + EXPECT_EQ(0x4E8C, *++i1); + EXPECT_TRUE(i1 != i2); + EXPECT_EQ(0x1C0, *i2); + ++i2; + EXPECT_TRUE(i1 == i2); + EXPECT_EQ(0x4E8C, *i2); +} + +TEST_F(IteratorTest, ReverseIterates) { + UnicodeText::const_iterator iter = text_.end(); + EXPECT_TRUE(iter == text_.end()); + iter--; + ASSERT_TRUE(iter != text_.end()); + EXPECT_EQ(0x1D11E, *iter--); + EXPECT_EQ(0x34, *iter); + EXPECT_EQ(0xD7DB, *--iter); + // Make sure you can dereference more than once. + EXPECT_EQ(0xD7DB, *iter); + --iter; + EXPECT_EQ(0x4E8C, *iter--); + EXPECT_EQ(0x1C0, *iter); + EXPECT_TRUE(iter == text_.begin()); +} + +TEST_F(IteratorTest, Comparable) { + UnicodeText::const_iterator i1, i2; + i1 = text_.begin(); + i2 = i1; + ++i2; + + EXPECT_TRUE(i1 < i2); + EXPECT_TRUE(text_.begin() <= i1); + EXPECT_FALSE(i1 >= i2); + EXPECT_FALSE(i1 > text_.end()); +} + +TEST_F(IteratorTest, Advance) { + UnicodeText::const_iterator iter = text_.begin(); + EXPECT_EQ(0x1C0, *iter); + std::advance(iter, 4); + EXPECT_EQ(0x1D11E, *iter); + ++iter; + EXPECT_TRUE(iter == text_.end()); +} + +TEST_F(IteratorTest, Distance) { + UnicodeText::const_iterator iter = text_.begin(); + EXPECT_EQ(0, std::distance(text_.begin(), iter)); + EXPECT_EQ(5, std::distance(iter, text_.end())); + ++iter; + ++iter; + EXPECT_EQ(2, std::distance(text_.begin(), iter)); + EXPECT_EQ(3, std::distance(iter, text_.end())); + ++iter; + ++iter; + EXPECT_EQ(4, std::distance(text_.begin(), iter)); + ++iter; + EXPECT_EQ(0, std::distance(iter, text_.end())); +} + +class OperatorTest : public UnicodeTextTest {}; + +TEST_F(OperatorTest, Clear) { + UnicodeText empty_text(UTF8ToUnicodeText("", /*do_copy=*/false)); + EXPECT_FALSE(text_ == empty_text); + text_.clear(); + EXPECT_TRUE(text_ == empty_text); +} + +TEST_F(OperatorTest, Empty) { + EXPECT_TRUE(empty_text_.empty()); + EXPECT_FALSE(text_.empty()); + text_.clear(); + EXPECT_TRUE(text_.empty()); +} + +} // namespace +} // namespace libtextclassifier2 diff --git a/util/utf8/unilib-icu.cc b/util/utf8/unilib-icu.cc new file mode 100644 index 0000000..9e9ce19 --- /dev/null +++ b/util/utf8/unilib-icu.cc @@ -0,0 +1,293 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/utf8/unilib-icu.h" + +#include <utility> + +namespace libtextclassifier2 { + +bool UniLib::ParseInt32(const UnicodeText& text, int* result) const { + UErrorCode status = U_ZERO_ERROR; + UNumberFormat* format_alias = + unum_open(UNUM_DECIMAL, nullptr, 0, "en_US_POSIX", nullptr, &status); + if (U_FAILURE(status)) { + return false; + } + icu::UnicodeString utf8_string = icu::UnicodeString::fromUTF8( + icu::StringPiece(text.data(), text.size_bytes())); + int parse_index = 0; + const int32 integer = unum_parse(format_alias, utf8_string.getBuffer(), + utf8_string.length(), &parse_index, &status); + *result = integer; + unum_close(format_alias); + if (U_FAILURE(status) || parse_index != utf8_string.length()) { + return false; + } + return true; +} + +bool UniLib::IsOpeningBracket(char32 codepoint) const { + return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) == + U_BPT_OPEN; +} + +bool UniLib::IsClosingBracket(char32 codepoint) const { + return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) == + U_BPT_CLOSE; +} + +bool UniLib::IsWhitespace(char32 codepoint) const { + return u_isWhitespace(codepoint); +} + +bool UniLib::IsDigit(char32 codepoint) const { return u_isdigit(codepoint); } + +bool UniLib::IsUpper(char32 codepoint) const { return u_isupper(codepoint); } + +char32 UniLib::ToLower(char32 codepoint) const { return u_tolower(codepoint); } + +char32 UniLib::GetPairedBracket(char32 codepoint) const { + return u_getBidiPairedBracket(codepoint); +} + +UniLib::RegexMatcher::RegexMatcher(icu::RegexPattern* pattern, + icu::UnicodeString text) + : text_(std::move(text)), + last_find_offset_(0), + last_find_offset_codepoints_(0), + last_find_offset_dirty_(true) { + UErrorCode status = U_ZERO_ERROR; + matcher_.reset(pattern->matcher(text_, status)); + if (U_FAILURE(status)) { + matcher_.reset(nullptr); + } +} + +std::unique_ptr<UniLib::RegexMatcher> UniLib::RegexPattern::Matcher( + const UnicodeText& input) const { + return std::unique_ptr<UniLib::RegexMatcher>(new UniLib::RegexMatcher( + pattern_.get(), icu::UnicodeString::fromUTF8( + icu::StringPiece(input.data(), input.size_bytes())))); +} + +constexpr int UniLib::RegexMatcher::kError; +constexpr int UniLib::RegexMatcher::kNoError; + +bool UniLib::RegexMatcher::Matches(int* status) const { + if (!matcher_) { + *status = kError; + return false; + } + + UErrorCode icu_status = U_ZERO_ERROR; + const bool result = matcher_->matches(/*startIndex=*/0, icu_status); + if (U_FAILURE(icu_status)) { + *status = kError; + return false; + } + *status = kNoError; + return result; +} + +bool UniLib::RegexMatcher::ApproximatelyMatches(int* status) { + if (!matcher_) { + *status = kError; + return false; + } + + matcher_->reset(); + *status = kNoError; + if (!Find(status) || *status != kNoError) { + return false; + } + const int found_start = Start(status); + if (*status != kNoError) { + return false; + } + const int found_end = End(status); + if (*status != kNoError) { + return false; + } + if (found_start != 0 || found_end != text_.countChar32()) { + return false; + } + return true; +} + +bool UniLib::RegexMatcher::UpdateLastFindOffset() const { + if (!last_find_offset_dirty_) { + return true; + } + + // Update the position of the match. + UErrorCode icu_status = U_ZERO_ERROR; + const int find_offset = matcher_->start(0, icu_status); + if (U_FAILURE(icu_status)) { + return false; + } + last_find_offset_codepoints_ += + text_.countChar32(last_find_offset_, find_offset - last_find_offset_); + last_find_offset_ = find_offset; + last_find_offset_dirty_ = false; + + return true; +} + +bool UniLib::RegexMatcher::Find(int* status) { + if (!matcher_) { + *status = kError; + return false; + } + UErrorCode icu_status = U_ZERO_ERROR; + const bool result = matcher_->find(icu_status); + if (U_FAILURE(icu_status)) { + *status = kError; + return false; + } + + last_find_offset_dirty_ = true; + *status = kNoError; + return result; +} + +int UniLib::RegexMatcher::Start(int* status) const { + return Start(/*group_idx=*/0, status); +} + +int UniLib::RegexMatcher::Start(int group_idx, int* status) const { + if (!matcher_ || !UpdateLastFindOffset()) { + *status = kError; + return kError; + } + + UErrorCode icu_status = U_ZERO_ERROR; + const int result = matcher_->start(group_idx, icu_status); + if (U_FAILURE(icu_status)) { + *status = kError; + return kError; + } + *status = kNoError; + + // If the group didn't participate in the match the result is -1 and is + // incompatible with the caching logic bellow. + if (result == -1) { + return -1; + } + + return last_find_offset_codepoints_ + + text_.countChar32(/*start=*/last_find_offset_, + /*length=*/result - last_find_offset_); +} + +int UniLib::RegexMatcher::End(int* status) const { + return End(/*group_idx=*/0, status); +} + +int UniLib::RegexMatcher::End(int group_idx, int* status) const { + if (!matcher_ || !UpdateLastFindOffset()) { + *status = kError; + return kError; + } + UErrorCode icu_status = U_ZERO_ERROR; + const int result = matcher_->end(group_idx, icu_status); + if (U_FAILURE(icu_status)) { + *status = kError; + return kError; + } + *status = kNoError; + + // If the group didn't participate in the match the result is -1 and is + // incompatible with the caching logic bellow. + if (result == -1) { + return -1; + } + + return last_find_offset_codepoints_ + + text_.countChar32(/*start=*/last_find_offset_, + /*length=*/result - last_find_offset_); +} + +UnicodeText UniLib::RegexMatcher::Group(int* status) const { + return Group(/*group_idx=*/0, status); +} + +UnicodeText UniLib::RegexMatcher::Group(int group_idx, int* status) const { + if (!matcher_) { + *status = kError; + return UTF8ToUnicodeText("", /*do_copy=*/false); + } + std::string result = ""; + UErrorCode icu_status = U_ZERO_ERROR; + const icu::UnicodeString result_icu = matcher_->group(group_idx, icu_status); + if (U_FAILURE(icu_status)) { + *status = kError; + return UTF8ToUnicodeText("", /*do_copy=*/false); + } + result_icu.toUTF8String(result); + *status = kNoError; + return UTF8ToUnicodeText(result, /*do_copy=*/true); +} + +constexpr int UniLib::BreakIterator::kDone; + +UniLib::BreakIterator::BreakIterator(const UnicodeText& text) + : text_(icu::UnicodeString::fromUTF8( + icu::StringPiece(text.data(), text.size_bytes()))), + last_break_index_(0), + last_unicode_index_(0) { + icu::ErrorCode status; + break_iterator_.reset( + icu::BreakIterator::createWordInstance(icu::Locale("en"), status)); + if (!status.isSuccess()) { + break_iterator_.reset(); + return; + } + break_iterator_->setText(text_); +} + +int UniLib::BreakIterator::Next() { + const int break_index = break_iterator_->next(); + if (break_index == icu::BreakIterator::DONE) { + return BreakIterator::kDone; + } + last_unicode_index_ += + text_.countChar32(last_break_index_, break_index - last_break_index_); + last_break_index_ = break_index; + return last_unicode_index_; +} + +std::unique_ptr<UniLib::RegexPattern> UniLib::CreateRegexPattern( + const UnicodeText& regex) const { + UErrorCode status = U_ZERO_ERROR; + std::unique_ptr<icu::RegexPattern> pattern( + icu::RegexPattern::compile(icu::UnicodeString::fromUTF8(icu::StringPiece( + regex.data(), regex.size_bytes())), + /*flags=*/UREGEX_MULTILINE, status)); + if (U_FAILURE(status) || !pattern) { + return nullptr; + } + return std::unique_ptr<UniLib::RegexPattern>( + new UniLib::RegexPattern(std::move(pattern))); +} + +std::unique_ptr<UniLib::BreakIterator> UniLib::CreateBreakIterator( + const UnicodeText& text) const { + return std::unique_ptr<UniLib::BreakIterator>( + new UniLib::BreakIterator(text)); +} + +} // namespace libtextclassifier2 diff --git a/util/utf8/unilib-icu.h b/util/utf8/unilib-icu.h new file mode 100644 index 0000000..8983756 --- /dev/null +++ b/util/utf8/unilib-icu.h @@ -0,0 +1,155 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// UniLib implementation with the help of ICU. UniLib is basically a wrapper +// around the ICU functionality. + +#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_ICU_H_ +#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_ICU_H_ + +#include <memory> + +#include "util/base/integral_types.h" +#include "util/utf8/unicodetext.h" +#include "unicode/brkiter.h" +#include "unicode/errorcode.h" +#include "unicode/regex.h" +#include "unicode/uchar.h" +#include "unicode/unum.h" + +namespace libtextclassifier2 { + +class UniLib { + public: + bool ParseInt32(const UnicodeText& text, int* result) const; + bool IsOpeningBracket(char32 codepoint) const; + bool IsClosingBracket(char32 codepoint) const; + bool IsWhitespace(char32 codepoint) const; + bool IsDigit(char32 codepoint) const; + bool IsUpper(char32 codepoint) const; + + char32 ToLower(char32 codepoint) const; + char32 GetPairedBracket(char32 codepoint) const; + + // Forward declaration for friend. + class RegexPattern; + + class RegexMatcher { + public: + static constexpr int kError = -1; + static constexpr int kNoError = 0; + + // Checks whether the input text matches the pattern exactly. + bool Matches(int* status) const; + + // Approximate Matches() implementation implemented using Find(). It uses + // the first Find() result and then checks that it spans the whole input. + // NOTE: Unlike Matches() it can result in false negatives. + // NOTE: Resets the matcher, so the current Find() state will be lost. + bool ApproximatelyMatches(int* status); + + // Finds occurrences of the pattern in the input text. + // Can be called repeatedly to find all occurences. A call will update + // internal state, so that 'Start', 'End' and 'Group' can be called to get + // information about the match. + // NOTE: Any call to ApproximatelyMatches() in between Find() calls will + // modify the state. + bool Find(int* status); + + // Gets the start offset of the last match (from 'Find'). + // Sets status to 'kError' if 'Find' + // was not called previously. + int Start(int* status) const; + + // Gets the start offset of the specified group of the last match. + // (from 'Find'). + // Sets status to 'kError' if an invalid group was specified or if 'Find' + // was not called previously. + int Start(int group_idx, int* status) const; + + // Gets the end offset of the last match (from 'Find'). + // Sets status to 'kError' if 'Find' + // was not called previously. + int End(int* status) const; + + // Gets the end offset of the specified group of the last match. + // (from 'Find'). + // Sets status to 'kError' if an invalid group was specified or if 'Find' + // was not called previously. + int End(int group_idx, int* status) const; + + // Gets the text of the last match (from 'Find'). + // Sets status to 'kError' if 'Find' was not called previously. + UnicodeText Group(int* status) const; + + // Gets the text of the specified group of the last match (from 'Find'). + // Sets status to 'kError' if an invalid group was specified or if 'Find' + // was not called previously. + UnicodeText Group(int group_idx, int* status) const; + + protected: + friend class RegexPattern; + explicit RegexMatcher(icu::RegexPattern* pattern, icu::UnicodeString text); + + private: + bool UpdateLastFindOffset() const; + + std::unique_ptr<icu::RegexMatcher> matcher_; + icu::UnicodeString text_; + mutable int last_find_offset_; + mutable int last_find_offset_codepoints_; + mutable bool last_find_offset_dirty_; + }; + + class RegexPattern { + public: + std::unique_ptr<RegexMatcher> Matcher(const UnicodeText& input) const; + + protected: + friend class UniLib; + explicit RegexPattern(std::unique_ptr<icu::RegexPattern> pattern) + : pattern_(std::move(pattern)) {} + + private: + std::unique_ptr<icu::RegexPattern> pattern_; + }; + + class BreakIterator { + public: + int Next(); + + static constexpr int kDone = -1; + + protected: + friend class UniLib; + explicit BreakIterator(const UnicodeText& text); + + private: + std::unique_ptr<icu::BreakIterator> break_iterator_; + icu::UnicodeString text_; + int last_break_index_; + int last_unicode_index_; + }; + + std::unique_ptr<RegexPattern> CreateRegexPattern( + const UnicodeText& regex) const; + std::unique_ptr<BreakIterator> CreateBreakIterator( + const UnicodeText& text) const; +}; + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_ICU_H_ diff --git a/util/utf8/unilib.h b/util/utf8/unilib.h new file mode 100644 index 0000000..29b4575 --- /dev/null +++ b/util/utf8/unilib.h @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_H_ +#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_H_ + +#include "util/utf8/unilib-icu.h" +#define CREATE_UNILIB_FOR_TESTING const UniLib unilib; + +#endif // LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_H_ diff --git a/util/utf8/unilib_test.cc b/util/utf8/unilib_test.cc new file mode 100644 index 0000000..13b1347 --- /dev/null +++ b/util/utf8/unilib_test.cc @@ -0,0 +1,232 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/utf8/unilib.h" + +#include "util/base/logging.h" +#include "util/utf8/unicodetext.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier2 { +namespace { + +using ::testing::ElementsAre; + +TEST(UniLibTest, CharacterClassesAscii) { + CREATE_UNILIB_FOR_TESTING; + EXPECT_TRUE(unilib.IsOpeningBracket('(')); + EXPECT_TRUE(unilib.IsClosingBracket(')')); + EXPECT_FALSE(unilib.IsWhitespace(')')); + EXPECT_TRUE(unilib.IsWhitespace(' ')); + EXPECT_FALSE(unilib.IsDigit(')')); + EXPECT_TRUE(unilib.IsDigit('0')); + EXPECT_TRUE(unilib.IsDigit('9')); + EXPECT_FALSE(unilib.IsUpper(')')); + EXPECT_TRUE(unilib.IsUpper('A')); + EXPECT_TRUE(unilib.IsUpper('Z')); + EXPECT_EQ(unilib.ToLower('A'), 'a'); + EXPECT_EQ(unilib.ToLower('Z'), 'z'); + EXPECT_EQ(unilib.ToLower(')'), ')'); + EXPECT_EQ(unilib.GetPairedBracket(')'), '('); + EXPECT_EQ(unilib.GetPairedBracket('}'), '{'); +} + +#ifndef LIBTEXTCLASSIFIER_UNILIB_DUMMY +TEST(UniLibTest, CharacterClassesUnicode) { + CREATE_UNILIB_FOR_TESTING; + EXPECT_TRUE(unilib.IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON + EXPECT_TRUE(unilib.IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS + EXPECT_FALSE(unilib.IsWhitespace(0x23F0)); // ALARM CLOCK + EXPECT_TRUE(unilib.IsWhitespace(0x2003)); // EM SPACE + EXPECT_FALSE(unilib.IsDigit(0xA619)); // VAI SYMBOL JONG + EXPECT_TRUE(unilib.IsDigit(0xA620)); // VAI DIGIT ZERO + EXPECT_TRUE(unilib.IsDigit(0xA629)); // VAI DIGIT NINE + EXPECT_FALSE(unilib.IsDigit(0xA62A)); // VAI SYLLABLE NDOLE MA + EXPECT_FALSE(unilib.IsUpper(0x0211)); // SMALL R WITH DOUBLE GRAVE + EXPECT_TRUE(unilib.IsUpper(0x0212)); // CAPITAL R WITH DOUBLE GRAVE + EXPECT_TRUE(unilib.IsUpper(0x0391)); // GREEK CAPITAL ALPHA + EXPECT_TRUE(unilib.IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL + EXPECT_FALSE(unilib.IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS + EXPECT_EQ(unilib.ToLower(0x0391), 0x03B1); // GREEK ALPHA + EXPECT_EQ(unilib.ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA + EXPECT_EQ(unilib.ToLower(0x03C0), 0x03C0); // GREEK SMALL PI + + EXPECT_EQ(unilib.GetPairedBracket(0x0F3C), 0x0F3D); + EXPECT_EQ(unilib.GetPairedBracket(0x0F3D), 0x0F3C); +} +#endif // ndef LIBTEXTCLASSIFIER_UNILIB_DUMMY + +TEST(UniLibTest, RegexInterface) { + CREATE_UNILIB_FOR_TESTING; + const UnicodeText regex_pattern = + UTF8ToUnicodeText("[0-9]+", /*do_copy=*/true); + std::unique_ptr<UniLib::RegexPattern> pattern = + unilib.CreateRegexPattern(regex_pattern); + const UnicodeText input = UTF8ToUnicodeText("hello 0123", /*do_copy=*/false); + int status; + std::unique_ptr<UniLib::RegexMatcher> matcher = pattern->Matcher(input); + TC_LOG(INFO) << matcher->Matches(&status); + TC_LOG(INFO) << matcher->Find(&status); + TC_LOG(INFO) << matcher->Start(0, &status); + TC_LOG(INFO) << matcher->End(0, &status); + TC_LOG(INFO) << matcher->Group(0, &status).size_codepoints(); +} + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST(UniLibTest, Regex) { + CREATE_UNILIB_FOR_TESTING; + + // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to + // test the regex functionality with it to verify we are handling the indices + // correctly. + const UnicodeText regex_pattern = + UTF8ToUnicodeText("[0-9]+😋", /*do_copy=*/false); + std::unique_ptr<UniLib::RegexPattern> pattern = + unilib.CreateRegexPattern(regex_pattern); + int status; + std::unique_ptr<UniLib::RegexMatcher> matcher; + + matcher = pattern->Matcher(UTF8ToUnicodeText("0123😋", /*do_copy=*/false)); + EXPECT_TRUE(matcher->Matches(&status)); + EXPECT_TRUE(matcher->ApproximatelyMatches(&status)); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_TRUE(matcher->Matches(&status)); // Check that the state is reset. + EXPECT_TRUE(matcher->ApproximatelyMatches(&status)); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + + matcher = pattern->Matcher( + UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false)); + EXPECT_FALSE(matcher->Matches(&status)); + EXPECT_FALSE(matcher->ApproximatelyMatches(&status)); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + + matcher = pattern->Matcher( + UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false)); + EXPECT_TRUE(matcher->Find(&status)); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->Start(0, &status), 8); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->End(0, &status), 13); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋"); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST(UniLibTest, RegexGroups) { + CREATE_UNILIB_FOR_TESTING; + + // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to + // test the regex functionality with it to verify we are handling the indices + // correctly. + const UnicodeText regex_pattern = UTF8ToUnicodeText( + "(?<group1>[0-9])(?<group2>[0-9]+)😋", /*do_copy=*/false); + std::unique_ptr<UniLib::RegexPattern> pattern = + unilib.CreateRegexPattern(regex_pattern); + int status; + std::unique_ptr<UniLib::RegexMatcher> matcher; + + matcher = pattern->Matcher( + UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false)); + EXPECT_TRUE(matcher->Find(&status)); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->Start(0, &status), 8); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->Start(1, &status), 8); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->Start(2, &status), 9); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->End(0, &status), 13); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->End(1, &status), 9); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->End(2, &status), 12); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋"); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->Group(1, &status).ToUTF8String(), "0"); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); + EXPECT_EQ(matcher->Group(2, &status).ToUTF8String(), "123"); + EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU + +TEST(UniLibTest, BreakIterator) { + CREATE_UNILIB_FOR_TESTING; + const UnicodeText text = UTF8ToUnicodeText("some text", /*do_copy=*/false); + std::unique_ptr<UniLib::BreakIterator> iterator = + unilib.CreateBreakIterator(text); + std::vector<int> break_indices; + int break_index = 0; + while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) { + break_indices.push_back(break_index); + } + EXPECT_THAT(break_indices, ElementsAre(4, 5, 9)); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST(UniLibTest, BreakIterator4ByteUTF8) { + CREATE_UNILIB_FOR_TESTING; + const UnicodeText text = UTF8ToUnicodeText("😀😂😋", /*do_copy=*/false); + std::unique_ptr<UniLib::BreakIterator> iterator = + unilib.CreateBreakIterator(text); + std::vector<int> break_indices; + int break_index = 0; + while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) { + break_indices.push_back(break_index); + } + EXPECT_THAT(break_indices, ElementsAre(1, 2, 3)); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifndef LIBTEXTCLASSIFIER_UNILIB_JAVAICU +TEST(UniLibTest, IntegerParse) { + CREATE_UNILIB_FOR_TESTING; + int result; + EXPECT_TRUE( + unilib.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), &result)); + EXPECT_EQ(result, 123); +} +#endif // ndef LIBTEXTCLASSIFIER_UNILIB_JAVAICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST(UniLibTest, IntegerParseFullWidth) { + CREATE_UNILIB_FOR_TESTING; + int result; + // The input string here is full width + EXPECT_TRUE(unilib.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), + &result)); + EXPECT_EQ(result, 123); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +TEST(UniLibTest, IntegerParseFullWidthWithAlpha) { + CREATE_UNILIB_FOR_TESTING; + int result; + // The input string here is full width + EXPECT_FALSE(unilib.ParseInt32(UTF8ToUnicodeText("1a3", /*do_copy=*/false), + &result)); +} +#endif // LIBTEXTCLASSIFIER_UNILIB_ICU + +} // namespace +} // namespace libtextclassifier2 diff --git a/zlib-utils.cc b/zlib-utils.cc new file mode 100644 index 0000000..7e6646f --- /dev/null +++ b/zlib-utils.cc @@ -0,0 +1,269 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "zlib-utils.h" + +#include <memory> + +#include "util/base/logging.h" +#include "util/flatbuffers.h" + +namespace libtextclassifier2 { + +std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance() { + std::unique_ptr<ZlibDecompressor> result(new ZlibDecompressor()); + if (!result->initialized_) { + result.reset(); + } + return result; +} + +ZlibDecompressor::ZlibDecompressor() { + memset(&stream_, 0, sizeof(stream_)); + stream_.zalloc = Z_NULL; + stream_.zfree = Z_NULL; + initialized_ = (inflateInit(&stream_) == Z_OK); +} + +ZlibDecompressor::~ZlibDecompressor() { + if (initialized_) { + inflateEnd(&stream_); + } +} + +bool ZlibDecompressor::Decompress(const CompressedBuffer* compressed_buffer, + std::string* out) { + out->resize(compressed_buffer->uncompressed_size()); + stream_.next_in = + reinterpret_cast<const Bytef*>(compressed_buffer->buffer()->Data()); + stream_.avail_in = compressed_buffer->buffer()->Length(); + stream_.next_out = reinterpret_cast<Bytef*>(const_cast<char*>(out->c_str())); + stream_.avail_out = compressed_buffer->uncompressed_size(); + return (inflate(&stream_, Z_SYNC_FLUSH) == Z_OK); +} + +std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance() { + std::unique_ptr<ZlibCompressor> result(new ZlibCompressor()); + if (!result->initialized_) { + result.reset(); + } + return result; +} + +ZlibCompressor::ZlibCompressor(int level, int tmp_buffer_size) { + memset(&stream_, 0, sizeof(stream_)); + stream_.zalloc = Z_NULL; + stream_.zfree = Z_NULL; + buffer_size_ = tmp_buffer_size; + buffer_.reset(new Bytef[buffer_size_]); + initialized_ = (deflateInit(&stream_, level) == Z_OK); +} + +ZlibCompressor::~ZlibCompressor() { deflateEnd(&stream_); } + +void ZlibCompressor::Compress(const std::string& uncompressed_content, + CompressedBufferT* out) { + out->uncompressed_size = uncompressed_content.size(); + out->buffer.clear(); + stream_.next_in = + reinterpret_cast<const Bytef*>(uncompressed_content.c_str()); + stream_.avail_in = uncompressed_content.size(); + stream_.next_out = buffer_.get(); + stream_.avail_out = buffer_size_; + unsigned char* buffer_deflate_start_position = + reinterpret_cast<unsigned char*>(buffer_.get()); + int status; + do { + // Deflate chunk-wise. + // Z_SYNC_FLUSH causes all pending output to be flushed, but doesn't + // reset the compression state. + // As we do not know how big the compressed buffer will be, we compress + // chunk wise and append the flushed content to the output string buffer. + // As we store the uncompressed size, we do not have to do this during + // decompression. + status = deflate(&stream_, Z_SYNC_FLUSH); + unsigned char* buffer_deflate_end_position = + reinterpret_cast<unsigned char*>(stream_.next_out); + if (buffer_deflate_end_position != buffer_deflate_start_position) { + out->buffer.insert(out->buffer.end(), buffer_deflate_start_position, + buffer_deflate_end_position); + stream_.next_out = buffer_deflate_start_position; + stream_.avail_out = buffer_size_; + } else { + break; + } + } while (status == Z_OK); +} + +// Compress rule fields in the model. +bool CompressModel(ModelT* model) { + std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance(); + if (!zlib_compressor) { + TC_LOG(ERROR) << "Cannot compress model."; + return false; + } + + // Compress regex rules. + if (model->regex_model != nullptr) { + for (int i = 0; i < model->regex_model->patterns.size(); i++) { + RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); + pattern->compressed_pattern.reset(new CompressedBufferT); + zlib_compressor->Compress(pattern->pattern, + pattern->compressed_pattern.get()); + pattern->pattern.clear(); + } + } + + // Compress date-time rules. + if (model->datetime_model != nullptr) { + for (int i = 0; i < model->datetime_model->patterns.size(); i++) { + DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); + for (int j = 0; j < pattern->regexes.size(); j++) { + DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); + regex->compressed_pattern.reset(new CompressedBufferT); + zlib_compressor->Compress(regex->pattern, + regex->compressed_pattern.get()); + regex->pattern.clear(); + } + } + for (int i = 0; i < model->datetime_model->extractors.size(); i++) { + DatetimeModelExtractorT* extractor = + model->datetime_model->extractors[i].get(); + extractor->compressed_pattern.reset(new CompressedBufferT); + zlib_compressor->Compress(extractor->pattern, + extractor->compressed_pattern.get()); + extractor->pattern.clear(); + } + } + return true; +} + +namespace { + +bool DecompressBuffer(const CompressedBufferT* compressed_pattern, + ZlibDecompressor* zlib_decompressor, + std::string* uncompressed_pattern) { + std::string packed_pattern = + PackFlatbuffer<CompressedBuffer>(compressed_pattern); + if (!zlib_decompressor->Decompress( + LoadAndVerifyFlatbuffer<CompressedBuffer>(packed_pattern), + uncompressed_pattern)) { + return false; + } + return true; +} + +} // namespace + +bool DecompressModel(ModelT* model) { + std::unique_ptr<ZlibDecompressor> zlib_decompressor = + ZlibDecompressor::Instance(); + if (!zlib_decompressor) { + TC_LOG(ERROR) << "Cannot initialize decompressor."; + return false; + } + + // Decompress regex rules. + if (model->regex_model != nullptr) { + for (int i = 0; i < model->regex_model->patterns.size(); i++) { + RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); + if (!DecompressBuffer(pattern->compressed_pattern.get(), + zlib_decompressor.get(), &pattern->pattern)) { + TC_LOG(ERROR) << "Cannot decompress pattern: " << i; + return false; + } + pattern->compressed_pattern.reset(nullptr); + } + } + + // Decompress date-time rules. + if (model->datetime_model != nullptr) { + for (int i = 0; i < model->datetime_model->patterns.size(); i++) { + DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); + for (int j = 0; j < pattern->regexes.size(); j++) { + DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); + if (!DecompressBuffer(regex->compressed_pattern.get(), + zlib_decompressor.get(), ®ex->pattern)) { + TC_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j; + return false; + } + regex->compressed_pattern.reset(nullptr); + } + } + for (int i = 0; i < model->datetime_model->extractors.size(); i++) { + DatetimeModelExtractorT* extractor = + model->datetime_model->extractors[i].get(); + if (!DecompressBuffer(extractor->compressed_pattern.get(), + zlib_decompressor.get(), &extractor->pattern)) { + TC_LOG(ERROR) << "Cannot decompress pattern: " << i; + return false; + } + extractor->compressed_pattern.reset(nullptr); + } + } + return true; +} + +std::string CompressSerializedModel(const std::string& model) { + std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str()); + TC_CHECK(unpacked_model != nullptr); + TC_CHECK(CompressModel(unpacked_model.get())); + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); +} + +std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern( + const UniLib& unilib, const flatbuffers::String* uncompressed_pattern, + const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor, + std::string* result_pattern_text) { + UnicodeText unicode_regex_pattern; + std::string decompressed_pattern; + if (compressed_pattern != nullptr && + compressed_pattern->buffer() != nullptr) { + if (decompressor == nullptr || + !decompressor->Decompress(compressed_pattern, &decompressed_pattern)) { + TC_LOG(ERROR) << "Cannot decompress pattern."; + return nullptr; + } + unicode_regex_pattern = + UTF8ToUnicodeText(decompressed_pattern.data(), + decompressed_pattern.size(), /*do_copy=*/false); + } else { + if (uncompressed_pattern == nullptr) { + TC_LOG(ERROR) << "Cannot load uncompressed pattern."; + return nullptr; + } + unicode_regex_pattern = + UTF8ToUnicodeText(uncompressed_pattern->c_str(), + uncompressed_pattern->Length(), /*do_copy=*/false); + } + + if (result_pattern_text != nullptr) { + *result_pattern_text = unicode_regex_pattern.ToUTF8String(); + } + + std::unique_ptr<UniLib::RegexPattern> regex_pattern = + unilib.CreateRegexPattern(unicode_regex_pattern); + if (!regex_pattern) { + TC_LOG(ERROR) << "Could not create pattern: " + << unicode_regex_pattern.ToUTF8String(); + } + return regex_pattern; +} + +} // namespace libtextclassifier2 diff --git a/zlib-utils.h b/zlib-utils.h new file mode 100644 index 0000000..136f4d2 --- /dev/null +++ b/zlib-utils.h @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Functions to compress and decompress low entropy entries in the model. + +#ifndef LIBTEXTCLASSIFIER_ZLIB_UTILS_H_ +#define LIBTEXTCLASSIFIER_ZLIB_UTILS_H_ + +#include <memory> + +#include "model_generated.h" +#include "util/utf8/unilib.h" +#include "zlib.h" + +namespace libtextclassifier2 { + +class ZlibDecompressor { + public: + static std::unique_ptr<ZlibDecompressor> Instance(); + ~ZlibDecompressor(); + + bool Decompress(const CompressedBuffer* compressed_buffer, std::string* out); + + private: + ZlibDecompressor(); + z_stream stream_; + bool initialized_; +}; + +class ZlibCompressor { + public: + static std::unique_ptr<ZlibCompressor> Instance(); + ~ZlibCompressor(); + + void Compress(const std::string& uncompressed_content, + CompressedBufferT* out); + + private: + explicit ZlibCompressor(int level = Z_BEST_COMPRESSION, + // Tmp. buffer size was set based on the current set + // of patterns to be compressed. + int tmp_buffer_size = 64 * 1024); + z_stream stream_; + std::unique_ptr<Bytef[]> buffer_; + unsigned int buffer_size_; + bool initialized_; +}; + +// Compresses regex and datetime rules in the model in place. +bool CompressModel(ModelT* model); + +// Decompresses regex and datetime rules in the model in place. +bool DecompressModel(ModelT* model); + +// Compresses regex and datetime rules in the model. +std::string CompressSerializedModel(const std::string& model); + +// Create and compile a regex pattern from optionally compressed pattern. +std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern( + const UniLib& unilib, const flatbuffers::String* uncompressed_pattern, + const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor, + std::string* result_pattern_text = nullptr); + +} // namespace libtextclassifier2 + +#endif // LIBTEXTCLASSIFIER_ZLIB_UTILS_H_ diff --git a/zlib-utils_test.cc b/zlib-utils_test.cc new file mode 100644 index 0000000..155f14f --- /dev/null +++ b/zlib-utils_test.cc @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "zlib-utils.h" + +#include <memory> + +#include "model_generated.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier2 { + +TEST(ZlibUtilsTest, CompressModel) { + ModelT model; + model.regex_model.reset(new RegexModelT); + model.regex_model->patterns.emplace_back(new RegexModel_::PatternT); + model.regex_model->patterns.back()->pattern = "this is a test pattern"; + model.regex_model->patterns.emplace_back(new RegexModel_::PatternT); + model.regex_model->patterns.back()->pattern = "this is a second test pattern"; + + model.datetime_model.reset(new DatetimeModelT); + model.datetime_model->patterns.emplace_back(new DatetimeModelPatternT); + model.datetime_model->patterns.back()->regexes.emplace_back( + new DatetimeModelPattern_::RegexT); + model.datetime_model->patterns.back()->regexes.back()->pattern = + "an example datetime pattern"; + model.datetime_model->extractors.emplace_back(new DatetimeModelExtractorT); + model.datetime_model->extractors.back()->pattern = + "an example datetime extractor"; + + // Compress the model. + EXPECT_TRUE(CompressModel(&model)); + + // Sanity check that uncompressed field is removed. + EXPECT_TRUE(model.regex_model->patterns[0]->pattern.empty()); + EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty()); + EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty()); + EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty()); + + // Pack and load the model. + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, &model)); + const Model* compressed_model = + GetModel(reinterpret_cast<const char*>(builder.GetBufferPointer())); + ASSERT_TRUE(compressed_model != nullptr); + + // Decompress the fields again and check that they match the original. + std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); + ASSERT_TRUE(decompressor != nullptr); + std::string uncompressed_pattern; + EXPECT_TRUE(decompressor->Decompress( + compressed_model->regex_model()->patterns()->Get(0)->compressed_pattern(), + &uncompressed_pattern)); + EXPECT_EQ(uncompressed_pattern, "this is a test pattern"); + EXPECT_TRUE(decompressor->Decompress( + compressed_model->regex_model()->patterns()->Get(1)->compressed_pattern(), + &uncompressed_pattern)); + EXPECT_EQ(uncompressed_pattern, "this is a second test pattern"); + EXPECT_TRUE(decompressor->Decompress(compressed_model->datetime_model() + ->patterns() + ->Get(0) + ->regexes() + ->Get(0) + ->compressed_pattern(), + &uncompressed_pattern)); + EXPECT_EQ(uncompressed_pattern, "an example datetime pattern"); + EXPECT_TRUE(decompressor->Decompress(compressed_model->datetime_model() + ->extractors() + ->Get(0) + ->compressed_pattern(), + &uncompressed_pattern)); + EXPECT_EQ(uncompressed_pattern, "an example datetime extractor"); + + EXPECT_TRUE(DecompressModel(&model)); + EXPECT_EQ(model.regex_model->patterns[0]->pattern, "this is a test pattern"); + EXPECT_EQ(model.regex_model->patterns[1]->pattern, + "this is a second test pattern"); + EXPECT_EQ(model.datetime_model->patterns[0]->regexes[0]->pattern, + "an example datetime pattern"); + EXPECT_EQ(model.datetime_model->extractors[0]->pattern, + "an example datetime extractor"); +} + +} // namespace libtextclassifier2 |