summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChang Li <licha@google.com>2022-09-12 20:03:37 +0000
committerPresubmit Automerger Backend <android-build-presubmit-automerger-backend@system.gserviceaccount.com>2022-09-12 20:03:37 +0000
commite2dcf61bd2d7840c76017b01ec8b0447107b6c90 (patch)
treead92f308c0f448078a0c0657d5411958cdecf160
parent32414a370da93be06d9bb77aaf70e290f8ca6eb2 (diff)
parent1075b1e4e39ab4af90deb3758e5631943c07d47e (diff)
downloadlibtextclassifier-e2dcf61bd2d7840c76017b01ec8b0447107b6c90.tar.gz
[automerge] Export external/libtextclassifier to AOSP. 2p: 1075b1e4e3
Original change: https://googleplex-android-review.googlesource.com/c/platform/external/libtextclassifier/+/19932204 Bug: 187927611 Change-Id: I314d2899800a28315764686529d6b1d7e03556c6
-rw-r--r--native/actions/actions-suggestions.cc34
-rw-r--r--native/actions/actions-suggestions.h17
-rw-r--r--native/actions/actions-suggestions_test.cc21
-rw-r--r--native/actions/actions_model.fbs13
-rw-r--r--native/actions/test_data/actions_suggestions_grammar_test.modelbin145632 -> 145616 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.modelbin3384992 -> 3385008 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_9heads.modelbin3866944 -> 3866880 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.modelbin3808128 -> 3808080 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.modelbin0 -> 10192720 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.modelbin3848464 -> 3848720 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.modelbin4667328 -> 4667088 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.modelbin5035360 -> 5035952 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.sensitive_tflite.modelbin7106352 -> 7106288 bytes
-rw-r--r--notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java25
-rw-r--r--notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java5
15 files changed, 93 insertions, 22 deletions
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index 9f9a8d4..eeeb508 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -21,6 +21,8 @@
#include <vector>
#include "utils/base/statusor.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/random/random.h"
#if !defined(TC3_DISABLE_LUA)
#include "actions/lua-actions.h"
@@ -42,6 +44,7 @@
#include "utils/strings/utf8.h"
#include "utils/utf8/unicodetext.h"
#include "absl/container/flat_hash_set.h"
+#include "absl/random/distributions.h"
#include "tensorflow/lite/string_util.h"
namespace libtextclassifier3 {
@@ -813,6 +816,8 @@ void ActionsSuggestions::PopulateTextReplies(
const tflite::Interpreter* interpreter, int suggestion_index,
int score_index, const std::string& type, float priority_score,
const absl::flat_hash_set<std::string>& blocklist,
+ const absl::flat_hash_map<std::string, std::vector<std::string>>&
+ concept_mappings,
ActionsSuggestionsResponse* response) const {
const std::vector<tflite::StringRef> replies =
model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
@@ -831,6 +836,12 @@ void ActionsSuggestions::PopulateTextReplies(
if (blocklist.contains(response_text)) {
continue;
}
+ if (concept_mappings.contains(response_text)) {
+ const int candidates_size = concept_mappings.at(response_text).size();
+ const int candidate_index = absl::Uniform<int>(
+ absl::IntervalOpenOpen, bit_gen_, 0, candidates_size);
+ response_text = concept_mappings.at(response_text)[candidate_index];
+ }
response->actions.push_back({response_text, type, score, priority_score});
}
@@ -918,11 +929,11 @@ bool ActionsSuggestions::ReadModelOutput(
if (!response->output_filtered_min_triggering_score &&
model_->tflite_model_spec()->output_replies() >= 0) {
absl::flat_hash_set<std::string> empty_blocklist;
- PopulateTextReplies(interpreter,
- model_->tflite_model_spec()->output_replies(),
- model_->tflite_model_spec()->output_replies_scores(),
- model_->smart_reply_action_type()->str(),
- /* priority_score */ 0.0, empty_blocklist, response);
+ PopulateTextReplies(
+ interpreter, model_->tflite_model_spec()->output_replies(),
+ model_->tflite_model_spec()->output_replies_scores(),
+ model_->smart_reply_action_type()->str(),
+ /* priority_score */ 0.0, empty_blocklist, {}, response);
}
// Read actions suggestions.
@@ -961,6 +972,8 @@ bool ActionsSuggestions::ReadModelOutput(
const int suggestions_scores_index =
metadata->output_suggestions_scores();
absl::flat_hash_set<std::string> response_text_blocklist;
+ absl::flat_hash_map<std::string, std::vector<std::string>>
+ concept_mappings;
switch (metadata->prediction_type()) {
case PredictionType_NEXT_MESSAGE_PREDICTION:
if (!task_spec || task_spec->type()->size() == 0) {
@@ -973,13 +986,22 @@ bool ActionsSuggestions::ReadModelOutput(
response_text_blocklist.insert(val->str());
}
}
+ if (task_spec->concept_mappings()) {
+ for (const auto& concept : *task_spec->concept_mappings()) {
+ std::vector<std::string> candidates;
+ for (const auto& candidate : *concept->candidates()) {
+ candidates.push_back(candidate->str());
+ }
+ concept_mappings[concept->concept_name()->str()] = candidates;
+ }
+ }
}
PopulateTextReplies(
interpreter, suggestions_index, suggestions_scores_index,
task_spec ? task_spec->type()->str()
: model_->smart_reply_action_type()->str(),
task_spec ? task_spec->priority_score() : 0.0,
- response_text_blocklist, response);
+ response_text_blocklist, concept_mappings, response);
break;
case PredictionType_INTENT_TRIGGERING:
PopulateIntentTriggering(interpreter, suggestions_index,
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 87f55fb..c3d58e4 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -43,7 +43,9 @@
#include "utils/utf8/unilib.h"
#include "utils/variant.h"
#include "utils/zlib/zlib.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
+#include "absl/random/random.h"
namespace libtextclassifier3 {
@@ -175,11 +177,13 @@ class ActionsSuggestions {
void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
ActionSuggestion* suggestion) const;
- void PopulateTextReplies(const tflite::Interpreter* interpreter,
- int suggestion_index, int score_index,
- const std::string& type, float priority_score,
- const absl::flat_hash_set<std::string>& blocklist,
- ActionsSuggestionsResponse* response) const;
+ void PopulateTextReplies(
+ const tflite::Interpreter* interpreter, int suggestion_index,
+ int score_index, const std::string& type, float priority_score,
+ const absl::flat_hash_set<std::string>& blocklist,
+ const absl::flat_hash_map<std::string, std::vector<std::string>>&
+ concept_mappings,
+ ActionsSuggestionsResponse* response) const;
void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
int suggestion_index, int score_index,
@@ -273,6 +277,9 @@ class ActionsSuggestions {
// Conversation intent detection model for additional actions.
std::unique_ptr<const ConversationIntentDetection>
conversation_intent_detection_;
+
+ // Used for randomly selecting candidates.
+ mutable absl::BitGen bit_gen_;
};
// Interprets the buffer as a Model flatbuffer and returns it for reading.
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index b51ebc7..65f9796 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -61,6 +61,8 @@ constexpr char kMultiTaskSrP13nModelFileName[] =
"actions_suggestions_test.multi_task_sr_p13n.model";
constexpr char kMultiTaskSrEmojiModelFileName[] =
"actions_suggestions_test.multi_task_sr_emoji.model";
+constexpr char kMultiTaskSrEmojiConceptModelFileName[] =
+ "actions_suggestions_test.multi_task_sr_emoji_concept.model";
constexpr char kSensitiveTFliteModelFileName[] =
"actions_suggestions_test.sensitive_tflite.model";
constexpr char kLiveRelayTFLiteModelFileName[] =
@@ -1835,6 +1837,25 @@ TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) {
EXPECT_EQ(response.actions[2].type, "text_reply");
}
+TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelUsesConcepts) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kMultiTaskSrEmojiConceptModelFileName);
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "i am tired",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ std::vector<std::string> sigh_emojis = {"😔", "😞"};
+
+ EXPECT_TRUE(std::find(sigh_emojis.begin(), sigh_emojis.end(),
+ response.actions[0].response_text) !=
+ sigh_emojis.end());
+ EXPECT_EQ(response.actions[0].type, "emoji_reply");
+}
+
TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kLiveRelayTFLiteModelFileName);
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 0d8c7ad..70f9104 100644
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -312,6 +312,15 @@ table TriggeringPreconditions {
min_reply_score_threshold:float = 0;
}
+// This proto handles model outputs that are concepts, such as emoji concept
+// suggestion models. Each concept maps to a list of candidates. One of
+// the candidates is chosen randomly as the final suggestion.
+namespace libtextclassifier3;
+table ActionConceptToSuggestion {
+ concept_name:string (shared);
+ candidates:[string];
+}
+
namespace libtextclassifier3;
table ActionSuggestionSpec {
// Type of the action suggestion.
@@ -331,6 +340,10 @@ table ActionSuggestionSpec {
entity_data:ActionsEntityData;
response_text_blocklist:[string];
+
+ // If provided, map the response as concept to one of the corresponding
+ // candidates.
+ concept_mappings:[ActionConceptToSuggestion];
}
// Options to specify triggering behaviour per action class.
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index a44bfe6..6d7bdb0 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index d262953..88f62eb 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index 96fd5ef..40a2409 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
index b77d2b5..effb2cb 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model
new file mode 100644
index 0000000..18333d6
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index ad9b684..e41ab39 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
index 7fa095a..5314b43 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
index 33bb389..a633742 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
index 11b7524..6685d26 100644
--- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
+++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
Binary files differ
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
index 9429b29..28f947b 100644
--- a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
@@ -96,16 +96,11 @@ public class SmartSuggestionsHelper {
oldSession.destroy();
}
};
- private final TextClassificationContext textClassificationContext;
public SmartSuggestionsHelper(Context context, SmartSuggestionsConfig config) {
this.context = context;
textClassificationManager = this.context.getSystemService(TextClassificationManager.class);
this.config = config;
- this.textClassificationContext =
- new TextClassificationContext.Builder(
- context.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION)
- .build();
}
/**
@@ -170,7 +165,10 @@ public class SmartSuggestionsHelper {
} else {
SmartSuggestionsLogSession session =
new SmartSuggestionsLogSession(
- resultId, repliesScore, textClassifier, textClassificationContext);
+ resultId,
+ repliesScore,
+ textClassifier,
+ getTextClassificationContext(statusBarNotification));
session.onSuggestionsGenerated(conversationActions);
// Store the session if we expect more logging from it, destroy it otherwise.
@@ -302,7 +300,11 @@ public class SmartSuggestionsHelper {
.setTypeConfig(typeConfigBuilder.build())
.build();
- TextClassifier textClassifier = createTextClassificationSession();
+ TextClassifier textClassifier =
+ textClassificationManager.createTextClassificationSession(
+ getTextClassificationContext(statusBarNotification));
+ onTextClassificationSessionCreated();
+
return new SuggestConversationActionsResult(
Optional.of(textClassifier), textClassifier.suggestConversationActions(request));
}
@@ -477,8 +479,13 @@ public class SmartSuggestionsHelper {
}
@VisibleForTesting
- TextClassifier createTextClassificationSession() {
- return textClassificationManager.createTextClassificationSession(textClassificationContext);
+ void onTextClassificationSessionCreated() {}
+
+ private static TextClassificationContext getTextClassificationContext(
+ StatusBarNotification statusBarNotification) {
+ return new TextClassificationContext.Builder(
+ statusBarNotification.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION)
+ .build();
}
private static boolean arePersonsEqual(Person left, Person right) {
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
index 84cf4fb..9354819 100644
--- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
@@ -86,9 +86,8 @@ public class SmartSuggestionsHelperTest {
}
@Override
- TextClassifier createTextClassificationSession() {
+ void onTextClassificationSessionCreated() {
numOfSessionsCreated += 1;
- return super.createTextClassificationSession();
}
int getNumOfSessionsCreated() {
@@ -260,9 +259,11 @@ public class SmartSuggestionsHelperTest {
assertThat(firstEvent.getEntityTypes())
.asList()
.containsExactly(ConversationAction.TYPE_TEXT_REPLY, ConversationAction.TYPE_OPEN_URL);
+ assertThat(firstEvent.getEventContext().getPackageName()).isEqualTo(PACKAGE_NAME);
TextClassifierEvent secondEvent = textClassifierEvents.get(1);
assertThat(secondEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION);
assertThat(secondEvent.getEntityTypes()[0]).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
+ assertThat(secondEvent.getEventContext().getPackageName()).isEqualTo(PACKAGE_NAME);
}
@Test