diff options
author | Xin Li <delphij@google.com> | 2019-10-30 11:48:05 -0700 |
---|---|---|
committer | Xin Li <delphij@google.com> | 2019-10-30 11:48:05 -0700 |
commit | 4fbf2a638a67c3ae368a25797149d4b59742271e (patch) | |
tree | 2dfdfd492f12e83f577d8c94eb1e2cadf4b3202f | |
parent | 13252da94c05513899c8949a8305e995056d1664 (diff) | |
parent | 12973dcf3b8eab3792cf3cf23980168523ddb73d (diff) | |
download | libtextclassifier-temp_b_145570283.tar.gz |
DO NOT MERGE - qt-qpr1-dev-plus-aosp-without-vendor@5915889 into stage-aosp-mastertemp_b_145570283
Bug: 142003500
Change-Id: I1d3af7ee8f417596366e8b70bf5fd84184717028
-rw-r--r-- | lang_id/lang-id.cc | 48 | ||||
-rw-r--r-- | lang_id/lang-id_jni.cc | 12 |
2 files changed, 48 insertions, 12 deletions
diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc index c892329..1339223 100644 --- a/lang_id/lang-id.cc +++ b/lang_id/lang-id.cc @@ -100,8 +100,17 @@ class LangIdImpl { return LangId::kUnknownLanguageCode; } + // Create a Sentence storing the input text. + LightSentence sentence; + tokenizer_.Tokenize(text, &sentence); + + // Test input size here, after pre-processing removed irrelevant chars. + if (IsTooShort(sentence)) { + return LangId::kUnknownLanguageCode; + } + std::vector<float> scores; - ComputeScores(text, &scores); + ComputeScores(&sentence, &scores); int prediction_id = GetArgMax(scores); const string language = GetLanguageForSoftmaxLabel(prediction_id); @@ -133,8 +142,18 @@ class LangIdImpl { return; } + // Create a Sentence storing the input text. + LightSentence sentence; + tokenizer_.Tokenize(text, &sentence); + + // Test input size here, after pre-processing removed irrelevant chars. + if (IsTooShort(sentence)) { + result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1); + return; + } + std::vector<float> scores; - ComputeScores(text, &scores); + ComputeScores(&sentence, &scores); // Compute and sort softmax in descending order by probability and convert // IDs to language code strings. When probabilities are equal, we sort by @@ -173,6 +192,8 @@ class LangIdImpl { bool Setup(TaskContext *context) { tokenizer_.Setup(context); if (!lang_id_brain_interface_.SetupForProcessing(context)) return false; + + min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0); default_threshold_ = context->Get("reliability_thresh", kDefaultConfidenceThreshold); @@ -203,13 +224,9 @@ class LangIdImpl { // network, and computes the output scores (activations from the last layer). // These scores can be used to compute the softmax probabilities for our // labels (in this case, the languages). - void ComputeScores(StringPiece text, std::vector<float> *scores) const { - // Create a Sentence storing the input text. - LightSentence sentence; - tokenizer_.Tokenize(text, &sentence); - + void ComputeScores(LightSentence* sentence, std::vector<float> *scores) const { std::vector<FeatureVector> features = - lang_id_brain_interface_.GetFeaturesNoCaching(&sentence); + lang_id_brain_interface_.GetFeaturesNoCaching(sentence); // Run feed-forward neural network to compute scores. network_->ComputeFinalScores(features, scores); @@ -227,6 +244,16 @@ class LangIdImpl { } } + bool IsTooShort(const LightSentence &sentence) const { + int text_size = 0; + for (const std::string &token : sentence) { + // Each token has the form ^...$: we subtract 2 because we want to count + // only the real text, not the chars added by us. + text_size += token.size() - 2; + } + return text_size < min_text_size_in_bytes_; + } + std::unique_ptr<ModelProvider> model_provider_; TokenizerForLangId tokenizer_; @@ -240,6 +267,11 @@ class LangIdImpl { // True if this object is ready to perform language predictions. bool valid_ = false; + // The model returns LangId::kUnknownLanguageCode for input text that has + // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces, + // digits, and punctuation). + int min_text_size_in_bytes_ = 0; + // Only predictions with a probability (confidence) above this threshold are // reported. Otherwise, we report LangId::kUnknownLanguageCode. float default_threshold_ = kDefaultConfidenceThreshold; diff --git a/lang_id/lang-id_jni.cc b/lang_id/lang-id_jni.cc index 6696298..61547e5 100644 --- a/lang_id/lang-id_jni.cc +++ b/lang_id/lang-id_jni.cc @@ -44,10 +44,14 @@ jobjectArray LangIdResultToJObjectArray(JNIEnv* env, return nullptr; } - // clang-format off - const std::vector<std::pair<std::string, float>>& predictions = - lang_id_result.predictions; - // clang-format on + std::vector<std::pair<std::string, float>> predictions; + std::copy_if(lang_id_result.predictions.begin(), + lang_id_result.predictions.end(), + std::back_inserter(predictions), + [](std::pair<std::string, float> pair) { + return pair.first != "und"; + }); + const jmethodID result_class_constructor = env->GetMethodID(result_class.get(), "<init>", "(Ljava/lang/String;F)V"); const jobjectArray results = |