summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-03-15 18:59:59 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-03-15 18:59:59 +0000
commit3d5da76ab70029db2810b2d7b7611bafdc258c50 (patch)
tree3a3d2ccd661e94baac725672eb79068f9778a135
parentf106c46253e7ec42e5d39ab1b3fa3ada443917b2 (diff)
parent8ebbedca8443b38941a7ddadc8245fcc83c6f866 (diff)
downloadlibtextclassifier-android12-mainline-sdkext-release.tar.gz
Snap for 8303596 from 8ebbedca8443b38941a7ddadc8245fcc83c6f866 to mainline-sdkext-releaseandroid-mainline-12.0.0_r109aml_sdk_311710000android12-mainline-sdkext-release
Change-Id: I5ae59fe3453aa2ba4fe57f826c1f51bae4092d3f
-rw-r--r--java/src/com/android/textclassifier/ExtrasUtils.java4
-rw-r--r--java/src/com/android/textclassifier/ModelFileManagerImpl.java13
-rw-r--r--java/src/com/android/textclassifier/common/logging/ResultIdUtils.java4
-rw-r--r--java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java4
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloadManager.java179
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java8
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java10
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloaderService.java2
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java2
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java44
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java7
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java29
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java10
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java171
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java7
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java6
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java61
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java185
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java8
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java108
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java58
-rw-r--r--native/actions/actions-entity-data.bfbsbin880 -> 888 bytes
-rw-r--r--native/actions/actions-entity-data.fbs2
-rw-r--r--native/actions/actions-suggestions.cc29
-rw-r--r--native/actions/actions-suggestions.h4
-rw-r--r--native/actions/actions-suggestions_test.cc29
-rw-r--r--native/actions/actions_model.fbs16
-rw-r--r--native/actions/ranker.cc74
-rw-r--r--native/actions/ranker_test.cc96
-rw-r--r--native/actions/test_data/actions_suggestions_grammar_test.modelbin145160 -> 145176 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.modelbin3387328 -> 3387360 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_9heads.modelbin3874528 -> 3874704 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.modelbin3808528 -> 3812304 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.modelbin3853520 -> 3853520 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.modelbin4671808 -> 4671840 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.modelbin5045280 -> 5045408 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.sensitive_tflite.modelbin7111552 -> 7111552 bytes
-rw-r--r--native/annotator/annotator.cc79
-rw-r--r--native/annotator/datetime/datetime-grounder.cc9
-rw-r--r--native/annotator/datetime/extractor.cc11
-rw-r--r--native/annotator/datetime/regex-parser.cc31
-rw-r--r--native/annotator/translate/translate.cc11
-rw-r--r--native/lang_id/common/embedding-network.cc2
-rw-r--r--native/lang_id/common/fel/feature-extractor.cc1
-rw-r--r--native/lang_id/common/fel/workspace.cc1
-rw-r--r--native/lang_id/common/fel/workspace.h1
-rw-r--r--native/lang_id/common/file/mmap.cc2
-rw-r--r--native/lang_id/common/lite_strings/str-split.cc2
-rw-r--r--native/lang_id/common/math/softmax.cc1
-rw-r--r--native/lang_id/fb_model/lang-id-from-fb.cc2
-rw-r--r--native/lang_id/fb_model/model-provider-from-fb.cc2
-rw-r--r--native/lang_id/lang-id.cc1
-rw-r--r--native/utils/codepoint-range.cc9
-rw-r--r--native/utils/grammar/parsing/parser.cc17
-rw-r--r--native/utils/grammar/utils/ir.cc35
-rw-r--r--native/utils/grammar/utils/locale-shard-map.cc4
-rw-r--r--native/utils/testing/test_data_generator.h13
-rw-r--r--native/utils/tflite-model-executor.cc14
-rw-r--r--native/utils/tflite/encoder_common.cc11
-rw-r--r--native/utils/tflite/text_encoder3s.cc243
-rw-r--r--native/utils/tflite/text_encoder3s.h35
-rw-r--r--native/utils/tokenfree/byte_encoder.cc42
-rw-r--r--native/utils/tokenfree/byte_encoder.h37
-rw-r--r--native/utils/tokenfree/byte_encoder_test.cc51
-rw-r--r--native/utils/tokenizer.cc11
-rw-r--r--notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java8
66 files changed, 1273 insertions, 583 deletions
diff --git a/java/src/com/android/textclassifier/ExtrasUtils.java b/java/src/com/android/textclassifier/ExtrasUtils.java
index fd64581..bde3898 100644
--- a/java/src/com/android/textclassifier/ExtrasUtils.java
+++ b/java/src/com/android/textclassifier/ExtrasUtils.java
@@ -87,7 +87,9 @@ public final class ExtrasUtils {
return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
}
- /** @see #getTopLanguage(Intent) */
+ /**
+ * @see #getTopLanguage(Intent)
+ */
static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
final int maxSize = Math.min(3, languageScores.getEntities().size());
final String[] languages =
diff --git a/java/src/com/android/textclassifier/ModelFileManagerImpl.java b/java/src/com/android/textclassifier/ModelFileManagerImpl.java
index 45426d0..e3b646f 100644
--- a/java/src/com/android/textclassifier/ModelFileManagerImpl.java
+++ b/java/src/com/android/textclassifier/ModelFileManagerImpl.java
@@ -390,7 +390,18 @@ final class ModelFileManagerImpl implements ModelFileManager {
localePreferences.get(0),
targetLocale));
}
- return findBestModelFile(modelType, targetLocale);
+ ModelFile modelFile = findBestModelFile(modelType, targetLocale);
+ TcLog.d(
+ TAG,
+ String.format(
+ Locale.US,
+ "findBestModelFile: best model: %s; localePreferences: %s; detectedLocales: %s;"
+ + " targetLocale: %s",
+ modelFile,
+ localePreferences,
+ detectedLocales,
+ targetLocale));
+ return modelFile;
}
/**
diff --git a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
index dae0442..67e300d 100644
--- a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
+++ b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
@@ -66,8 +66,8 @@ public final class ResultIdUtils {
}
/** Returns if the result id was generated from the default text classifier. */
- public static boolean isFromDefaultTextClassifier(String resultId) {
- return resultId.startsWith(CLASSIFIER_ID + '|');
+ public static boolean isFromDefaultTextClassifier(@Nullable String resultId) {
+ return resultId != null && resultId.startsWith(CLASSIFIER_ID + '|');
}
/** Returns all the model names encoded in the signature. */
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
index 1ae79ce..9bdfb5e 100644
--- a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
@@ -195,7 +195,7 @@ public final class DownloadedModelManagerImpl implements DownloadedModelManager
@Override
public void onDownloadCompleted(
ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload) {
- TcLog.v(TAG, "Start to clean up models and update model lookup cache...");
+ TcLog.d(TAG, "Start to clean up models and update model lookup cache...");
// Step 1: Clean up ManifestEnrollment table
List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments();
List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>();
@@ -286,7 +286,7 @@ public final class DownloadedModelManagerImpl implements DownloadedModelManager
// Clear the cache table and rebuild the cache based on ModelView table
private void updateCache() {
synchronized (cacheLock) {
- TcLog.v(TAG, "Updating model lookup cache...");
+ TcLog.d(TAG, "Updating model lookup cache...");
for (String modelType : ModelType.values()) {
modelLookupCache.get(modelType).clear();
}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
index b125f13..af33e21 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
@@ -44,6 +44,7 @@ import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Enums;
import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
import com.google.common.hash.Hashing;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
@@ -54,6 +55,7 @@ import java.time.Instant;
import java.util.List;
import java.util.Locale;
import java.util.UUID;
+import java.util.concurrent.Callable;
import javax.annotation.Nullable;
/** Manager to listen to config update and download latest models. */
@@ -64,6 +66,7 @@ public final class ModelDownloadManager {
private final Context appContext;
private final Class<? extends ListenableWorker> modelDownloadWorkerClass;
+ private final Callable<WorkManager> workManagerSupplier;
private final DownloadedModelManager downloadedModelManager;
private final TextClassifierSettings settings;
private final ListeningExecutorService executorService;
@@ -84,6 +87,7 @@ public final class ModelDownloadManager {
this(
appContext,
ModelDownloadWorker.class,
+ () -> WorkManager.getInstance(appContext),
DownloadedModelManagerImpl.getInstance(appContext),
settings,
executorService);
@@ -93,11 +97,13 @@ public final class ModelDownloadManager {
public ModelDownloadManager(
Context appContext,
Class<? extends ListenableWorker> modelDownloadWorkerClass,
+ Callable<WorkManager> workManagerSupplier,
DownloadedModelManager downloadedModelManager,
TextClassifierSettings settings,
ListeningExecutorService executorService) {
this.appContext = Preconditions.checkNotNull(appContext);
this.modelDownloadWorkerClass = Preconditions.checkNotNull(modelDownloadWorkerClass);
+ this.workManagerSupplier = Preconditions.checkNotNull(workManagerSupplier);
this.downloadedModelManager = Preconditions.checkNotNull(downloadedModelManager);
this.settings = Preconditions.checkNotNull(settings);
this.executorService = Preconditions.checkNotNull(executorService);
@@ -121,22 +127,31 @@ public final class ModelDownloadManager {
/** Returns the downlaoded models for the given modelType. */
@Nullable
public List<File> listDownloadedModels(@ModelTypeDef String modelType) {
- return downloadedModelManager.listModels(modelType);
+ try {
+ return downloadedModelManager.listModels(modelType);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed to list downloaded models", t);
+ return ImmutableList.of();
+ }
}
/** Notifies the model downlaoder that the text classifier service is created. */
public void onTextClassifierServiceCreated() {
- DeviceConfig.addOnPropertiesChangedListener(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener);
- appContext.registerReceiver(
- localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
- TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered.");
- if (!settings.isModelDownloadManagerEnabled()) {
- return;
+ try {
+ DeviceConfig.addOnPropertiesChangedListener(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener);
+ appContext.registerReceiver(
+ localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
+ TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered.");
+ if (!settings.isModelDownloadManagerEnabled()) {
+ return;
+ }
+ maybeOverrideLocaleListForTesting();
+ TcLog.d(TAG, "Try to schedule model download work because TextClassifierService started.");
+ scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed inside onTextClassifierServiceCreated", t);
}
- maybeOverrideLocaleListForTesting();
- TcLog.v(TAG, "Try to schedule model download work because TextClassifierService started.");
- scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED);
}
// TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
@@ -146,8 +161,12 @@ public final class ModelDownloadManager {
if (!settings.isModelDownloadManagerEnabled()) {
return;
}
- TcLog.v(TAG, "Try to schedule model download work because of system locale changes.");
- scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED);
+ TcLog.d(TAG, "Try to schedule model download work because of system locale changes.");
+ try {
+ scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed inside onLocaleChanged", t);
+ }
}
// TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
@@ -157,16 +176,24 @@ public final class ModelDownloadManager {
if (!settings.isModelDownloadManagerEnabled()) {
return;
}
- maybeOverrideLocaleListForTesting();
- TcLog.v(TAG, "Try to schedule model download work because of device config changes.");
- scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED);
+ TcLog.d(TAG, "Try to schedule model download work because of device config changes.");
+ try {
+ maybeOverrideLocaleListForTesting();
+ scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed inside onTextClassifierDeviceConfigChanged", t);
+ }
}
/** Clean up internal states on destroying. */
public void destroy() {
- DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener);
- appContext.unregisterReceiver(localeChangedReceiver);
- TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager");
+ try {
+ DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener);
+ appContext.unregisterReceiver(localeChangedReceiver);
+ TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager");
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed to destroy ModelDownloadManager", t);
+ }
}
/**
@@ -178,10 +205,14 @@ public final class ModelDownloadManager {
if (!settings.isModelDownloadManagerEnabled()) {
return;
}
- printWriter.println("ModelDownloadManager:");
- printWriter.increaseIndent();
- downloadedModelManager.dump(printWriter);
- printWriter.decreaseIndent();
+ try {
+ printWriter.println("ModelDownloadManager:");
+ printWriter.increaseIndent();
+ downloadedModelManager.dump(printWriter);
+ printWriter.decreaseIndent();
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed to dump ModelDownloadManager", t);
+ }
}
/**
@@ -193,54 +224,62 @@ public final class ModelDownloadManager {
private void scheduleDownloadWork(int reasonToSchedule) {
long workId =
Hashing.farmHashFingerprint64().hashUnencodedChars(UUID.randomUUID().toString()).asLong();
- NetworkType networkType =
- Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
- .or(NetworkType.UNMETERED);
- OneTimeWorkRequest downloadRequest =
- new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
- .setConstraints(
- new Constraints.Builder()
- .setRequiredNetworkType(networkType)
- .setRequiresBatteryNotLow(true)
- .setRequiresStorageNotLow(true)
- .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle())
- .setRequiresCharging(settings.getManifestDownloadRequiresCharging())
- .build())
- .setBackoffCriteria(
- BackoffPolicy.EXPONENTIAL,
- settings.getModelDownloadBackoffDelayInMillis(),
- MILLISECONDS)
- .setInputData(
- new Data.Builder()
- .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId)
- .putLong(
- ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP,
- Instant.now().toEpochMilli())
- .build())
- .build();
- ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture =
- WorkManager.getInstance(appContext)
- .enqueueUniqueWork(
- UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest)
- .getResult();
- Futures.addCallback(
- enqueueResultFuture,
- new FutureCallback<Operation.State.SUCCESS>() {
- @Override
- public void onSuccess(Operation.State.SUCCESS unused) {
- TcLog.v(TAG, "Download work scheduled.");
- TextClassifierDownloadLogger.downloadWorkScheduled(
- workId, reasonToSchedule, /* failedToSchedule= */ false);
- }
+ try {
+ NetworkType networkType =
+ Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
+ .or(NetworkType.UNMETERED);
+ OneTimeWorkRequest downloadRequest =
+ new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
+ .setConstraints(
+ new Constraints.Builder()
+ .setRequiredNetworkType(networkType)
+ .setRequiresBatteryNotLow(true)
+ .setRequiresStorageNotLow(true)
+ .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle())
+ .setRequiresCharging(settings.getManifestDownloadRequiresCharging())
+ .build())
+ .setBackoffCriteria(
+ BackoffPolicy.EXPONENTIAL,
+ settings.getModelDownloadBackoffDelayInMillis(),
+ MILLISECONDS)
+ .setInputData(
+ new Data.Builder()
+ .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId)
+ .putLong(
+ ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP,
+ Instant.now().toEpochMilli())
+ .build())
+ .build();
+ ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture =
+ workManagerSupplier
+ .call()
+ .enqueueUniqueWork(
+ UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest)
+ .getResult();
+ Futures.addCallback(
+ enqueueResultFuture,
+ new FutureCallback<Operation.State.SUCCESS>() {
+ @Override
+ public void onSuccess(Operation.State.SUCCESS unused) {
+ TcLog.d(TAG, "Download work scheduled.");
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ workId, reasonToSchedule, /* failedToSchedule= */ false);
+ }
- @Override
- public void onFailure(Throwable t) {
- TcLog.e(TAG, "Failed to schedule download work: ", t);
- TextClassifierDownloadLogger.downloadWorkScheduled(
- workId, reasonToSchedule, /* failedToSchedule= */ true);
- }
- },
- executorService);
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "Failed to schedule download work: ", t);
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ workId, reasonToSchedule, /* failedToSchedule= */ true);
+ }
+ },
+ executorService);
+ } catch (Throwable t) {
+ // TODO(licha): this is just for temporary fix. Refactor the try-catch in the future.
+ TcLog.e(TAG, "Failed to schedule download work: ", t);
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ workId, reasonToSchedule, /* failedToSchedule= */ true);
+ }
}
private void maybeOverrideLocaleListForTesting() {
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
index 6e04e16..3db0815 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
@@ -113,6 +113,7 @@ public final class ModelDownloadWorker extends ListenableWorker {
@Override
public final ListenableFuture<ListenableWorker.Result> startWork() {
+ TcLog.d(TAG, "Start download work...");
workStartedTimeMillis = getCurrentTimeMillis();
// Notice: startWork() is invoked on the main thread
if (!settings.isModelDownloadManagerEnabled()) {
@@ -121,7 +122,6 @@ public final class ModelDownloadWorker extends ListenableWorker {
TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED);
return Futures.immediateFuture(ListenableWorker.Result.failure());
}
- TcLog.v(TAG, "Start download work...");
if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) {
TcLog.d(TAG, "Max attempt reached. Abort download work.");
logDownloadWorkCompleted(
@@ -134,7 +134,7 @@ public final class ModelDownloadWorker extends ListenableWorker {
downloadResult -> {
Preconditions.checkNotNull(manifestsToDownload);
downloadedModelManager.onDownloadCompleted(manifestsToDownload);
- TcLog.v(TAG, "Download work completed: " + downloadResult);
+ TcLog.d(TAG, "Download work completed: " + downloadResult);
if (downloadResult.failureCount() == 0) {
logDownloadWorkCompleted(
downloadResult.successCount() > 0
@@ -239,7 +239,7 @@ public final class ModelDownloadWorker extends ListenableWorker {
return Futures.whenAllComplete(downloadResultFutures)
.call(
() -> {
- TcLog.v(TAG, "All Download Tasks Completed");
+ TcLog.d(TAG, "All Download Tasks Completed");
int successCount = 0;
int failureCount = 0;
for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) {
@@ -333,7 +333,7 @@ public final class ModelDownloadWorker extends ListenableWorker {
Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
if (downloadedManifest != null
&& downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) {
- TcLog.v(TAG, "Manifest already downloaded: " + manifestUrl);
+ TcLog.d(TAG, "Manifest already downloaded: " + manifestUrl);
return Futures.immediateVoidFuture();
}
if (pendingDownloads.containsKey(manifestUrl)) {
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
index 2244e9a..0b76f22 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
@@ -99,7 +99,7 @@ final class ModelDownloaderImpl implements ModelDownloader {
new FutureCallback<File>() {
@Override
public void onSuccess(File pendingModelFile) {
- TcLog.v(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath());
+ TcLog.d(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath());
}
@Override
@@ -170,11 +170,11 @@ final class ModelDownloaderImpl implements ModelDownloader {
} catch (IOException e) {
throw new ModelDownloadException(ModelDownloadException.FAILED_TO_VALIDATE_MODEL, e);
}
- TcLog.v(TAG, "Pending model file passed validation.");
+ TcLog.d(TAG, "Pending model file passed validation.");
}
private ListenableFuture<IModelDownloaderService> connect(DownloaderServiceConnection conn) {
- TcLog.v(TAG, "Starting a new connection to ModelDownloaderService");
+ TcLog.d(TAG, "Starting a new connection to ModelDownloaderService");
return CallbackToFutureAdapter.getFuture(
completer -> {
conn.attachCompleter(completer);
@@ -197,7 +197,7 @@ final class ModelDownloaderImpl implements ModelDownloader {
// restult future will hang there until time out. (WorkManager forces a 10-min running time.)
private static ListenableFuture<Long> scheduleDownload(
IModelDownloaderService service, URI uri, File targetFile) {
- TcLog.v(TAG, "Scheduling a new download task with ModelDownloaderService");
+ TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService");
return CallbackToFutureAdapter.getFuture(
completer -> {
service.download(
@@ -236,7 +236,7 @@ final class ModelDownloaderImpl implements ModelDownloader {
@Override
public void onServiceConnected(ComponentName componentName, IBinder iBinder) {
- TcLog.v(TAG, "DownloaderService connected");
+ TcLog.d(TAG, "DownloaderService connected");
completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder)));
}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
index e4ebbfa..6d7e47e 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
@@ -39,7 +39,7 @@ public final class ModelDownloaderService extends Service {
@Override
public IBinder onBind(Intent intent) {
- TcLog.v(TAG, "Binding to ModelDownloadService");
+ TcLog.d(TAG, "Binding to ModelDownloadService");
return iBinder;
}
}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
index 439588b..47e6f19 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
@@ -91,7 +91,7 @@ final class ModelDownloaderServiceImpl extends IModelDownloaderService.Stub {
@Override
public void download(String uri, String targetFilePath, IModelDownloaderCallback callback) {
- TcLog.v(TAG, "Download request received: " + uri);
+ TcLog.d(TAG, "Download request received: " + uri);
try {
File targetFile = new File(targetFilePath);
File tempMetadataFile = getMetadataFile(targetFile);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index 0e3842c..ddab8bd 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -17,7 +17,10 @@
package com.android.textclassifier;
import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import android.content.Context;
import android.os.CancellationSignal;
@@ -39,6 +42,7 @@ import com.android.os.AtomsProto.Atom;
import com.android.os.AtomsProto.TextClassifierApiUsageReported;
import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
+import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.statsd.StatsdTestUtils;
import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
@@ -47,21 +51,27 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
+import java.io.IOException;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class DefaultTextClassifierServiceTest {
+
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
/** A statsd config ID, which is arbitrary. */
private static final long CONFIG_ID = 689777;
@@ -76,14 +86,21 @@ public class DefaultTextClassifierServiceTest {
@Mock private TextClassifierService.Callback<TextLinks> textLinksCallback;
@Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback;
@Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback;
+ @Mock private ModelFileManager testModelFileManager;
@Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
-
- testInjector = new TestInjector(ApplicationProvider.getApplicationContext());
+ public void setup() throws IOException {
+ testInjector =
+ new TestInjector(ApplicationProvider.getApplicationContext(), testModelFileManager);
defaultTextClassifierService = new DefaultTextClassifierService(testInjector);
defaultTextClassifierService.onCreate();
+
+ when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
+ when(testModelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
+ .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
+ when(testModelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
+ .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
}
@Before
@@ -207,11 +224,8 @@ public class DefaultTextClassifierServiceTest {
@Test
public void missingModelFile_onFailureShouldBeCalled() throws Exception {
- testInjector.setModelFileManager(
- new ModelFileManagerImpl(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(),
- testInjector.createTextClassifierSettings()));
+ when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(null);
defaultTextClassifierService.onCreate();
TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build();
@@ -247,12 +261,9 @@ public class DefaultTextClassifierServiceTest {
private final Context context;
private ModelFileManager modelFileManager;
- private TestInjector(Context context) {
+ private TestInjector(Context context, ModelFileManager modelFileManager) {
this.context = Preconditions.checkNotNull(context);
- }
-
- private void setModelFileManager(ModelFileManager modelFileManager) {
- this.modelFileManager = modelFileManager;
+ this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
}
@Override
@@ -263,9 +274,6 @@ public class DefaultTextClassifierServiceTest {
@Override
public ModelFileManager createModelFileManager(
TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) {
- if (modelFileManager == null) {
- return TestDataUtils.createModelFileManagerForTesting(context);
- }
return modelFileManager;
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
index 5297640..0e40515 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
@@ -25,6 +25,7 @@ import android.os.LocaleList;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
+import androidx.work.WorkManager;
import com.android.textclassifier.ModelFileManagerImpl.DownloaderModelsLister;
import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister;
import com.android.textclassifier.ModelFileManagerImpl.RegularFilePatternMatchLister;
@@ -53,7 +54,8 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@SmallTest
@RunWith(AndroidJUnit4.class)
@@ -67,6 +69,7 @@ public final class ModelFileManagerImplTest {
@Mock private DownloadedModelManager downloadedModelManager;
@Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
private File rootTestDir;
private ModelFileManagerImpl modelFileManager;
@@ -75,7 +78,6 @@ public final class ModelFileManagerImplTest {
@Before
public void setup() {
- MockitoAnnotations.initMocks(this);
deviceConfig = new TestingDeviceConfig();
rootTestDir =
new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
@@ -86,6 +88,7 @@ public final class ModelFileManagerImplTest {
new ModelDownloadManager(
context,
ModelDownloadWorker.class,
+ () -> WorkManager.getInstance(context),
downloadedModelManager,
settings,
MoreExecutors.newDirectExecutorService());
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
index bac4fa1..a19e3ff 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -16,12 +16,10 @@
package com.android.textclassifier;
-import android.content.Context;
-import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister;
+import com.android.textclassifier.common.ModelFile;
import com.android.textclassifier.common.ModelType;
-import com.android.textclassifier.common.TextClassifierSettings;
-import com.google.common.collect.ImmutableList;
import java.io.File;
+import java.io.IOException;
/** Utils to access test data files. */
public final class TestDataUtils {
@@ -30,7 +28,7 @@ public final class TestDataUtils {
private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model";
/** Returns the root folder that contains the test data. */
- public static File getTestDataFolder() {
+ private static File getTestDataFolder() {
return new File("/data/local/tmp/TextClassifierServiceTest/");
}
@@ -38,24 +36,25 @@ public final class TestDataUtils {
return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH);
}
+ public static ModelFile getTestAnnotatorModelFileWrapped() throws IOException {
+ return ModelFile.createFromRegularFile(getTestAnnotatorModelFile(), ModelType.ANNOTATOR);
+ }
+
public static File getTestActionsModelFile() {
return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH);
}
+ public static ModelFile getTestActionsModelFileWrapped() throws IOException {
+ return ModelFile.createFromRegularFile(
+ getTestActionsModelFile(), ModelType.ACTIONS_SUGGESTIONS);
+ }
+
public static File getLangIdModelFile() {
return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH);
}
- public static ModelFileManager createModelFileManagerForTesting(Context context) {
- return new ModelFileManagerImpl(
- context,
- ImmutableList.of(
- new RegularFileFullMatchLister(
- ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true),
- new RegularFileFullMatchLister(
- ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true),
- new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)),
- new TextClassifierSettings());
+ public static ModelFile getLangIdModelFileWrapped() throws IOException {
+ return ModelFile.createFromRegularFile(getLangIdModelFile(), ModelType.LANG_ID);
}
private TestDataUtils() {}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
index 42177e6..e7bf90c 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
@@ -56,6 +56,10 @@ public class TextClassifierApiTest {
@Before
public void setup() {
+ extServicesTextClassifierRule.enableVerboseLogging();
+ // Verbose logging only takes effect after restarting ExtServices
+ extServicesTextClassifierRule.forceStopExtServices();
+
textClassifier = extServicesTextClassifierRule.getTextClassifier();
}
@@ -81,8 +85,8 @@ public class TextClassifierApiTest {
@Test
public void classifyText() {
- String text = "Contact me at droid@android.com";
- String classifiedText = "droid@android.com";
+ String text = "Contact me at http://www.android.com";
+ String classifiedText = "http://www.android.com";
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
@@ -90,7 +94,7 @@ public class TextClassifierApiTest {
TextClassification classification = textClassifier.classifyText(request);
assertThat(classification.getEntityCount()).isGreaterThan(0);
- assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL);
+ assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
assertThat(classification.getText()).isEqualTo(classifiedText);
assertThat(classification.getActions()).isNotEmpty();
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index fb1aea8..c20ec8a 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -22,6 +22,9 @@ import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.expectThrows;
@@ -74,30 +77,34 @@ public class TextClassifierImplTest {
private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
private static final String NO_TYPE = null;
- @Mock private ModelFileManagerImpl.ModelFileLister mockModelFileLister;
+ @Mock private ModelFileManager modelFileManager;
- private TextClassifierSettings settings;
private Context context;
private TestingDeviceConfig deviceConfig;
- private TextClassifierImpl classifier;
-
- private final ModelFileManager modelFileManager =
- TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext());
+ private TextClassifierSettings settings;
private LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
+ private TextClassifierImpl classifier;
@Before
- public void setup() {
+ public void setup() throws IOException {
MockitoAnnotations.initMocks(this);
- deviceConfig = new TestingDeviceConfig();
- Context context =
+ this.context =
new FakeContextBuilder()
.setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
.setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
.build();
- this.context = context;
- settings = new TextClassifierSettings(deviceConfig);
- // TODO(veronikanikina): consider using a testing constructor here.
- classifier = new TextClassifierImpl(context, settings, modelFileManager);
+ this.deviceConfig = new TestingDeviceConfig();
+ this.settings = new TextClassifierSettings(deviceConfig);
+ this.annotatorModelCache = new LruCache<>(2);
+ this.classifier =
+ new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache);
+
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
+ when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
+ .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
+ when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
+ .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
}
@Test
@@ -110,9 +117,7 @@ public class TextClassifierImplTest {
int smartStartIndex = text.indexOf(suggested);
int smartEndIndex = smartStartIndex + suggested.length();
TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(
@@ -120,6 +125,24 @@ public class TextClassifierImplTest {
}
@Test
+ public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException {
+ String text = "Contact me at droid@android.com";
+ String selected = "droid";
+ String suggested = "droid@android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ classifier.suggestSelection(null, null, request);
+ verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any());
+ }
+
+ @Test
public void testSuggestSelection_url() throws IOException {
String text = "Visit http://www.android.com for more information";
String selected = "http";
@@ -129,9 +152,7 @@ public class TextClassifierImplTest {
int smartStartIndex = text.indexOf(suggested);
int smartEndIndex = smartStartIndex + suggested.length();
TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
@@ -144,9 +165,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(selected);
int endIndex = startIndex + selected.length();
TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
@@ -160,7 +179,6 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(suggested);
TextSelection.Request request =
new TextSelection.Request.Builder(text, startIndex, /*endIndex=*/ startIndex + 1)
- .setDefaultLocales(LOCALES)
.setIncludeTextClassification(true)
.build();
@@ -178,7 +196,6 @@ public class TextClassifierImplTest {
String text = "Visit http://www.android.com for more information";
TextSelection.Request request =
new TextSelection.Request.Builder(text, /*startIndex=*/ 0, /*endIndex=*/ 4)
- .setDefaultLocales(LOCALES)
.setIncludeTextClassification(false)
.build();
@@ -194,9 +211,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification =
classifier.classifyText(/* sessionId= */ null, null, request);
@@ -210,9 +225,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
@@ -223,9 +236,7 @@ public class TextClassifierImplTest {
public void testClassifyText_address() throws IOException {
String text = "Brandschenkestrasse 110, Zürich, Switzerland";
TextClassification.Request request =
- new TextClassification.Request.Builder(text, 0, text.length())
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, 0, text.length()).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
@@ -238,9 +249,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
@@ -254,9 +263,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
@@ -275,9 +282,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
@@ -289,14 +294,12 @@ public class TextClassifierImplTest {
LocaleList.setDefault(LocaleList.forLanguageTags("en"));
String japaneseText = "これは日本語のテキストです";
TextClassification.Request request =
- new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length())
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build();
TextClassification classification = classifier.classifyText(null, null, request);
RemoteAction translateAction = classification.getActions().get(0);
assertEquals(1, classification.getActions().size());
- assertEquals("Translate", translateAction.getTitle().toString());
+ assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction());
assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
@@ -323,18 +326,17 @@ public class TextClassifierImplTest {
@Test
public void testGenerateLinks_exclude() throws IOException {
- String text = "You want apple@banana.com. See you tonight!";
+ String text = "The number is +12122537077. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = ImmutableList.of();
- List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
+ List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE);
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
- .setDefaultLocales(LOCALES)
.build();
assertThat(
classifier.generateLinks(null, null, request),
- not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
+ not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
}
@Test
@@ -344,7 +346,6 @@ public class TextClassifierImplTest {
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
- .setDefaultLocales(LOCALES)
.build();
assertThat(
classifier.generateLinks(null, null, request),
@@ -361,7 +362,6 @@ public class TextClassifierImplTest {
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
- .setDefaultLocales(LOCALES)
.build();
assertThat(
classifier.generateLinks(null, null, request),
@@ -573,29 +573,16 @@ public class TextClassifierImplTest {
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
ModelFile annotatorModelB =
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
- String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath();
- ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false);
-
- annotatorModelCache = new LruCache<>(2);
- ModelFileManager modelFileManagerCached =
- new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings);
- TextClassifierImpl textClassifierImpl =
- new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache);
- LocaleList.setDefault(LocaleList.forLanguageTags("en"));
String englishText = "You can reach me on +12122537077.";
String classifiedText = "+12122537077";
TextClassification.Request request =
- new TextClassification.Request.Builder(englishText, 0, englishText.length())
- .setDefaultLocales(LOCALES)
- .build();
-
- when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel));
+ new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
// Check modelFileA v701
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelA));
- TextClassification classificationA = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classificationA = classifier.classifyText(null, null, request);
assertThat(classificationA.getId()).contains("v701");
assertThat(classificationA.getText()).contains(classifiedText);
@@ -609,9 +596,9 @@ public class TextClassifierImplTest {
});
// Check modelFileB v801
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelB));
- TextClassification classificationB = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelB);
+ TextClassification classificationB = classifier.classifyText(null, null, request);
assertThat(classificationB.getId()).contains("v801");
assertThat(classificationB.getText()).contains(classifiedText);
@@ -625,9 +612,9 @@ public class TextClassifierImplTest {
});
// Reload modelFileA v701
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelA));
- TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classificationAcached = classifier.classifyText(null, null, request);
assertThat(classificationAcached.getId()).contains("v701");
assertThat(classificationAcached.getText()).contains(classifiedText);
@@ -651,28 +638,16 @@ public class TextClassifierImplTest {
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
ModelFile annotatorModelB =
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
- String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath();
- ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false);
-
- annotatorModelCache = new LruCache<>(settings.getMultiAnnotatorCacheSize());
- ModelFileManager modelFileManagerCached =
- new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings);
- TextClassifierImpl textClassifierImpl =
- new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache);
- LocaleList.setDefault(LocaleList.forLanguageTags("en"));
+
String englishText = "You can reach me on +12122537077.";
String classifiedText = "+12122537077";
TextClassification.Request request =
- new TextClassification.Request.Builder(englishText, 0, englishText.length())
- .setDefaultLocales(LOCALES)
- .build();
-
- when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel));
+ new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
// Check modelFileA v701
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelA));
- TextClassification classification = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification.getId()).contains("v701");
assertThat(classification.getText()).contains(classifiedText);
@@ -686,9 +661,9 @@ public class TextClassifierImplTest {
});
// Check modelFileB v801
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelB));
- TextClassification classificationB = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelB);
+ TextClassification classificationB = classifier.classifyText(null, null, request);
assertThat(classificationB.getId()).contains("v801");
assertThat(classificationB.getText()).contains(classifiedText);
@@ -702,9 +677,9 @@ public class TextClassifierImplTest {
});
// Reload modelFileA v701
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelA));
- TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classificationAcached = classifier.classifyText(null, null, request);
assertThat(classificationAcached.getId()).contains("v701");
assertThat(classificationAcached.getText()).contains(classifiedText);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
index 216cd5d..3aab211 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
@@ -27,14 +27,18 @@ import com.google.android.textclassifier.NamedVariant;
import com.google.android.textclassifier.RemoteActionTemplate;
import java.util.List;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class TemplateIntentFactoryTest {
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
private static final String TITLE_WITHOUT_ENTITY = "Map";
private static final String TITLE_WITH_ENTITY = "Map NW14D1";
private static final String DESCRIPTION = "Check the map";
@@ -71,7 +75,6 @@ public class TemplateIntentFactoryTest {
@Before
public void setup() {
- MockitoAnnotations.initMocks(this);
templateIntentFactory = new TemplateIntentFactory();
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
index ffd2ee4..3a8fefc 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
@@ -86,8 +86,8 @@ public class StatsdTestUtils {
return ImmutableList.copyOf(
metricsList.stream()
.flatMap(statsLogReport -> statsLogReport.getEventMetrics().getDataList().stream())
- .flatMap(eventMetricData -> backfillAggregatedAtomsinEventMetric(
- eventMetricData).stream())
+ .flatMap(
+ eventMetricData -> backfillAggregatedAtomsinEventMetric(eventMetricData).stream())
.sorted(Comparator.comparing(EventMetricData::getElapsedTimestampNanos))
.map(EventMetricData::getAtom)
.collect(Collectors.toList()));
@@ -136,7 +136,7 @@ public class StatsdTestUtils {
}
private static ImmutableList<EventMetricData> backfillAggregatedAtomsinEventMetric(
- EventMetricData metricData) {
+ EventMetricData metricData) {
if (metricData.hasAtom()) {
return ImmutableList.of(metricData);
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
index c626ed7..9e11c09 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
@@ -46,7 +46,8 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@RunWith(AndroidJUnit4.class)
public final class ModelDownloadManagerTest {
@@ -61,14 +62,16 @@ public final class ModelDownloadManagerTest {
public final TextClassifierDownloadLoggerTestRule loggerTestRule =
new TextClassifierDownloadLoggerTestRule();
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
private TestingDeviceConfig deviceConfig;
private WorkManager workManager;
private ModelDownloadManager downloadManager;
+ private ModelDownloadManager downloadManagerWithBadWorkManager;
@Mock DownloadedModelManager downloadedModelManager;
@Before
public void setUp() {
- MockitoAnnotations.initMocks(this);
Context context = ApplicationProvider.getApplicationContext();
WorkManagerTestInitHelper.initializeTestWorkManager(context);
@@ -78,6 +81,17 @@ public final class ModelDownloadManagerTest {
new ModelDownloadManager(
context,
ModelDownloadWorker.class,
+ () -> workManager,
+ downloadedModelManager,
+ new TextClassifierSettings(deviceConfig),
+ MoreExecutors.newDirectExecutorService());
+ this.downloadManagerWithBadWorkManager =
+ new ModelDownloadManager(
+ context,
+ ModelDownloadWorker.class,
+ () -> {
+ throw new IllegalStateException("WorkManager may fail!");
+ },
downloadedModelManager,
new TextClassifierSettings(deviceConfig),
MoreExecutors.newDirectExecutorService());
@@ -94,7 +108,20 @@ public final class ModelDownloadManagerTest {
}
@Test
+ public void onTextClassifierServiceCreated_workManagerCrashed() throws Exception {
+ assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
+ downloadManagerWithBadWorkManager.onTextClassifierServiceCreated();
+
+ // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.TCS_STARTED);
+ assertThat(atom.getFailedToSchedule()).isTrue();
+ }
+
+ @Test
public void onTextClassifierServiceCreated_requestEnqueued() throws Exception {
+ assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
downloadManager.onTextClassifierServiceCreated();
WorkInfo workInfo =
@@ -102,21 +129,34 @@ public final class ModelDownloadManagerTest {
DownloaderTestUtils.queryWorkInfos(
workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED);
}
@Test
public void onTextClassifierServiceCreated_localeListOverridden() throws Exception {
+ assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr");
downloadManager.onTextClassifierServiceCreated();
assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh"));
assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
+ // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED);
}
@Test
+ public void onLocaleChanged_workManagerCrashed() throws Exception {
+ downloadManagerWithBadWorkManager.onLocaleChanged();
+
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.LOCALE_SETTINGS_CHANGED);
+ assertThat(atom.getFailedToSchedule()).isTrue();
+ }
+
+ @Test
public void onLocaleChanged_requestEnqueued() throws Exception {
downloadManager.onLocaleChanged();
@@ -129,6 +169,16 @@ public final class ModelDownloadManagerTest {
}
@Test
+ public void onTextClassifierDeviceConfigChanged_workManagerCrashed() throws Exception {
+ downloadManagerWithBadWorkManager.onTextClassifierDeviceConfigChanged();
+
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.DEVICE_CONFIG_UPDATED);
+ assertThat(atom.getFailedToSchedule()).isTrue();
+ }
+
+ @Test
public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception {
downloadManager.onTextClassifierDeviceConfigChanged();
@@ -186,6 +236,13 @@ public final class ModelDownloadManagerTest {
assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile);
}
+ @Test
+ public void listDownloadedModels_doNotCrashOnError() throws Exception {
+ when(downloadedModelManager.listModels(MODEL_TYPE)).thenThrow(new IllegalStateException());
+
+ assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).isEmpty();
+ }
+
private void verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule) throws Exception {
TextClassifierDownloadWorkScheduled atom =
Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
index 9f555fc..e261158 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
@@ -18,16 +18,10 @@ package com.android.textclassifier.downloader;
import static com.google.common.truth.Truth.assertThat;
-import android.app.Instrumentation;
-import android.app.UiAutomation;
import android.util.Log;
import android.view.textclassifier.TextClassification;
import android.view.textclassifier.TextClassification.Request;
-import android.view.textclassifier.TextClassifier;
-import androidx.test.platform.app.InstrumentationRegistry;
import com.android.textclassifier.testing.ExtServicesTextClassifierRule;
-import com.android.textclassifier.testing.TestingLocaleListOverrideRule;
-import java.io.IOException;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
@@ -48,170 +42,133 @@ public class ModelDownloaderIntegrationTest {
private static final String V804_EN_TAG = "en_v804";
private static final String V804_RU_TAG = "ru_v804";
private static final String FACTORY_MODEL_TAG = "*";
-
- @Rule
- public final TestingLocaleListOverrideRule testingLocaleListOverrideRule =
- new TestingLocaleListOverrideRule();
+ private static final int ASSERT_MAX_ATTEMPTS = 20;
+ private static final int ASSERT_SLEEP_BEFORE_RETRY_MS = 1000;
@Rule
public final ExtServicesTextClassifierRule extServicesTextClassifierRule =
new ExtServicesTextClassifierRule();
- private TextClassifier textClassifier;
-
@Before
public void setup() throws Exception {
- // Flag overrides below can be overridden by Phenotype sync, which makes this test flaky
- runShellCommand("device_config put textclassifier config_updater_model_enabled false");
- runShellCommand("device_config put textclassifier model_download_manager_enabled true");
- runShellCommand("device_config put textclassifier model_download_backoff_delay_in_millis 5");
-
- textClassifier = extServicesTextClassifierRule.getTextClassifier();
- startExtservicesProcess();
+ extServicesTextClassifierRule.addDeviceConfigOverride("config_updater_model_enabled", "false");
+ extServicesTextClassifierRule.addDeviceConfigOverride("model_download_manager_enabled", "true");
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "model_download_backoff_delay_in_millis", "5");
+ extServicesTextClassifierRule.addDeviceConfigOverride("testing_locale_list_override", "en-US");
+ extServicesTextClassifierRule.overrideDeviceConfig();
+
+ extServicesTextClassifierRule.enableVerboseLogging();
+ // Verbose logging only takes effect after restarting ExtServices
+ extServicesTextClassifierRule.forceStopExtServices();
}
@After
public void tearDown() throws Exception {
- runShellCommand("device_config delete textclassifier manifest_url_annotator_en");
- runShellCommand("device_config delete textclassifier manifest_url_annotator_ru");
- runShellCommand("device_config put textclassifier config_updater_model_enabled true");
- runShellCommand("device_config delete textclassifier multi_language_support_enabled");
- runShellCommand(
- "device_config put textclassifier model_download_backoff_delay_in_millis 3600000");
+ // This is to reset logging/locale_override for ExtServices.
+ extServicesTextClassifierRule.forceStopExtServices();
}
@Test
- public void smokeTest() throws IOException, InterruptedException {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + V804_EN_ANNOTATOR_MANIFEST_URL);
+ public void smokeTest() throws Exception {
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
- assertWithRetries(
- /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveEnglishModel(V804_EN_TAG));
+ assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
}
@Test
- public void downgradeModel() throws IOException, InterruptedException {
+ public void downgradeModel() throws Exception {
// Download an experimental model.
- {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
-
- assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 500,
- () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
- }
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
// Downgrade to an older model.
- {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + V804_EN_ANNOTATOR_MANIFEST_URL);
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
- assertWithRetries(
- /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveEnglishModel(V804_EN_TAG));
- }
+ assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
}
@Test
- public void upgradeModel() throws IOException, InterruptedException {
+ public void upgradeModel() throws Exception {
// Download a model.
- {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + V804_EN_ANNOTATOR_MANIFEST_URL);
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
- assertWithRetries(
- /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveEnglishModel(V804_EN_TAG));
- }
+ assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
// Upgrade to an experimental model.
- {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
-
- assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 500,
- () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
- }
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
}
@Test
- public void clearFlag() throws IOException, InterruptedException {
+ public void clearFlag() throws Exception {
// Download a new model.
- {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
-
- assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 500,
- () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
- }
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
// Revert the flag.
- {
- runShellCommand("device_config delete textclassifier manifest_url_annotator_en");
- // Fallback to use the universal model.
- assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 500,
- () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG));
- }
+ extServicesTextClassifierRule.addDeviceConfigOverride("manifest_url_annotator_en", "");
+ // Fallback to use the universal model.
+ assertWithRetries(
+ () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG));
}
@Test
- public void modelsForMultipleLanguagesDownloaded() throws IOException, InterruptedException {
- runShellCommand("device_config put textclassifier multi_language_support_enabled true");
- testingLocaleListOverrideRule.set("en-US", "ru-RU");
+ public void modelsForMultipleLanguagesDownloaded() throws Exception {
+ extServicesTextClassifierRule.addDeviceConfigOverride("multi_language_support_enabled", "true");
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "testing_locale_list_override", "en-US,ru-RU");
// download en model
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
// download ru model
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_ru "
- + V804_RU_ANNOTATOR_MANIFEST_URL);
- assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 500,
- () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_ru", V804_RU_ANNOTATOR_MANIFEST_URL);
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
- assertWithRetries(/* maxAttempts= */ 10, /* sleepMs= */ 500, this::verifyActiveRussianModel);
+ assertWithRetries(this::verifyActiveRussianModel);
assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 500,
() -> verifyActiveModel(/* text= */ "français", /* expectedVersion= */ FACTORY_MODEL_TAG));
}
- private void assertWithRetries(int maxAttempts, int sleepMs, Runnable assertRunnable)
- throws InterruptedException {
- for (int i = 0; i < maxAttempts; i++) {
+ private void assertWithRetries(Runnable assertRunnable) throws Exception {
+ for (int i = 0; i < ASSERT_MAX_ATTEMPTS; i++) {
try {
+ extServicesTextClassifierRule.overrideDeviceConfig();
assertRunnable.run();
break; // success. Bail out.
} catch (AssertionError ex) {
- if (i == maxAttempts - 1) { // last attempt, give up.
+ if (i == ASSERT_MAX_ATTEMPTS - 1) { // last attempt, give up.
+ extServicesTextClassifierRule.dumpDefaultTextClassifierService();
throw ex;
} else {
- Thread.sleep(sleepMs);
+ Thread.sleep(ASSERT_SLEEP_BEFORE_RETRY_MS);
}
+ } catch (Exception unknownException) {
+ throw unknownException;
}
}
}
private void verifyActiveModel(String text, String expectedVersion) {
TextClassification textClassification =
- textClassifier.classifyText(new Request.Builder(text, 0, text.length()).build());
+ extServicesTextClassifierRule
+ .getTextClassifier()
+ .classifyText(new Request.Builder(text, 0, text.length()).build());
// The result id contains the name of the just used model.
+ Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId());
assertThat(textClassification.getId()).contains(expectedVersion);
}
@@ -222,16 +179,4 @@ public class ModelDownloaderIntegrationTest {
private void verifyActiveRussianModel() {
verifyActiveModel("привет", V804_RU_TAG);
}
-
- private void startExtservicesProcess() {
- // Start the process of ExtServices by sending it a text classifier request.
- textClassifier.classifyText(new TextClassification.Request.Builder("abc", 0, 3).build());
- }
-
- private static void runShellCommand(String cmd) {
- Log.v(TAG, "run shell command: " + cmd);
- Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation();
- UiAutomation uiAutomation = instrumentation.getUiAutomation();
- uiAutomation.executeShellCommand(cmd);
- }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
index eac2af3..76d04e0 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
@@ -37,14 +37,19 @@ import com.google.common.util.concurrent.SettableFuture;
import java.io.File;
import java.net.URI;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@RunWith(JUnit4.class)
public final class ModelDownloaderServiceImplTest {
+
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
private static final long BYTES_WRITTEN = 1L;
private static final String DOWNLOAD_URI =
"https://www.gstatic.com/android/text_classifier/r/v999/en.fb";
@@ -66,7 +71,6 @@ public final class ModelDownloaderServiceImplTest {
@Before
public void setUp() {
- MockitoAnnotations.initMocks(this);
this.targetModelFile =
new File(ApplicationProvider.getApplicationContext().getCacheDir(), "model.fb");
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
index 3ceb47b..5f8247d 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
@@ -20,64 +20,72 @@ import android.app.UiAutomation;
import android.content.pm.PackageManager;
import android.content.pm.PackageManager.NameNotFoundException;
import android.provider.DeviceConfig;
+import android.util.Log;
import android.view.textclassifier.TextClassificationManager;
import android.view.textclassifier.TextClassifier;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.platform.app.InstrumentationRegistry;
+import com.google.common.io.ByteStreams;
+import java.io.FileInputStream;
+import java.io.IOException;
import org.junit.rules.ExternalResource;
/** A rule that manages a text classifier that is backed by the ExtServices. */
public final class ExtServicesTextClassifierRule extends ExternalResource {
+ private static final String TAG = "androidtc";
private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
"textclassifier_service_package_override";
private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services";
private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services";
- private String textClassifierServiceOverrideFlagOldValue;
+ private UiAutomation uiAutomation;
+ private DeviceConfig.Properties originalProperties;
+ private DeviceConfig.Properties.Builder newPropertiesBuilder;
@Override
- protected void before() {
- UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
- try {
- uiAutomation.adoptShellPermissionIdentity();
- textClassifierServiceOverrideFlagOldValue =
- DeviceConfig.getString(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- null);
- DeviceConfig.setProperty(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- getExtServicesPackageName(),
- /* makeDefault= */ false);
- } finally {
- uiAutomation.dropShellPermissionIdentity();
- }
+ protected void before() throws Exception {
+ uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ uiAutomation.adoptShellPermissionIdentity();
+ originalProperties = DeviceConfig.getProperties(DeviceConfig.NAMESPACE_TEXTCLASSIFIER);
+ newPropertiesBuilder =
+ new DeviceConfig.Properties.Builder(DeviceConfig.NAMESPACE_TEXTCLASSIFIER)
+ .setString(
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, getExtServicesPackageName());
+ overrideDeviceConfig();
}
@Override
protected void after() {
- UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
try {
- uiAutomation.adoptShellPermissionIdentity();
- DeviceConfig.setProperty(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- textClassifierServiceOverrideFlagOldValue,
- /* makeDefault= */ false);
+ DeviceConfig.setProperties(originalProperties);
+ } catch (Throwable t) {
+ Log.e(TAG, "Failed to reset DeviceConfig", t);
} finally {
uiAutomation.dropShellPermissionIdentity();
}
}
- private static String getExtServicesPackageName() {
- PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager();
- try {
- packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0);
- return PKG_NAME_GOOGLE_EXTSERVICES;
- } catch (NameNotFoundException e) {
- return PKG_NAME_AOSP_EXTSERVICES;
- }
+ public void addDeviceConfigOverride(String name, String value) {
+ newPropertiesBuilder.setString(name, value);
+ }
+
+ /**
+ * Overrides the TextClassifier DeviceConfig manually.
+ *
+ * <p>This will clean up all device configs not in newPropertiesBuilder.
+ *
+ * <p>We will need to call this everytime before testing, because DeviceConfig can be synced in
+ * background at anytime. DeviceConfig#setSyncDisabledMode is to disable sync, however it's a
+ * hidden API.
+ */
+ public void overrideDeviceConfig() throws Exception {
+ DeviceConfig.setProperties(newPropertiesBuilder.build());
+ }
+
+ /** Force stop ExtServices. Force-stop-and-start can be helpful to reload some states. */
+ public void forceStopExtServices() {
+ runShellCommand("am force-stop com.google.android.ext.services");
+ runShellCommand("am force-stop android.ext.services");
}
public TextClassifier getTextClassifier() {
@@ -87,4 +95,38 @@ public final class ExtServicesTextClassifierRule extends ExternalResource {
textClassificationManager.setTextClassifier(null); // Reset TC overrides
return textClassificationManager.getTextClassifier();
}
+
+ public void dumpDefaultTextClassifierService() {
+ runShellCommand(
+ "dumpsys activity service com.google.android.ext.services/"
+ + "com.android.textclassifier.DefaultTextClassifierService");
+ runShellCommand("cmd device_config list textclassifier");
+ }
+
+ public void enableVerboseLogging() {
+ runShellCommand("setprop log.tag.androidtc VERBOSE");
+ }
+
+ private void runShellCommand(String cmd) {
+ Log.v(TAG, "run shell command: " + cmd);
+ try (FileInputStream output =
+ new FileInputStream(uiAutomation.executeShellCommand(cmd).getFileDescriptor())) {
+ String cmdOutput = new String(ByteStreams.toByteArray(output));
+ if (!cmdOutput.isEmpty()) {
+ Log.d(TAG, "cmd output: " + cmdOutput);
+ }
+ } catch (IOException ioe) {
+ Log.w(TAG, "failed to get cmd output", ioe);
+ }
+ }
+
+ private static String getExtServicesPackageName() {
+ PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager();
+ try {
+ packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0);
+ return PKG_NAME_GOOGLE_EXTSERVICES;
+ } catch (NameNotFoundException e) {
+ return PKG_NAME_AOSP_EXTSERVICES;
+ }
+ }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java
deleted file mode 100644
index 7d46e97..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.testing;
-
-import android.app.UiAutomation;
-import android.os.LocaleList;
-import android.util.Log;
-import androidx.test.platform.app.InstrumentationRegistry;
-import org.junit.rules.ExternalResource;
-
-/** class for overriding testing_locale_list_override from {@link TextClassifierSettings} */
-public final class TestingLocaleListOverrideRule extends ExternalResource {
- private static final String TAG = "TestingLocaleListOverrideRule";
-
- private LocaleList originalLocaleList;
-
- @Override
- protected void before() {
- originalLocaleList = LocaleList.getDefault();
- }
-
- public void set(String... localeTags) {
- if (localeTags.length == 0) {
- return;
- }
- runShellCommand(
- "device_config put textclassifier testing_locale_list_override "
- + String.join(",", localeTags));
- }
-
- @Override
- protected void after() {
- runShellCommand(
- "device_config put textclassifier testing_locale_list_override "
- + originalLocaleList.toLanguageTags());
- runShellCommand("device_config delete textclassifier testing_locale_list_override");
- }
-
- private static void runShellCommand(String cmd) {
- Log.v(TAG, "run shell command: " + cmd);
- UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
- uiAutomation.executeShellCommand(cmd);
- }
-}
diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs
index 7421579..6ebf1cf 100644
--- a/native/actions/actions-entity-data.bfbs
+++ b/native/actions/actions-entity-data.bfbs
Binary files differ
diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs
index 21584b6..e906f93 100644
--- a/native/actions/actions-entity-data.fbs
+++ b/native/actions/actions-entity-data.fbs
@@ -18,7 +18,7 @@
namespace libtextclassifier3;
table ActionsEntityData {
// Extracted text.
- text:string (shared);
+ text:string (key, shared);
}
root_type libtextclassifier3.ActionsEntityData;
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index b1a042c..9f9a8d4 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -17,6 +17,7 @@
#include "actions/actions-suggestions.h"
#include <memory>
+#include <string>
#include <vector>
#include "utils/base/statusor.h"
@@ -40,6 +41,7 @@
#include "utils/strings/stringpiece.h"
#include "utils/strings/utf8.h"
#include "utils/utf8/unicodetext.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/lite/string_util.h"
namespace libtextclassifier3 {
@@ -809,12 +811,14 @@ bool ActionsSuggestions::SetupModelInput(
void ActionsSuggestions::PopulateTextReplies(
const tflite::Interpreter* interpreter, int suggestion_index,
- int score_index, const std::string& type,
+ int score_index, const std::string& type, float priority_score,
+ const absl::flat_hash_set<std::string>& blocklist,
ActionsSuggestionsResponse* response) const {
const std::vector<tflite::StringRef> replies =
model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
const TensorView<float> scores =
model_executor_->OutputView<float>(score_index, interpreter);
+
for (int i = 0; i < replies.size(); i++) {
if (replies[i].len == 0) {
continue;
@@ -823,8 +827,12 @@ void ActionsSuggestions::PopulateTextReplies(
if (score < preconditions_.min_reply_score_threshold) {
continue;
}
- response->actions.push_back(
- {std::string(replies[i].str, replies[i].len), type, score});
+ std::string response_text(replies[i].str, replies[i].len);
+ if (blocklist.contains(response_text)) {
+ continue;
+ }
+
+ response->actions.push_back({response_text, type, score, priority_score});
}
}
@@ -909,10 +917,12 @@ bool ActionsSuggestions::ReadModelOutput(
// Read smart reply predictions.
if (!response->output_filtered_min_triggering_score &&
model_->tflite_model_spec()->output_replies() >= 0) {
+ absl::flat_hash_set<std::string> empty_blocklist;
PopulateTextReplies(interpreter,
model_->tflite_model_spec()->output_replies(),
model_->tflite_model_spec()->output_replies_scores(),
- model_->smart_reply_action_type()->str(), response);
+ model_->smart_reply_action_type()->str(),
+ /* priority_score */ 0.0, empty_blocklist, response);
}
// Read actions suggestions.
@@ -950,17 +960,26 @@ bool ActionsSuggestions::ReadModelOutput(
const int suggestions_index = metadata->output_suggestions();
const int suggestions_scores_index =
metadata->output_suggestions_scores();
+ absl::flat_hash_set<std::string> response_text_blocklist;
switch (metadata->prediction_type()) {
case PredictionType_NEXT_MESSAGE_PREDICTION:
if (!task_spec || task_spec->type()->size() == 0) {
TC3_LOG(WARNING) << "Task type not provided, use default "
"smart_reply_action_type!";
}
+ if (task_spec) {
+ if (task_spec->response_text_blocklist()) {
+ for (const auto& val : *task_spec->response_text_blocklist()) {
+ response_text_blocklist.insert(val->str());
+ }
+ }
+ }
PopulateTextReplies(
interpreter, suggestions_index, suggestions_scores_index,
task_spec ? task_spec->type()->str()
: model_->smart_reply_action_type()->str(),
- response);
+ task_spec ? task_spec->priority_score() : 0.0,
+ response_text_blocklist, response);
break;
case PredictionType_INTENT_TRIGGERING:
PopulateIntentTriggering(interpreter, suggestions_index,
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 32edc78..87f55fb 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -43,6 +43,7 @@
#include "utils/utf8/unilib.h"
#include "utils/variant.h"
#include "utils/zlib/zlib.h"
+#include "absl/container/flat_hash_set.h"
namespace libtextclassifier3 {
@@ -176,7 +177,8 @@ class ActionsSuggestions {
void PopulateTextReplies(const tflite::Interpreter* interpreter,
int suggestion_index, int score_index,
- const std::string& type,
+ const std::string& type, float priority_score,
+ const absl::flat_hash_set<std::string>& blocklist,
ActionsSuggestionsResponse* response) const;
void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index 062d527..b51ebc7 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -1798,6 +1798,7 @@ TEST_F(ActionsSuggestionsTest,
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kMultiTaskSrEmojiModelFileName);
+
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "hello?",
@@ -1807,9 +1808,31 @@ TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
/*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(), 5);
EXPECT_EQ(response.actions[0].response_text, "😁");
- EXPECT_EQ(response.actions[0].type, "EMOJI_CONCEPT");
- EXPECT_EQ(response.actions[1].response_text, "Yes");
- EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION");
+ EXPECT_EQ(response.actions[0].type, "text_reply");
+ EXPECT_EQ(response.actions[1].response_text, "👋");
+ EXPECT_EQ(response.actions[1].type, "text_reply");
+ EXPECT_EQ(response.actions[2].response_text, "Yes");
+ EXPECT_EQ(response.actions[2].type, "text_reply");
+}
+
+TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kMultiTaskSrEmojiModelFileName);
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "a pleasure chatting",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), 3);
+ EXPECT_EQ(response.actions[0].response_text, "😁");
+ EXPECT_EQ(response.actions[0].type, "text_reply");
+ EXPECT_EQ(response.actions[1].response_text, "😘");
+ EXPECT_EQ(response.actions[1].type, "text_reply");
+ EXPECT_EQ(response.actions[2].response_text, "Okay");
+ EXPECT_EQ(response.actions[2].type, "text_reply");
}
TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 8c03eeb..0d8c7ad 100644
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -36,6 +36,17 @@ enum PredictionType : int {
ENTITY_ANNOTATION = 3,
}
+namespace libtextclassifier3;
+enum RankingOptionsSortType : int {
+ SORT_TYPE_UNSPECIFIED = 0,
+
+ // Rank results (or groups) by score, then type
+ SORT_TYPE_SCORE = 1,
+
+ // Rank results (or groups) by priority score, then score, then type
+ SORT_TYPE_PRIORITY_SCORE = 2,
+}
+
// Prediction metadata for an arbitrary task.
namespace libtextclassifier3;
table PredictionMetadata {
@@ -315,10 +326,11 @@ table ActionSuggestionSpec {
// Additional entity information.
serialized_entity_data:string (shared);
- // Priority score used for internal conflict resolution.
+ // For ranking and internal conflict resolution.
priority_score:float = 0;
entity_data:ActionsEntityData;
+ response_text_blocklist:[string];
}
// Options to specify triggering behaviour per action class.
@@ -416,6 +428,8 @@ table RankingOptions {
// If true, keep actions from the same entities together for ranking.
group_by_annotations:bool = true;
+
+ sort_type:RankingOptionsSortType = SORT_TYPE_SCORE;
}
// Entity data to set from capturing groups.
diff --git a/native/actions/ranker.cc b/native/actions/ranker.cc
index d52ecaa..46e392a 100644
--- a/native/actions/ranker.cc
+++ b/native/actions/ranker.cc
@@ -20,6 +20,8 @@
#include <set>
#include <vector>
+#include "actions/actions_model_generated.h"
+
#if !defined(TC3_DISABLE_LUA)
#include "actions/lua-ranker.h"
#endif
@@ -34,11 +36,22 @@ namespace libtextclassifier3 {
namespace {
void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
- std::sort(actions->begin(), actions->end(),
- [](const ActionSuggestion& a, const ActionSuggestion& b) {
- return a.score > b.score ||
- (a.score >= b.score && a.type < b.type);
- });
+ std::stable_sort(actions->begin(), actions->end(),
+ [](const ActionSuggestion& a, const ActionSuggestion& b) {
+ return a.score > b.score ||
+ (a.score >= b.score && a.type < b.type);
+ });
+}
+
+void SortByPriorityAndScoreAndType(std::vector<ActionSuggestion>* actions) {
+ std::stable_sort(
+ actions->begin(), actions->end(),
+ [](const ActionSuggestion& a, const ActionSuggestion& b) {
+ return a.priority_score > b.priority_score ||
+ (a.priority_score >= b.priority_score && a.score > b.score) ||
+ (a.priority_score >= b.priority_score && a.score >= b.score &&
+ a.type < b.type);
+ });
}
template <typename T>
@@ -241,13 +254,8 @@ bool ActionsSuggestionsRanker::RankActions(
const reflection::Schema* annotations_entity_data_schema) const {
if (options_->deduplicate_suggestions() ||
options_->deduplicate_suggestions_by_span()) {
- // First order suggestions by priority score for deduplication.
- std::sort(
- response->actions.begin(), response->actions.end(),
- [](const ActionSuggestion& a, const ActionSuggestion& b) {
- return a.priority_score > b.priority_score ||
- (a.priority_score >= b.priority_score && a.score > b.score);
- });
+ // Order suggestions by [priority score -> score] for deduplication
+ SortByPriorityAndScoreAndType(&response->actions);
// Deduplicate, keeping the higher score actions.
if (options_->deduplicate_suggestions()) {
@@ -275,6 +283,8 @@ bool ActionsSuggestionsRanker::RankActions(
}
}
+ bool sort_by_priority =
+ options_->sort_type() == RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
// Suppress smart replies if actions are present.
if (options_->suppress_smart_replies_with_actions()) {
std::vector<ActionSuggestion> non_smart_reply_actions;
@@ -316,17 +326,35 @@ bool ActionsSuggestionsRanker::RankActions(
// Sort within each group by score.
for (std::vector<ActionSuggestion>& group : groups) {
- SortByScoreAndType(&group);
+ if (sort_by_priority) {
+ SortByPriorityAndScoreAndType(&group);
+ } else {
+ SortByScoreAndType(&group);
+ }
}
- // Sort groups by maximum score.
- std::sort(groups.begin(), groups.end(),
- [](const std::vector<ActionSuggestion>& a,
- const std::vector<ActionSuggestion>& b) {
- return a.begin()->score > b.begin()->score ||
- (a.begin()->score >= b.begin()->score &&
- a.begin()->type < b.begin()->type);
- });
+ // Sort groups by maximum score or priority score.
+ if (sort_by_priority) {
+ std::stable_sort(
+ groups.begin(), groups.end(),
+ [](const std::vector<ActionSuggestion>& a,
+ const std::vector<ActionSuggestion>& b) {
+ return (a.begin()->priority_score > b.begin()->priority_score) ||
+ (a.begin()->priority_score >= b.begin()->priority_score &&
+ a.begin()->score > b.begin()->score) ||
+ (a.begin()->priority_score >= b.begin()->priority_score &&
+ a.begin()->score >= b.begin()->score &&
+ a.begin()->type < b.begin()->type);
+ });
+ } else {
+ std::stable_sort(groups.begin(), groups.end(),
+ [](const std::vector<ActionSuggestion>& a,
+ const std::vector<ActionSuggestion>& b) {
+ return a.begin()->score > b.begin()->score ||
+ (a.begin()->score >= b.begin()->score &&
+ a.begin()->type < b.begin()->type);
+ });
+ }
// Flatten result.
const size_t num_actions = response->actions.size();
@@ -336,9 +364,9 @@ bool ActionsSuggestionsRanker::RankActions(
response->actions.insert(response->actions.end(), actions.begin(),
actions.end());
}
-
+ } else if (sort_by_priority) {
+ SortByPriorityAndScoreAndType(&response->actions);
} else {
- // Order suggestions independently by score.
SortByScoreAndType(&response->actions);
}
diff --git a/native/actions/ranker_test.cc b/native/actions/ranker_test.cc
index b52cf45..5eba45f 100644
--- a/native/actions/ranker_test.cc
+++ b/native/actions/ranker_test.cc
@@ -18,6 +18,7 @@
#include <string>
+#include "actions/actions_model_generated.h"
#include "actions/types.h"
#include "utils/zlib/zlib.h"
#include "gmock/gmock.h"
@@ -308,12 +309,12 @@ TEST(RankingTest, GroupsActionsByAnnotations) {
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
- /*priority_score=*/1.0,
+ /*priority_score=*/0.0,
/*annotations=*/{annotation}});
response.actions.push_back({/*response_text=*/"",
/*type=*/"add_contact",
/*score=*/0.0,
- /*priority_score=*/0.0,
+ /*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
response.actions.push_back({/*response_text=*/"How are you?",
@@ -338,23 +339,75 @@ TEST(RankingTest, GroupsActionsByAnnotations) {
IsAction("text_reply", "How are you?", 0.5)}));
}
-TEST(RankingTest, SortsActionsByScore) {
+TEST(RankingTest, GroupsByAnnotationsSortedByPriority) {
const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
ActionsSuggestionsResponse response;
+ response.actions.push_back({/*response_text=*/"How are you?",
+ /*type=*/"text_reply",
+ /*score=*/2.0,
+ /*priority_score=*/0.0});
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
/*text=*/"911"};
annotation.entity = ClassificationResult("phone", 1.0);
response.actions.push_back({/*response_text=*/"",
+ /*type=*/"add_contact",
+ /*score=*/0.0,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/{annotation}});
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"add_contact2",
+ /*score=*/0.5,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
+ }
+ RankingOptionsT options;
+ options.group_by_annotations = true;
+ options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+
+ // The text reply should be last, even though it's score is higher than
+ // any other scores -- because it's priority_score is lower than the max
+ // of those with the 'phone' annotation
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({
+ // Group 1 (Phone annotation)
+ IsAction("add_contact2", "", 0.5), // priority_score=1.0
+ IsAction("add_contact", "", 0.0), // priority_score=1.0
+ IsAction("call_phone", "", 1.0), // priority_score=0.0
+ IsAction("text_reply", "How are you?", 2.0), // Group 2
+ }));
+}
+
+TEST(RankingTest, SortsActionsByScore) {
+ const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
+ ActionsSuggestionsResponse response;
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
+ /*text=*/"911"};
+ annotation.entity = ClassificationResult("phone", 1.0);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"call_phone",
+ /*score=*/1.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/{annotation}});
response.actions.push_back({/*response_text=*/"",
/*type=*/"add_contact",
/*score=*/0.0,
- /*priority_score=*/0.0,
+ /*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
response.actions.push_back({/*response_text=*/"How are you?",
@@ -378,5 +431,40 @@ TEST(RankingTest, SortsActionsByScore) {
IsAction("add_contact", "", 0.0)}));
}
+TEST(RankingTest, SortsActionsByPriority) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello?"}}};
+ ActionsSuggestionsResponse response;
+ // emoji replies given higher priority_score
+ response.actions.push_back({/*response_text=*/"😁",
+ /*type=*/"text_reply",
+ /*score=*/0.5,
+ /*priority_score=*/1.0});
+ response.actions.push_back({/*response_text=*/"👋",
+ /*type=*/"text_reply",
+ /*score=*/0.4,
+ /*priority_score=*/1.0});
+ response.actions.push_back({/*response_text=*/"Yes",
+ /*type=*/"text_reply",
+ /*score=*/1.0,
+ /*priority_score=*/0.0});
+ RankingOptionsT options;
+ // Don't group by annotation.
+ options.group_by_annotations = false;
+ options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+
+ EXPECT_THAT(response.actions, testing::ElementsAreArray(
+ {IsAction("text_reply", "😁", 0.5),
+ IsAction("text_reply", "👋", 0.4),
+ // Ranked last because of priority score
+ IsAction("text_reply", "Yes", 1.0)}));
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index 77e556c..0fa7f7e 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index c468bd5..6107e98 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index ec421a1..436ed93 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
index 24be6c6..935691d 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index fd7ddf2..2c9f74b 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
index c969c56..cdb7523 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
index d171898..ac28fa2 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
index 937552b..d864b79 100644
--- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
+++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
Binary files differ
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 32bd29c..e0d4241 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -973,11 +973,11 @@ CodepointSpan Annotator::SuggestSelection(
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
// contiguous block.
- std::sort(candidates.annotated_spans[0].begin(),
- candidates.annotated_spans[0].end(),
- [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
- return a.span.first < b.span.first;
- });
+ std::stable_sort(candidates.annotated_spans[0].begin(),
+ candidates.annotated_spans[0].end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ return a.span.first < b.span.first;
+ });
std::vector<int> candidate_indices;
if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
@@ -987,13 +987,14 @@ CodepointSpan Annotator::SuggestSelection(
return original_click_indices;
}
- std::sort(candidate_indices.begin(), candidate_indices.end(),
- [this, &candidates](int a, int b) {
- return GetPriorityScore(
- candidates.annotated_spans[0][a].classification) >
- GetPriorityScore(
- candidates.annotated_spans[0][b].classification);
- });
+ std::stable_sort(
+ candidate_indices.begin(), candidate_indices.end(),
+ [this, &candidates](int a, int b) {
+ return GetPriorityScore(
+ candidates.annotated_spans[0][a].classification) >
+ GetPriorityScore(
+ candidates.annotated_spans[0][b].classification);
+ });
for (const int i : candidate_indices) {
if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
@@ -1173,7 +1174,7 @@ bool Annotator::ResolveConflict(
}
}
- std::sort(
+ std::stable_sort(
conflicting_indices.begin(), conflicting_indices.end(),
[this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
if (scores_lengths[i].first == scores_lengths[j].first &&
@@ -1241,7 +1242,7 @@ bool Annotator::ResolveConflict(
chosen_indices_for_source_ptr->insert(considered_candidate);
}
- std::sort(chosen_indices->begin(), chosen_indices->end());
+ std::stable_sort(chosen_indices->begin(), chosen_indices->end());
return true;
}
@@ -1414,10 +1415,11 @@ namespace {
// Sorts the classification results from high score to low score.
void SortClassificationResults(
std::vector<ClassificationResult>* classification_results) {
- std::sort(classification_results->begin(), classification_results->end(),
- [](const ClassificationResult& a, const ClassificationResult& b) {
- return a.score > b.score;
- });
+ std::stable_sort(
+ classification_results->begin(), classification_results->end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
}
} // namespace
@@ -1936,10 +1938,11 @@ std::vector<ClassificationResult> Annotator::ClassifyText(
}
// Sort results according to score.
- std::sort(results.begin(), results.end(),
- [](const ClassificationResult& a, const ClassificationResult& b) {
- return a.score > b.score;
- });
+ std::stable_sort(
+ results.begin(), results.end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
if (results.empty()) {
results = {{Collections::Other(), 1.0}};
@@ -2297,19 +2300,19 @@ Status Annotator::AnnotateSingleInput(
// Also sort them according to the end position and collection, so that the
// deduplication code below can assume that same spans and classifications
// form contiguous blocks.
- std::sort(candidates->begin(), candidates->end(),
- [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
- if (a.span.first != b.span.first) {
- return a.span.first < b.span.first;
- }
+ std::stable_sort(candidates->begin(), candidates->end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ if (a.span.first != b.span.first) {
+ return a.span.first < b.span.first;
+ }
- if (a.span.second != b.span.second) {
- return a.span.second < b.span.second;
- }
+ if (a.span.second != b.span.second) {
+ return a.span.second < b.span.second;
+ }
- return a.classification[0].collection <
- b.classification[0].collection;
- });
+ return a.classification[0].collection <
+ b.classification[0].collection;
+ });
std::vector<int> candidate_indices;
if (!ResolveConflicts(*candidates, context, tokens,
@@ -2904,10 +2907,10 @@ bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
return false;
}
}
- std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
- [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
- return lhs.score < rhs.score;
- });
+ std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(),
+ [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
+ return lhs.score < rhs.score;
+ });
// Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
// them greedily as long as they do not overlap with any previously picked
@@ -2936,7 +2939,7 @@ bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
chunks->push_back(scored_chunk.token_span);
}
- std::sort(chunks->begin(), chunks->end());
+ std::stable_sort(chunks->begin(), chunks->end());
return true;
}
diff --git a/native/annotator/datetime/datetime-grounder.cc b/native/annotator/datetime/datetime-grounder.cc
index 7d5f440..ff0c775 100644
--- a/native/annotator/datetime/datetime-grounder.cc
+++ b/native/annotator/datetime/datetime-grounder.cc
@@ -16,6 +16,7 @@
#include "annotator/datetime/datetime-grounder.h"
+#include <algorithm>
#include <limits>
#include <unordered_map>
#include <vector>
@@ -250,10 +251,10 @@ StatusOr<std::vector<DatetimeParseResult>> DatetimeGrounder::Ground(
}
// Sort the date time units by component type.
- std::sort(date_components.begin(), date_components.end(),
- [](DatetimeComponent a, DatetimeComponent b) {
- return a.component_type > b.component_type;
- });
+ std::stable_sort(date_components.begin(), date_components.end(),
+ [](DatetimeComponent a, DatetimeComponent b) {
+ return a.component_type > b.component_type;
+ });
result.datetime_components.swap(date_components);
datetime_parse_result.push_back(result);
}
diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc
index 867c886..94a0961 100644
--- a/native/annotator/datetime/extractor.cc
+++ b/native/annotator/datetime/extractor.cc
@@ -16,6 +16,8 @@
#include "annotator/datetime/extractor.h"
+#include <algorithm>
+
#include "annotator/datetime/utils.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
@@ -347,10 +349,11 @@ bool DatetimeExtractor::ParseWrittenNumber(const UnicodeText& input,
}
}
- std::sort(found_numbers.begin(), found_numbers.end(),
- [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
- return a.first < b.first;
- });
+ std::stable_sort(
+ found_numbers.begin(), found_numbers.end(),
+ [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
+ return a.first < b.first;
+ });
int sum = 0;
int running_value = -1;
diff --git a/native/annotator/datetime/regex-parser.cc b/native/annotator/datetime/regex-parser.cc
index 4dc9c56..5daabd5 100644
--- a/native/annotator/datetime/regex-parser.cc
+++ b/native/annotator/datetime/regex-parser.cc
@@ -16,6 +16,7 @@
#include "annotator/datetime/regex-parser.h"
+#include <algorithm>
#include <iterator>
#include <set>
#include <unordered_set>
@@ -191,17 +192,17 @@ StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
// Resolve conflicts by always picking the longer span and breaking ties by
// selecting the earlier entry in the list for a given locale.
- std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
- [](const std::pair<DatetimeParseResultSpan, int>& a,
- const std::pair<DatetimeParseResultSpan, int>& b) {
- if ((a.first.span.second - a.first.span.first) !=
- (b.first.span.second - b.first.span.first)) {
- return (a.first.span.second - a.first.span.first) >
- (b.first.span.second - b.first.span.first);
- } else {
- return a.second < b.second;
- }
- });
+ std::stable_sort(indexed_found_spans.begin(), indexed_found_spans.end(),
+ [](const std::pair<DatetimeParseResultSpan, int>& a,
+ const std::pair<DatetimeParseResultSpan, int>& b) {
+ if ((a.first.span.second - a.first.span.first) !=
+ (b.first.span.second - b.first.span.first)) {
+ return (a.first.span.second - a.first.span.first) >
+ (b.first.span.second - b.first.span.first);
+ } else {
+ return a.second < b.second;
+ }
+ });
std::vector<DatetimeParseResultSpan> results;
std::vector<DatetimeParseResultSpan> resolved_found_spans;
@@ -394,10 +395,10 @@ bool RegexDatetimeParser::ExtractDatetime(
}
// Sort the date time units by component type.
- std::sort(date_components.begin(), date_components.end(),
- [](DatetimeComponent a, DatetimeComponent b) {
- return a.component_type > b.component_type;
- });
+ std::stable_sort(date_components.begin(), date_components.end(),
+ [](DatetimeComponent a, DatetimeComponent b) {
+ return a.component_type > b.component_type;
+ });
result.datetime_components.swap(date_components);
results->push_back(result);
}
diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc
index 640ceec..2c5a43c 100644
--- a/native/annotator/translate/translate.cc
+++ b/native/annotator/translate/translate.cc
@@ -16,6 +16,7 @@
#include "annotator/translate/translate.h"
+#include <algorithm>
#include <memory>
#include "annotator/collections.h"
@@ -142,11 +143,11 @@ TranslateAnnotator::BackoffDetectLanguages(
result.push_back({key, value});
}
- std::sort(result.begin(), result.end(),
- [](TranslateAnnotator::LanguageConfidence& a,
- TranslateAnnotator::LanguageConfidence& b) {
- return a.confidence > b.confidence;
- });
+ std::stable_sort(result.begin(), result.end(),
+ [](const TranslateAnnotator::LanguageConfidence& a,
+ const TranslateAnnotator::LanguageConfidence& b) {
+ return a.confidence > b.confidence;
+ });
return result;
}
diff --git a/native/lang_id/common/embedding-network.cc b/native/lang_id/common/embedding-network.cc
index 469cb1f..49c9ca0 100644
--- a/native/lang_id/common/embedding-network.cc
+++ b/native/lang_id/common/embedding-network.cc
@@ -16,6 +16,8 @@
#include "lang_id/common/embedding-network.h"
+#include <vector>
+
#include "lang_id/common/lite_base/integral-types.h"
#include "lang_id/common/lite_base/logging.h"
diff --git a/native/lang_id/common/fel/feature-extractor.cc b/native/lang_id/common/fel/feature-extractor.cc
index ab8a1a6..4e304fe 100644
--- a/native/lang_id/common/fel/feature-extractor.cc
+++ b/native/lang_id/common/fel/feature-extractor.cc
@@ -17,6 +17,7 @@
#include "lang_id/common/fel/feature-extractor.h"
#include <string>
+#include <vector>
#include "lang_id/common/fel/feature-types.h"
#include "lang_id/common/fel/fel-parser.h"
diff --git a/native/lang_id/common/fel/workspace.cc b/native/lang_id/common/fel/workspace.cc
index af41e29..60dcc46 100644
--- a/native/lang_id/common/fel/workspace.cc
+++ b/native/lang_id/common/fel/workspace.cc
@@ -18,6 +18,7 @@
#include <atomic>
#include <string>
+#include <vector>
namespace libtextclassifier3 {
namespace mobile {
diff --git a/native/lang_id/common/fel/workspace.h b/native/lang_id/common/fel/workspace.h
index f13d802..2ac5b26 100644
--- a/native/lang_id/common/fel/workspace.h
+++ b/native/lang_id/common/fel/workspace.h
@@ -23,6 +23,7 @@
#include <stddef.h>
+#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc
index 19afcc4..fc925ea 100644
--- a/native/lang_id/common/file/mmap.cc
+++ b/native/lang_id/common/file/mmap.cc
@@ -29,6 +29,8 @@
#endif
#include <sys/stat.h>
+#include <string>
+
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_base/macros.h"
diff --git a/native/lang_id/common/lite_strings/str-split.cc b/native/lang_id/common/lite_strings/str-split.cc
index 199bb69..d227eec 100644
--- a/native/lang_id/common/lite_strings/str-split.cc
+++ b/native/lang_id/common/lite_strings/str-split.cc
@@ -16,6 +16,8 @@
#include "lang_id/common/lite_strings/str-split.h"
+#include <vector>
+
namespace libtextclassifier3 {
namespace mobile {
diff --git a/native/lang_id/common/math/softmax.cc b/native/lang_id/common/math/softmax.cc
index 750341d..249ed57 100644
--- a/native/lang_id/common/math/softmax.cc
+++ b/native/lang_id/common/math/softmax.cc
@@ -17,6 +17,7 @@
#include "lang_id/common/math/softmax.h"
#include <algorithm>
+#include <vector>
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/math/fastexp.h"
diff --git a/native/lang_id/fb_model/lang-id-from-fb.cc b/native/lang_id/fb_model/lang-id-from-fb.cc
index dc36fb7..51c8c47 100644
--- a/native/lang_id/fb_model/lang-id-from-fb.cc
+++ b/native/lang_id/fb_model/lang-id-from-fb.cc
@@ -16,7 +16,9 @@
#include "lang_id/fb_model/lang-id-from-fb.h"
+#include <memory>
#include <string>
+#include <utility>
#include "lang_id/fb_model/model-provider-from-fb.h"
diff --git a/native/lang_id/fb_model/model-provider-from-fb.cc b/native/lang_id/fb_model/model-provider-from-fb.cc
index 43bf860..d14d403 100644
--- a/native/lang_id/fb_model/model-provider-from-fb.cc
+++ b/native/lang_id/fb_model/model-provider-from-fb.cc
@@ -16,7 +16,9 @@
#include "lang_id/fb_model/model-provider-from-fb.h"
+#include <memory>
#include <string>
+#include <utility>
#include "lang_id/common/file/file-utils.h"
#include "lang_id/common/file/mmap.h"
diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc
index 92359a9..f7c66f7 100644
--- a/native/lang_id/lang-id.cc
+++ b/native/lang_id/lang-id.cc
@@ -21,6 +21,7 @@
#include <memory>
#include <string>
#include <unordered_map>
+#include <utility>
#include <vector>
#include "lang_id/common/embedding-feature-interface.h"
diff --git a/native/utils/codepoint-range.cc b/native/utils/codepoint-range.cc
index e26b160..a4cd485 100644
--- a/native/utils/codepoint-range.cc
+++ b/native/utils/codepoint-range.cc
@@ -31,10 +31,11 @@ void SortCodepointRanges(
CodepointRangeStruct(range->start(), range->end()));
}
- std::sort(sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(),
- [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) {
- return a.start < b.start;
- });
+ std::stable_sort(
+ sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(),
+ [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) {
+ return a.start < b.start;
+ });
}
// Returns true if given codepoint is covered by the given sorted vector of
diff --git a/native/utils/grammar/parsing/parser.cc b/native/utils/grammar/parsing/parser.cc
index 4e39a98..a9e99ba 100644
--- a/native/utils/grammar/parsing/parser.cc
+++ b/native/utils/grammar/parsing/parser.cc
@@ -16,6 +16,7 @@
#include "utils/grammar/parsing/parser.h"
+#include <algorithm>
#include <unordered_map>
#include "utils/grammar/parsing/parse-tree.h"
@@ -177,14 +178,14 @@ std::vector<Symbol> Parser::SortedSymbolsForInput(const TextContext& input,
}
}
- std::sort(symbols.begin(), symbols.end(),
- [](const Symbol& a, const Symbol& b) {
- // Sort by increasing (end, start) position to guarantee the
- // matcher requirement that the tokens are fed in non-decreasing
- // end position order.
- return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
- std::tie(b.codepoint_span.second, b.codepoint_span.first);
- });
+ std::stable_sort(
+ symbols.begin(), symbols.end(), [](const Symbol& a, const Symbol& b) {
+ // Sort by increasing (end, start) position to guarantee the
+ // matcher requirement that the tokens are fed in non-decreasing
+ // end position order.
+ return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
+ std::tie(b.codepoint_span.second, b.codepoint_span.first);
+ });
return symbols;
}
diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc
index dd29e3c..c134550 100644
--- a/native/utils/grammar/utils/ir.cc
+++ b/native/utils/grammar/utils/ir.cc
@@ -16,6 +16,8 @@
#include "utils/grammar/utils/ir.h"
+#include <algorithm>
+
#include "utils/i18n/locale.h"
#include "utils/strings/append.h"
#include "utils/strings/stringpiece.h"
@@ -28,14 +30,16 @@ constexpr size_t kMaxHashTableSize = 100;
template <typename T>
void SortForBinarySearchLookup(T* entries) {
- std::sort(entries->begin(), entries->end(),
- [](const auto& a, const auto& b) { return a->key < b->key; });
+ std::stable_sort(
+ entries->begin(), entries->end(),
+ [](const auto& a, const auto& b) { return a->key < b->key; });
}
template <typename T>
void SortStructsForBinarySearchLookup(T* entries) {
- std::sort(entries->begin(), entries->end(),
- [](const auto& a, const auto& b) { return a.key() < b.key(); });
+ std::stable_sort(
+ entries->begin(), entries->end(),
+ [](const auto& a, const auto& b) { return a.key() < b.key(); });
}
bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) {
@@ -76,13 +80,14 @@ bool IsSameLhsSet(const Ir::LhsSet& lhs_set,
Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
Ir::LhsSet sorted_lhs = lhs_set;
- std::sort(sorted_lhs.begin(), sorted_lhs.end(),
- [](const Ir::Lhs& a, const Ir::Lhs& b) {
- return std::tie(a.nonterminal, a.callback.id, a.callback.param,
- a.preconditions.max_whitespace_gap) <
- std::tie(b.nonterminal, b.callback.id, b.callback.param,
- b.preconditions.max_whitespace_gap);
- });
+ std::stable_sort(
+ sorted_lhs.begin(), sorted_lhs.end(),
+ [](const Ir::Lhs& a, const Ir::Lhs& b) {
+ return std::tie(a.nonterminal, a.callback.id, a.callback.param,
+ a.preconditions.max_whitespace_gap) <
+ std::tie(b.nonterminal, b.callback.id, b.callback.param,
+ b.preconditions.max_whitespace_gap);
+ });
return lhs_set;
}
@@ -300,10 +305,10 @@ void Ir::SerializeTerminalRules(
TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second});
}
}
- std::sort(terminal_rules.begin(), terminal_rules.end(),
- [](const TerminalEntry& a, const TerminalEntry& b) {
- return a.terminal < b.terminal;
- });
+ std::stable_sort(terminal_rules.begin(), terminal_rules.end(),
+ [](const TerminalEntry& a, const TerminalEntry& b) {
+ return a.terminal < b.terminal;
+ });
// Index the entries in sorted order.
std::vector<int> index(terminal_rules_sets.size(), 0);
diff --git a/native/utils/grammar/utils/locale-shard-map.cc b/native/utils/grammar/utils/locale-shard-map.cc
index e6db06d..141ce5d 100644
--- a/native/utils/grammar/utils/locale-shard-map.cc
+++ b/native/utils/grammar/utils/locale-shard-map.cc
@@ -40,8 +40,8 @@ std::vector<Locale> LocaleTagsToLocaleList(const std::string& locale_tags) {
locale_list.emplace_back(locale);
}
}
- std::sort(locale_list.begin(), locale_list.end(),
- [](const Locale& a, const Locale& b) { return a < b; });
+ std::stable_sort(locale_list.begin(), locale_list.end(),
+ [](const Locale& a, const Locale& b) { return a < b; });
return locale_list;
}
diff --git a/native/utils/testing/test_data_generator.h b/native/utils/testing/test_data_generator.h
index 30c7aed..c23b5dc 100644
--- a/native/utils/testing/test_data_generator.h
+++ b/native/utils/testing/test_data_generator.h
@@ -20,6 +20,7 @@
#include <algorithm>
#include <iostream>
#include <random>
+#include <string>
#include "utils/strings/stringpiece.h"
@@ -35,6 +36,18 @@ class TestDataGenerator {
return dist(random_engine_);
}
+ template <>
+ bool generate() {
+ std::bernoulli_distribution dist(0.5);
+ return dist(random_engine_);
+ }
+
+ template <>
+ char generate() {
+ std::uniform_int_distribution<int> dist(0, 25);
+ return dist(random_engine_) + 'a';
+ }
+
template <typename T, typename std::enable_if_t<
std::is_floating_point<T>::value>* = nullptr>
T generate() {
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 463d910..644dde8 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -27,6 +27,8 @@ namespace builtin {
TfLiteRegistration* Register_ADD();
TfLiteRegistration* Register_CONCATENATION();
TfLiteRegistration* Register_CONV_2D();
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
+TfLiteRegistration* Register_AVERAGE_POOL_2D();
TfLiteRegistration* Register_EQUAL();
TfLiteRegistration* Register_FULLY_CONNECTED();
TfLiteRegistration* Register_GREATER_EQUAL();
@@ -89,7 +91,9 @@ TfLiteRegistration* Register_GREATER();
#include "utils/tflite/dist_diversification.h"
#include "utils/tflite/string_projection.h"
#include "utils/tflite/text_encoder.h"
+#include "utils/tflite/text_encoder3s.h"
#include "utils/tflite/token_encoder.h"
+
namespace tflite {
namespace ops {
namespace custom {
@@ -114,6 +118,14 @@ void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
tflite::ops::builtin::Register_CONV_2D(),
/*min_version=*/1,
/*max_version=*/5);
+ resolver->AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
+ tflite::ops::builtin::Register_DEPTHWISE_CONV_2D(),
+ /*min_version=*/1,
+ /*max_version=*/6);
+ resolver->AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D,
+ tflite::ops::builtin::Register_AVERAGE_POOL_2D(),
+ /*min_version=*/1,
+ /*max_version=*/1);
resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL,
::tflite::ops::builtin::Register_EQUAL());
@@ -289,6 +301,8 @@ std::unique_ptr<tflite::OpResolver> BuildOpResolver(
tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
resolver->AddCustom("TextEncoder",
tflite::ops::custom::Register_TEXT_ENCODER());
+ resolver->AddCustom("TextEncoder3S",
+ tflite::ops::custom::Register_TEXT_ENCODER3S());
resolver->AddCustom("TokenEncoder",
tflite::ops::custom::Register_TOKEN_ENCODER());
resolver->AddCustom(
diff --git a/native/utils/tflite/encoder_common.cc b/native/utils/tflite/encoder_common.cc
index 8f9f2a8..eb319f9 100644
--- a/native/utils/tflite/encoder_common.cc
+++ b/native/utils/tflite/encoder_common.cc
@@ -58,6 +58,11 @@ TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
out->data.i32 + output_offset + from_this_element,
in.data.i32[value_index]);
} break;
+ case kTfLiteInt64: {
+ std::fill(out->data.i64 + output_offset,
+ out->data.i64 + output_offset + from_this_element,
+ in.data.i64[value_index]);
+ } break;
case kTfLiteFloat32: {
std::fill(out->data.f + output_offset,
out->data.f + output_offset + from_this_element,
@@ -78,6 +83,12 @@ TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
value);
} break;
+ case kTfLiteInt64: {
+ const int64_t value =
+ (output_offset > 0) ? out->data.i64[output_offset - 1] : 0;
+ std::fill(out->data.i64 + output_offset, out->data.i64 + output_size,
+ value);
+ } break;
case kTfLiteFloat32: {
const float value =
(output_offset > 0) ? out->data.f[output_offset - 1] : 0;
diff --git a/native/utils/tflite/text_encoder3s.cc b/native/utils/tflite/text_encoder3s.cc
new file mode 100644
index 0000000..0b5e65b
--- /dev/null
+++ b/native/utils/tflite/text_encoder3s.cc
@@ -0,0 +1,243 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/text_encoder3s.h"
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/tflite/encoder_common.h"
+#include "utils/tflite/text_encoder_config_generated.h"
+#include "utils/tokenfree/byte_encoder.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Input parameters for the op.
+constexpr int kInputTextInd = 0;
+
+constexpr int kTextLengthInd = 1;
+constexpr int kMaxLengthInd = 2;
+constexpr int kInputAttrInd = 3;
+
+// Output parameters for the op.
+constexpr int kOutputEncodedInd = 0;
+constexpr int kOutputPositionInd = 1;
+constexpr int kOutputLengthsInd = 2;
+constexpr int kOutputAttrInd = 3;
+
+// Initializes text encoder object from serialized parameters.
+void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
+ std::unique_ptr<ByteEncoder> encoder(new ByteEncoder());
+ return encoder.release();
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<ByteEncoder*>(buffer);
+}
+
+namespace {
+TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
+ int max_output_length) {
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[kOutputEncodedInd]];
+
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(
+ context, &output_encoded,
+ CreateIntArray({kEncoderBatchSize, max_output_length})));
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPositionInd]];
+
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(
+ context, &output_positions,
+ CreateIntArray({kEncoderBatchSize, max_output_length})));
+
+ const int num_output_attrs = node->outputs->size - kOutputAttrInd;
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& output =
+ context->tensors[node->outputs->data[kOutputAttrInd + i]];
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(
+ context, &output,
+ CreateIntArray({kEncoderBatchSize, max_output_length})));
+ }
+ return kTfLiteOk;
+}
+} // namespace
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check that the batch dimension is kEncoderBatchSize.
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[kInputTextInd]];
+ TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
+ TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
+
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[kOutputLengthsInd]];
+
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[kOutputEncodedInd]];
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPositionInd]];
+ output_encoded.type = kTfLiteInt32;
+ output_positions.type = kTfLiteInt32;
+ output_lengths.type = kTfLiteInt32;
+
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &output_lengths,
+ CreateIntArray({kEncoderBatchSize})));
+
+ // Check that there are enough outputs for attributes.
+ const int num_output_attrs = node->outputs->size - kOutputAttrInd;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
+ num_output_attrs);
+
+ // Copy attribute types from input to output tensors.
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& input =
+ context->tensors[node->inputs->data[kInputAttrInd + i]];
+ TfLiteTensor& output =
+ context->tensors[node->outputs->data[kOutputAttrInd + i]];
+ output.type = input.type;
+ }
+
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[kMaxLengthInd]];
+
+ if (tflite::IsConstantTensor(&output_length)) {
+ return ResizeOutputTensors(context, node, output_length.data.i64[0]);
+ } else {
+ tflite::SetTensorToDynamic(&output_encoded);
+ tflite::SetTensorToDynamic(&output_positions);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& output_attr =
+ context->tensors[node->outputs->data[kOutputAttrInd + i]];
+ tflite::SetTensorToDynamic(&output_attr);
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ if (node->user_data == nullptr) {
+ return kTfLiteError;
+ }
+ auto text_encoder = reinterpret_cast<ByteEncoder*>(node->user_data);
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[kInputTextInd]];
+ const int num_strings_in_tensor = tflite::GetStringCount(&input_text);
+ const int num_strings =
+ context->tensors[node->inputs->data[kTextLengthInd]].data.i32[0];
+
+ // Check that the number of strings is not bigger than the input tensor size.
+ TF_LITE_ENSURE(context, num_strings_in_tensor >= num_strings);
+
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[kOutputEncodedInd]];
+ if (tflite::IsDynamicTensor(&output_encoded)) {
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[kMaxLengthInd]];
+ TF_LITE_ENSURE_OK(
+ context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
+ }
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPositionInd]];
+
+ std::vector<int> encoded_total;
+ std::vector<int> encoded_positions;
+ std::vector<int> encoded_offsets;
+ encoded_offsets.reserve(num_strings);
+ const int max_output_length = output_encoded.dims->data[1];
+ const int max_encoded_position = max_output_length;
+
+ for (int i = 0; i < num_strings; ++i) {
+ const auto& strref = tflite::GetString(&input_text, i);
+ std::vector<int64_t> encoded;
+ text_encoder->Encode(
+ libtextclassifier3::StringPiece(strref.str, strref.len), &encoded);
+ encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
+ encoded_offsets.push_back(encoded_total.size());
+ for (int i = 0; i < encoded.size(); ++i) {
+ encoded_positions.push_back(std::min(i, max_encoded_position - 1));
+ }
+ }
+
+ // Copy encoding to output tensor.
+ const int start_offset =
+ std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
+ int output_offset = 0;
+ int32_t* output_buffer = output_encoded.data.i32;
+ int32_t* output_positions_buffer = output_positions.data.i32;
+ for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
+ output_buffer[output_offset] = encoded_total[i];
+ output_positions_buffer[output_offset] = encoded_positions[i];
+ }
+
+ // Save output encoded length.
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[kOutputLengthsInd]];
+ output_lengths.data.i32[0] = output_offset;
+
+ // Do padding.
+ for (; output_offset < max_output_length; ++output_offset) {
+ output_buffer[output_offset] = 0;
+ output_positions_buffer[output_offset] = 0;
+ }
+
+ // Process attributes, all checks of sizes and types are done in Prepare.
+ const int num_output_attrs = node->outputs->size - kOutputAttrInd;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
+ num_output_attrs);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
+ context->tensors[node->inputs->data[kInputAttrInd + i]],
+ encoded_offsets, start_offset, context,
+ &context->tensors[node->outputs->data[kOutputAttrInd + i]]);
+ if (attr_status != kTfLiteOk) {
+ return attr_status;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace libtextclassifier3
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER3S() {
+ static TfLiteRegistration registration = {
+ libtextclassifier3::Initialize, libtextclassifier3::Free,
+ libtextclassifier3::Prepare, libtextclassifier3::Eval};
+ return &registration;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/native/utils/tflite/text_encoder3s.h b/native/utils/tflite/text_encoder3s.h
new file mode 100644
index 0000000..50e1e64
--- /dev/null
+++ b/native/utils/tflite/text_encoder3s.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// An encoder that produces positional and attributes encodings for a
+// transformer style model based on byte segmentation of text.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_
+
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER3S();
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_
diff --git a/native/utils/tokenfree/byte_encoder.cc b/native/utils/tokenfree/byte_encoder.cc
new file mode 100644
index 0000000..c79d3a2
--- /dev/null
+++ b/native/utils/tokenfree/byte_encoder.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenfree/byte_encoder.h"
+
+#include <vector>
+namespace libtextclassifier3 {
+
+bool ByteEncoder::Encode(StringPiece input_text,
+ std::vector<int64_t>* encoded_text) const {
+ const int len = input_text.size();
+ if (len <= 0) {
+ *encoded_text = {};
+ return true;
+ }
+
+ int size = input_text.size();
+ encoded_text->resize(size);
+
+ const auto& text = input_text.ToString();
+ for (int i = 0; i < size; i++) {
+ int64_t encoding = static_cast<int64_t>(text[i]);
+ (*encoded_text)[i] = encoding;
+ }
+
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/tokenfree/byte_encoder.h b/native/utils/tokenfree/byte_encoder.h
new file mode 100644
index 0000000..1a495ec
--- /dev/null
+++ b/native/utils/tokenfree/byte_encoder.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_
+#define LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_
+
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/container/string-set.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Encoder to segment/tokenize strings into bytes
+class ByteEncoder {
+ public:
+ bool Encode(StringPiece input_text, std::vector<int64_t>* encoded_text) const;
+ ByteEncoder() {}
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_
diff --git a/native/utils/tokenfree/byte_encoder_test.cc b/native/utils/tokenfree/byte_encoder_test.cc
new file mode 100644
index 0000000..d4d119e
--- /dev/null
+++ b/native/utils/tokenfree/byte_encoder_test.cc
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenfree/byte_encoder.h"
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/container/sorted-strings-table.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+
+TEST(EncoderTest, SimpleTokenization) {
+ const ByteEncoder encoder;
+ {
+ std::vector<int64_t> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
+ EXPECT_THAT(encoded_text,
+ ElementsAre(104, 101, 108, 108, 111, 116, 104, 101, 114, 101));
+ }
+}
+
+TEST(EncoderTest, SimpleTokenization2) {
+ const ByteEncoder encoder;
+ {
+ std::vector<int64_t> encoded_text;
+ EXPECT_TRUE(encoder.Encode("Hello", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(72, 101, 108, 108, 111));
+ }
+}
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/tokenizer.cc b/native/utils/tokenizer.cc
index 071141c..7038517 100644
--- a/native/utils/tokenizer.cc
+++ b/native/utils/tokenizer.cc
@@ -43,11 +43,12 @@ Tokenizer::Tokenizer(
codepoint_ranges_.emplace_back(range->UnPack());
}
- std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
- const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
- return a->start < b->start;
- });
+ std::stable_sort(
+ codepoint_ranges_.begin(), codepoint_ranges_.end(),
+ [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
+ const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
+ return a->start < b->start;
+ });
SortCodepointRanges(internal_tokenizer_codepoint_ranges,
&internal_tokenizer_codepoint_ranges_);
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
index bc30fcf..f539ba7 100644
--- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
@@ -37,15 +37,20 @@ import androidx.test.filters.LargeTest;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@LargeTest
@RunWith(AndroidJUnit4.class)
public class SmartSuggestionsLogSessionTest {
+
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
private static final String RESULT_ID = "resultId";
private static final String REPLY = "reply";
private static final float SCORE = 0.5f;
@@ -55,7 +60,6 @@ public class SmartSuggestionsLogSessionTest {
@Before
public void setup() {
- MockitoAnnotations.initMocks(this);
session =
new SmartSuggestionsLogSession(