diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-11-09 09:17:22 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-11-09 09:17:22 +0000 |
commit | 5e8124d7db4c92436101aeec8fa1a905470d60dd (patch) | |
tree | 101bfc731fc453160421993180e07d015e40bc56 | |
parent | 01d8b5cc8dd6a4ae2eab6e8b3d32b76e6f271c63 (diff) | |
parent | 1075b1e4e39ab4af90deb3758e5631943c07d47e (diff) | |
download | libtextclassifier-5e8124d7db4c92436101aeec8fa1a905470d60dd.tar.gz |
Snap for 9271768 from 1075b1e4e39ab4af90deb3758e5631943c07d47e to mainline-sdkext-releaseaml_sdk_331310010
Change-Id: If025cebdc7cb7150ee7d84fc3bcf200b2c556d46
15 files changed, 93 insertions, 22 deletions
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc index 9f9a8d4..eeeb508 100644 --- a/native/actions/actions-suggestions.cc +++ b/native/actions/actions-suggestions.cc @@ -21,6 +21,8 @@ #include <vector> #include "utils/base/statusor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/random/random.h" #if !defined(TC3_DISABLE_LUA) #include "actions/lua-actions.h" @@ -42,6 +44,7 @@ #include "utils/strings/utf8.h" #include "utils/utf8/unicodetext.h" #include "absl/container/flat_hash_set.h" +#include "absl/random/distributions.h" #include "tensorflow/lite/string_util.h" namespace libtextclassifier3 { @@ -813,6 +816,8 @@ void ActionsSuggestions::PopulateTextReplies( const tflite::Interpreter* interpreter, int suggestion_index, int score_index, const std::string& type, float priority_score, const absl::flat_hash_set<std::string>& blocklist, + const absl::flat_hash_map<std::string, std::vector<std::string>>& + concept_mappings, ActionsSuggestionsResponse* response) const { const std::vector<tflite::StringRef> replies = model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter); @@ -831,6 +836,12 @@ void ActionsSuggestions::PopulateTextReplies( if (blocklist.contains(response_text)) { continue; } + if (concept_mappings.contains(response_text)) { + const int candidates_size = concept_mappings.at(response_text).size(); + const int candidate_index = absl::Uniform<int>( + absl::IntervalOpenOpen, bit_gen_, 0, candidates_size); + response_text = concept_mappings.at(response_text)[candidate_index]; + } response->actions.push_back({response_text, type, score, priority_score}); } @@ -918,11 +929,11 @@ bool ActionsSuggestions::ReadModelOutput( if (!response->output_filtered_min_triggering_score && model_->tflite_model_spec()->output_replies() >= 0) { absl::flat_hash_set<std::string> empty_blocklist; - PopulateTextReplies(interpreter, - model_->tflite_model_spec()->output_replies(), - model_->tflite_model_spec()->output_replies_scores(), - model_->smart_reply_action_type()->str(), - /* priority_score */ 0.0, empty_blocklist, response); + PopulateTextReplies( + interpreter, model_->tflite_model_spec()->output_replies(), + model_->tflite_model_spec()->output_replies_scores(), + model_->smart_reply_action_type()->str(), + /* priority_score */ 0.0, empty_blocklist, {}, response); } // Read actions suggestions. @@ -961,6 +972,8 @@ bool ActionsSuggestions::ReadModelOutput( const int suggestions_scores_index = metadata->output_suggestions_scores(); absl::flat_hash_set<std::string> response_text_blocklist; + absl::flat_hash_map<std::string, std::vector<std::string>> + concept_mappings; switch (metadata->prediction_type()) { case PredictionType_NEXT_MESSAGE_PREDICTION: if (!task_spec || task_spec->type()->size() == 0) { @@ -973,13 +986,22 @@ bool ActionsSuggestions::ReadModelOutput( response_text_blocklist.insert(val->str()); } } + if (task_spec->concept_mappings()) { + for (const auto& concept : *task_spec->concept_mappings()) { + std::vector<std::string> candidates; + for (const auto& candidate : *concept->candidates()) { + candidates.push_back(candidate->str()); + } + concept_mappings[concept->concept_name()->str()] = candidates; + } + } } PopulateTextReplies( interpreter, suggestions_index, suggestions_scores_index, task_spec ? task_spec->type()->str() : model_->smart_reply_action_type()->str(), task_spec ? task_spec->priority_score() : 0.0, - response_text_blocklist, response); + response_text_blocklist, concept_mappings, response); break; case PredictionType_INTENT_TRIGGERING: PopulateIntentTriggering(interpreter, suggestions_index, diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h index 87f55fb..c3d58e4 100644 --- a/native/actions/actions-suggestions.h +++ b/native/actions/actions-suggestions.h @@ -43,7 +43,9 @@ #include "utils/utf8/unilib.h" #include "utils/variant.h" #include "utils/zlib/zlib.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/random/random.h" namespace libtextclassifier3 { @@ -175,11 +177,13 @@ class ActionsSuggestions { void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const; - void PopulateTextReplies(const tflite::Interpreter* interpreter, - int suggestion_index, int score_index, - const std::string& type, float priority_score, - const absl::flat_hash_set<std::string>& blocklist, - ActionsSuggestionsResponse* response) const; + void PopulateTextReplies( + const tflite::Interpreter* interpreter, int suggestion_index, + int score_index, const std::string& type, float priority_score, + const absl::flat_hash_set<std::string>& blocklist, + const absl::flat_hash_map<std::string, std::vector<std::string>>& + concept_mappings, + ActionsSuggestionsResponse* response) const; void PopulateIntentTriggering(const tflite::Interpreter* interpreter, int suggestion_index, int score_index, @@ -273,6 +277,9 @@ class ActionsSuggestions { // Conversation intent detection model for additional actions. std::unique_ptr<const ConversationIntentDetection> conversation_intent_detection_; + + // Used for randomly selecting candidates. + mutable absl::BitGen bit_gen_; }; // Interprets the buffer as a Model flatbuffer and returns it for reading. diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc index b51ebc7..65f9796 100644 --- a/native/actions/actions-suggestions_test.cc +++ b/native/actions/actions-suggestions_test.cc @@ -61,6 +61,8 @@ constexpr char kMultiTaskSrP13nModelFileName[] = "actions_suggestions_test.multi_task_sr_p13n.model"; constexpr char kMultiTaskSrEmojiModelFileName[] = "actions_suggestions_test.multi_task_sr_emoji.model"; +constexpr char kMultiTaskSrEmojiConceptModelFileName[] = + "actions_suggestions_test.multi_task_sr_emoji_concept.model"; constexpr char kSensitiveTFliteModelFileName[] = "actions_suggestions_test.sensitive_tflite.model"; constexpr char kLiveRelayTFLiteModelFileName[] = @@ -1835,6 +1837,25 @@ TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) { EXPECT_EQ(response.actions[2].type, "text_reply"); } +TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelUsesConcepts) { + std::unique_ptr<ActionsSuggestions> actions_suggestions = + LoadTestModel(kMultiTaskSrEmojiConceptModelFileName); + + const ActionsSuggestionsResponse response = + actions_suggestions->SuggestActions( + {{{/*user_id=*/1, "i am tired", + /*reference_time_ms_utc=*/0, + /*reference_timezone=*/"Europe/Zurich", + /*annotations=*/{}, + /*locales=*/"en"}}}); + std::vector<std::string> sigh_emojis = {"😔", "😞"}; + + EXPECT_TRUE(std::find(sigh_emojis.begin(), sigh_emojis.end(), + response.actions[0].response_text) != + sigh_emojis.end()); + EXPECT_EQ(response.actions[0].type, "emoji_reply"); +} + TEST_F(ActionsSuggestionsTest, LiveRelayModel) { std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(kLiveRelayTFLiteModelFileName); diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs index 0d8c7ad..70f9104 100644 --- a/native/actions/actions_model.fbs +++ b/native/actions/actions_model.fbs @@ -312,6 +312,15 @@ table TriggeringPreconditions { min_reply_score_threshold:float = 0; } +// This proto handles model outputs that are concepts, such as emoji concept +// suggestion models. Each concept maps to a list of candidates. One of +// the candidates is chosen randomly as the final suggestion. +namespace libtextclassifier3; +table ActionConceptToSuggestion { + concept_name:string (shared); + candidates:[string]; +} + namespace libtextclassifier3; table ActionSuggestionSpec { // Type of the action suggestion. @@ -331,6 +340,10 @@ table ActionSuggestionSpec { entity_data:ActionsEntityData; response_text_blocklist:[string]; + + // If provided, map the response as concept to one of the corresponding + // candidates. + concept_mappings:[ActionConceptToSuggestion]; } // Options to specify triggering behaviour per action class. diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model Binary files differindex a44bfe6..6d7bdb0 100644 --- a/native/actions/test_data/actions_suggestions_grammar_test.model +++ b/native/actions/test_data/actions_suggestions_grammar_test.model diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model Binary files differindex d262953..88f62eb 100644 --- a/native/actions/test_data/actions_suggestions_test.model +++ b/native/actions/test_data/actions_suggestions_test.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model Binary files differindex 96fd5ef..40a2409 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model Binary files differindex b77d2b5..effb2cb 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model Binary files differnew file mode 100644 index 0000000..18333d6 --- /dev/null +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model Binary files differindex ad9b684..e41ab39 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model Binary files differindex 7fa095a..5314b43 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model Binary files differindex 33bb389..a633742 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model Binary files differindex 11b7524..6685d26 100644 --- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model +++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java index 9429b29..28f947b 100644 --- a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java +++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java @@ -96,16 +96,11 @@ public class SmartSuggestionsHelper { oldSession.destroy(); } }; - private final TextClassificationContext textClassificationContext; public SmartSuggestionsHelper(Context context, SmartSuggestionsConfig config) { this.context = context; textClassificationManager = this.context.getSystemService(TextClassificationManager.class); this.config = config; - this.textClassificationContext = - new TextClassificationContext.Builder( - context.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION) - .build(); } /** @@ -170,7 +165,10 @@ public class SmartSuggestionsHelper { } else { SmartSuggestionsLogSession session = new SmartSuggestionsLogSession( - resultId, repliesScore, textClassifier, textClassificationContext); + resultId, + repliesScore, + textClassifier, + getTextClassificationContext(statusBarNotification)); session.onSuggestionsGenerated(conversationActions); // Store the session if we expect more logging from it, destroy it otherwise. @@ -302,7 +300,11 @@ public class SmartSuggestionsHelper { .setTypeConfig(typeConfigBuilder.build()) .build(); - TextClassifier textClassifier = createTextClassificationSession(); + TextClassifier textClassifier = + textClassificationManager.createTextClassificationSession( + getTextClassificationContext(statusBarNotification)); + onTextClassificationSessionCreated(); + return new SuggestConversationActionsResult( Optional.of(textClassifier), textClassifier.suggestConversationActions(request)); } @@ -477,8 +479,13 @@ public class SmartSuggestionsHelper { } @VisibleForTesting - TextClassifier createTextClassificationSession() { - return textClassificationManager.createTextClassificationSession(textClassificationContext); + void onTextClassificationSessionCreated() {} + + private static TextClassificationContext getTextClassificationContext( + StatusBarNotification statusBarNotification) { + return new TextClassificationContext.Builder( + statusBarNotification.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION) + .build(); } private static boolean arePersonsEqual(Person left, Person right) { diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java index 84cf4fb..9354819 100644 --- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java +++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java @@ -86,9 +86,8 @@ public class SmartSuggestionsHelperTest { } @Override - TextClassifier createTextClassificationSession() { + void onTextClassificationSessionCreated() { numOfSessionsCreated += 1; - return super.createTextClassificationSession(); } int getNumOfSessionsCreated() { @@ -260,9 +259,11 @@ public class SmartSuggestionsHelperTest { assertThat(firstEvent.getEntityTypes()) .asList() .containsExactly(ConversationAction.TYPE_TEXT_REPLY, ConversationAction.TYPE_OPEN_URL); + assertThat(firstEvent.getEventContext().getPackageName()).isEqualTo(PACKAGE_NAME); TextClassifierEvent secondEvent = textClassifierEvents.get(1); assertThat(secondEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION); assertThat(secondEvent.getEntityTypes()[0]).isEqualTo(ConversationAction.TYPE_TEXT_REPLY); + assertThat(secondEvent.getEventContext().getPackageName()).isEqualTo(PACKAGE_NAME); } @Test |