diff options
author | android-build-team Robot <android-build-team-robot@google.com> | 2019-10-28 15:35:39 +0000 |
---|---|---|
committer | android-build-team Robot <android-build-team-robot@google.com> | 2019-10-28 15:35:39 +0000 |
commit | 126fcf3f9e17f6e87c261611df0a0ab3e3d9b01d (patch) | |
tree | 768c65f44cb8326b9d6d0228cc3405edfbee3787 | |
parent | 32f4be9b671894fc67ad3fe18b130957545ec9a8 (diff) | |
parent | 832fd3280739abfad08cd7fb0ecdb481c84c561a (diff) | |
download | libtextclassifier-android10-d4-release.tar.gz |
Snap for 5970985 from 832fd3280739abfad08cd7fb0ecdb481c84c561a to qt-d4-releaseandroid-10.0.0_r45android-10.0.0_r44android-10.0.0_r43android-10.0.0_r42android10-d4-s1-releaseandroid10-d4-release
Change-Id: I37a8a92d92850af9826242a877102df7120b7d3d
-rw-r--r-- | lang_id/lang-id.cc | 48 | ||||
-rw-r--r-- | lang_id/lang-id_jni.cc | 12 |
2 files changed, 48 insertions, 12 deletions
diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc index c892329..1339223 100644 --- a/lang_id/lang-id.cc +++ b/lang_id/lang-id.cc @@ -100,8 +100,17 @@ class LangIdImpl { return LangId::kUnknownLanguageCode; } + // Create a Sentence storing the input text. + LightSentence sentence; + tokenizer_.Tokenize(text, &sentence); + + // Test input size here, after pre-processing removed irrelevant chars. + if (IsTooShort(sentence)) { + return LangId::kUnknownLanguageCode; + } + std::vector<float> scores; - ComputeScores(text, &scores); + ComputeScores(&sentence, &scores); int prediction_id = GetArgMax(scores); const string language = GetLanguageForSoftmaxLabel(prediction_id); @@ -133,8 +142,18 @@ class LangIdImpl { return; } + // Create a Sentence storing the input text. + LightSentence sentence; + tokenizer_.Tokenize(text, &sentence); + + // Test input size here, after pre-processing removed irrelevant chars. + if (IsTooShort(sentence)) { + result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1); + return; + } + std::vector<float> scores; - ComputeScores(text, &scores); + ComputeScores(&sentence, &scores); // Compute and sort softmax in descending order by probability and convert // IDs to language code strings. When probabilities are equal, we sort by @@ -173,6 +192,8 @@ class LangIdImpl { bool Setup(TaskContext *context) { tokenizer_.Setup(context); if (!lang_id_brain_interface_.SetupForProcessing(context)) return false; + + min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0); default_threshold_ = context->Get("reliability_thresh", kDefaultConfidenceThreshold); @@ -203,13 +224,9 @@ class LangIdImpl { // network, and computes the output scores (activations from the last layer). // These scores can be used to compute the softmax probabilities for our // labels (in this case, the languages). - void ComputeScores(StringPiece text, std::vector<float> *scores) const { - // Create a Sentence storing the input text. - LightSentence sentence; - tokenizer_.Tokenize(text, &sentence); - + void ComputeScores(LightSentence* sentence, std::vector<float> *scores) const { std::vector<FeatureVector> features = - lang_id_brain_interface_.GetFeaturesNoCaching(&sentence); + lang_id_brain_interface_.GetFeaturesNoCaching(sentence); // Run feed-forward neural network to compute scores. network_->ComputeFinalScores(features, scores); @@ -227,6 +244,16 @@ class LangIdImpl { } } + bool IsTooShort(const LightSentence &sentence) const { + int text_size = 0; + for (const std::string &token : sentence) { + // Each token has the form ^...$: we subtract 2 because we want to count + // only the real text, not the chars added by us. + text_size += token.size() - 2; + } + return text_size < min_text_size_in_bytes_; + } + std::unique_ptr<ModelProvider> model_provider_; TokenizerForLangId tokenizer_; @@ -240,6 +267,11 @@ class LangIdImpl { // True if this object is ready to perform language predictions. bool valid_ = false; + // The model returns LangId::kUnknownLanguageCode for input text that has + // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces, + // digits, and punctuation). + int min_text_size_in_bytes_ = 0; + // Only predictions with a probability (confidence) above this threshold are // reported. Otherwise, we report LangId::kUnknownLanguageCode. float default_threshold_ = kDefaultConfidenceThreshold; diff --git a/lang_id/lang-id_jni.cc b/lang_id/lang-id_jni.cc index 6696298..61547e5 100644 --- a/lang_id/lang-id_jni.cc +++ b/lang_id/lang-id_jni.cc @@ -44,10 +44,14 @@ jobjectArray LangIdResultToJObjectArray(JNIEnv* env, return nullptr; } - // clang-format off - const std::vector<std::pair<std::string, float>>& predictions = - lang_id_result.predictions; - // clang-format on + std::vector<std::pair<std::string, float>> predictions; + std::copy_if(lang_id_result.predictions.begin(), + lang_id_result.predictions.end(), + std::back_inserter(predictions), + [](std::pair<std::string, float> pair) { + return pair.first != "und"; + }); + const jmethodID result_class_constructor = env->GetMethodID(result_class.get(), "<init>", "(Ljava/lang/String;F)V"); const jobjectArray results = |