summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorandroid-build-team Robot <android-build-team-robot@google.com>2019-10-28 15:35:39 +0000
committerandroid-build-team Robot <android-build-team-robot@google.com>2019-10-28 15:35:39 +0000
commit126fcf3f9e17f6e87c261611df0a0ab3e3d9b01d (patch)
tree768c65f44cb8326b9d6d0228cc3405edfbee3787
parent32f4be9b671894fc67ad3fe18b130957545ec9a8 (diff)
parent832fd3280739abfad08cd7fb0ecdb481c84c561a (diff)
downloadlibtextclassifier-android10-d4-release.tar.gz
Snap for 5970985 from 832fd3280739abfad08cd7fb0ecdb481c84c561a to qt-d4-releaseandroid-10.0.0_r45android-10.0.0_r44android-10.0.0_r43android-10.0.0_r42android10-d4-s1-releaseandroid10-d4-release
Change-Id: I37a8a92d92850af9826242a877102df7120b7d3d
-rw-r--r--lang_id/lang-id.cc48
-rw-r--r--lang_id/lang-id_jni.cc12
2 files changed, 48 insertions, 12 deletions
diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc
index c892329..1339223 100644
--- a/lang_id/lang-id.cc
+++ b/lang_id/lang-id.cc
@@ -100,8 +100,17 @@ class LangIdImpl {
return LangId::kUnknownLanguageCode;
}
+ // Create a Sentence storing the input text.
+ LightSentence sentence;
+ tokenizer_.Tokenize(text, &sentence);
+
+ // Test input size here, after pre-processing removed irrelevant chars.
+ if (IsTooShort(sentence)) {
+ return LangId::kUnknownLanguageCode;
+ }
+
std::vector<float> scores;
- ComputeScores(text, &scores);
+ ComputeScores(&sentence, &scores);
int prediction_id = GetArgMax(scores);
const string language = GetLanguageForSoftmaxLabel(prediction_id);
@@ -133,8 +142,18 @@ class LangIdImpl {
return;
}
+ // Create a Sentence storing the input text.
+ LightSentence sentence;
+ tokenizer_.Tokenize(text, &sentence);
+
+ // Test input size here, after pre-processing removed irrelevant chars.
+ if (IsTooShort(sentence)) {
+ result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
+ return;
+ }
+
std::vector<float> scores;
- ComputeScores(text, &scores);
+ ComputeScores(&sentence, &scores);
// Compute and sort softmax in descending order by probability and convert
// IDs to language code strings. When probabilities are equal, we sort by
@@ -173,6 +192,8 @@ class LangIdImpl {
bool Setup(TaskContext *context) {
tokenizer_.Setup(context);
if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
+
+ min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0);
default_threshold_ =
context->Get("reliability_thresh", kDefaultConfidenceThreshold);
@@ -203,13 +224,9 @@ class LangIdImpl {
// network, and computes the output scores (activations from the last layer).
// These scores can be used to compute the softmax probabilities for our
// labels (in this case, the languages).
- void ComputeScores(StringPiece text, std::vector<float> *scores) const {
- // Create a Sentence storing the input text.
- LightSentence sentence;
- tokenizer_.Tokenize(text, &sentence);
-
+ void ComputeScores(LightSentence* sentence, std::vector<float> *scores) const {
std::vector<FeatureVector> features =
- lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
+ lang_id_brain_interface_.GetFeaturesNoCaching(sentence);
// Run feed-forward neural network to compute scores.
network_->ComputeFinalScores(features, scores);
@@ -227,6 +244,16 @@ class LangIdImpl {
}
}
+ bool IsTooShort(const LightSentence &sentence) const {
+ int text_size = 0;
+ for (const std::string &token : sentence) {
+ // Each token has the form ^...$: we subtract 2 because we want to count
+ // only the real text, not the chars added by us.
+ text_size += token.size() - 2;
+ }
+ return text_size < min_text_size_in_bytes_;
+ }
+
std::unique_ptr<ModelProvider> model_provider_;
TokenizerForLangId tokenizer_;
@@ -240,6 +267,11 @@ class LangIdImpl {
// True if this object is ready to perform language predictions.
bool valid_ = false;
+ // The model returns LangId::kUnknownLanguageCode for input text that has
+ // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces,
+ // digits, and punctuation).
+ int min_text_size_in_bytes_ = 0;
+
// Only predictions with a probability (confidence) above this threshold are
// reported. Otherwise, we report LangId::kUnknownLanguageCode.
float default_threshold_ = kDefaultConfidenceThreshold;
diff --git a/lang_id/lang-id_jni.cc b/lang_id/lang-id_jni.cc
index 6696298..61547e5 100644
--- a/lang_id/lang-id_jni.cc
+++ b/lang_id/lang-id_jni.cc
@@ -44,10 +44,14 @@ jobjectArray LangIdResultToJObjectArray(JNIEnv* env,
return nullptr;
}
- // clang-format off
- const std::vector<std::pair<std::string, float>>& predictions =
- lang_id_result.predictions;
- // clang-format on
+ std::vector<std::pair<std::string, float>> predictions;
+ std::copy_if(lang_id_result.predictions.begin(),
+ lang_id_result.predictions.end(),
+ std::back_inserter(predictions),
+ [](std::pair<std::string, float> pair) {
+ return pair.first != "und";
+ });
+
const jmethodID result_class_constructor =
env->GetMethodID(result_class.get(), "<init>", "(Ljava/lang/String;F)V");
const jobjectArray results =