diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-03-15 18:59:59 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-03-15 18:59:59 +0000 |
commit | 3d5da76ab70029db2810b2d7b7611bafdc258c50 (patch) | |
tree | 3a3d2ccd661e94baac725672eb79068f9778a135 | |
parent | f106c46253e7ec42e5d39ab1b3fa3ada443917b2 (diff) | |
parent | 8ebbedca8443b38941a7ddadc8245fcc83c6f866 (diff) | |
download | libtextclassifier-android12-mainline-sdkext-release.tar.gz |
Snap for 8303596 from 8ebbedca8443b38941a7ddadc8245fcc83c6f866 to mainline-sdkext-releaseandroid-mainline-12.0.0_r109aml_sdk_311710000android12-mainline-sdkext-release
Change-Id: I5ae59fe3453aa2ba4fe57f826c1f51bae4092d3f
66 files changed, 1273 insertions, 583 deletions
diff --git a/java/src/com/android/textclassifier/ExtrasUtils.java b/java/src/com/android/textclassifier/ExtrasUtils.java index fd64581..bde3898 100644 --- a/java/src/com/android/textclassifier/ExtrasUtils.java +++ b/java/src/com/android/textclassifier/ExtrasUtils.java @@ -87,7 +87,9 @@ public final class ExtrasUtils { return classification.getExtras().getBundle(FOREIGN_LANGUAGE); } - /** @see #getTopLanguage(Intent) */ + /** + * @see #getTopLanguage(Intent) + */ static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) { final int maxSize = Math.min(3, languageScores.getEntities().size()); final String[] languages = diff --git a/java/src/com/android/textclassifier/ModelFileManagerImpl.java b/java/src/com/android/textclassifier/ModelFileManagerImpl.java index 45426d0..e3b646f 100644 --- a/java/src/com/android/textclassifier/ModelFileManagerImpl.java +++ b/java/src/com/android/textclassifier/ModelFileManagerImpl.java @@ -390,7 +390,18 @@ final class ModelFileManagerImpl implements ModelFileManager { localePreferences.get(0), targetLocale)); } - return findBestModelFile(modelType, targetLocale); + ModelFile modelFile = findBestModelFile(modelType, targetLocale); + TcLog.d( + TAG, + String.format( + Locale.US, + "findBestModelFile: best model: %s; localePreferences: %s; detectedLocales: %s;" + + " targetLocale: %s", + modelFile, + localePreferences, + detectedLocales, + targetLocale)); + return modelFile; } /** diff --git a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java index dae0442..67e300d 100644 --- a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java +++ b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java @@ -66,8 +66,8 @@ public final class ResultIdUtils { } /** Returns if the result id was generated from the default text classifier. */ - public static boolean isFromDefaultTextClassifier(String resultId) { - return resultId.startsWith(CLASSIFIER_ID + '|'); + public static boolean isFromDefaultTextClassifier(@Nullable String resultId) { + return resultId != null && resultId.startsWith(CLASSIFIER_ID + '|'); } /** Returns all the model names encoded in the signature. */ diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java index 1ae79ce..9bdfb5e 100644 --- a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java +++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java @@ -195,7 +195,7 @@ public final class DownloadedModelManagerImpl implements DownloadedModelManager @Override public void onDownloadCompleted( ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload) { - TcLog.v(TAG, "Start to clean up models and update model lookup cache..."); + TcLog.d(TAG, "Start to clean up models and update model lookup cache..."); // Step 1: Clean up ManifestEnrollment table List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments(); List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>(); @@ -286,7 +286,7 @@ public final class DownloadedModelManagerImpl implements DownloadedModelManager // Clear the cache table and rebuild the cache based on ModelView table private void updateCache() { synchronized (cacheLock) { - TcLog.v(TAG, "Updating model lookup cache..."); + TcLog.d(TAG, "Updating model lookup cache..."); for (String modelType : ModelType.values()) { modelLookupCache.get(modelType).clear(); } diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java index b125f13..af33e21 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java @@ -44,6 +44,7 @@ import com.android.textclassifier.utils.IndentingPrintWriter; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Enums; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.hash.Hashing; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; @@ -54,6 +55,7 @@ import java.time.Instant; import java.util.List; import java.util.Locale; import java.util.UUID; +import java.util.concurrent.Callable; import javax.annotation.Nullable; /** Manager to listen to config update and download latest models. */ @@ -64,6 +66,7 @@ public final class ModelDownloadManager { private final Context appContext; private final Class<? extends ListenableWorker> modelDownloadWorkerClass; + private final Callable<WorkManager> workManagerSupplier; private final DownloadedModelManager downloadedModelManager; private final TextClassifierSettings settings; private final ListeningExecutorService executorService; @@ -84,6 +87,7 @@ public final class ModelDownloadManager { this( appContext, ModelDownloadWorker.class, + () -> WorkManager.getInstance(appContext), DownloadedModelManagerImpl.getInstance(appContext), settings, executorService); @@ -93,11 +97,13 @@ public final class ModelDownloadManager { public ModelDownloadManager( Context appContext, Class<? extends ListenableWorker> modelDownloadWorkerClass, + Callable<WorkManager> workManagerSupplier, DownloadedModelManager downloadedModelManager, TextClassifierSettings settings, ListeningExecutorService executorService) { this.appContext = Preconditions.checkNotNull(appContext); this.modelDownloadWorkerClass = Preconditions.checkNotNull(modelDownloadWorkerClass); + this.workManagerSupplier = Preconditions.checkNotNull(workManagerSupplier); this.downloadedModelManager = Preconditions.checkNotNull(downloadedModelManager); this.settings = Preconditions.checkNotNull(settings); this.executorService = Preconditions.checkNotNull(executorService); @@ -121,22 +127,31 @@ public final class ModelDownloadManager { /** Returns the downlaoded models for the given modelType. */ @Nullable public List<File> listDownloadedModels(@ModelTypeDef String modelType) { - return downloadedModelManager.listModels(modelType); + try { + return downloadedModelManager.listModels(modelType); + } catch (Throwable t) { + TcLog.e(TAG, "Failed to list downloaded models", t); + return ImmutableList.of(); + } } /** Notifies the model downlaoder that the text classifier service is created. */ public void onTextClassifierServiceCreated() { - DeviceConfig.addOnPropertiesChangedListener( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener); - appContext.registerReceiver( - localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED)); - TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered."); - if (!settings.isModelDownloadManagerEnabled()) { - return; + try { + DeviceConfig.addOnPropertiesChangedListener( + DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener); + appContext.registerReceiver( + localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED)); + TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered."); + if (!settings.isModelDownloadManagerEnabled()) { + return; + } + maybeOverrideLocaleListForTesting(); + TcLog.d(TAG, "Try to schedule model download work because TextClassifierService started."); + scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED); + } catch (Throwable t) { + TcLog.e(TAG, "Failed inside onTextClassifierServiceCreated", t); } - maybeOverrideLocaleListForTesting(); - TcLog.v(TAG, "Try to schedule model download work because TextClassifierService started."); - scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED); } // TODO(licha): Make this private. Let the constructor accept a receiver to enable testing. @@ -146,8 +161,12 @@ public final class ModelDownloadManager { if (!settings.isModelDownloadManagerEnabled()) { return; } - TcLog.v(TAG, "Try to schedule model download work because of system locale changes."); - scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED); + TcLog.d(TAG, "Try to schedule model download work because of system locale changes."); + try { + scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED); + } catch (Throwable t) { + TcLog.e(TAG, "Failed inside onLocaleChanged", t); + } } // TODO(licha): Make this private. Let the constructor accept a receiver to enable testing. @@ -157,16 +176,24 @@ public final class ModelDownloadManager { if (!settings.isModelDownloadManagerEnabled()) { return; } - maybeOverrideLocaleListForTesting(); - TcLog.v(TAG, "Try to schedule model download work because of device config changes."); - scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED); + TcLog.d(TAG, "Try to schedule model download work because of device config changes."); + try { + maybeOverrideLocaleListForTesting(); + scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED); + } catch (Throwable t) { + TcLog.e(TAG, "Failed inside onTextClassifierDeviceConfigChanged", t); + } } /** Clean up internal states on destroying. */ public void destroy() { - DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener); - appContext.unregisterReceiver(localeChangedReceiver); - TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager"); + try { + DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener); + appContext.unregisterReceiver(localeChangedReceiver); + TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager"); + } catch (Throwable t) { + TcLog.e(TAG, "Failed to destroy ModelDownloadManager", t); + } } /** @@ -178,10 +205,14 @@ public final class ModelDownloadManager { if (!settings.isModelDownloadManagerEnabled()) { return; } - printWriter.println("ModelDownloadManager:"); - printWriter.increaseIndent(); - downloadedModelManager.dump(printWriter); - printWriter.decreaseIndent(); + try { + printWriter.println("ModelDownloadManager:"); + printWriter.increaseIndent(); + downloadedModelManager.dump(printWriter); + printWriter.decreaseIndent(); + } catch (Throwable t) { + TcLog.e(TAG, "Failed to dump ModelDownloadManager", t); + } } /** @@ -193,54 +224,62 @@ public final class ModelDownloadManager { private void scheduleDownloadWork(int reasonToSchedule) { long workId = Hashing.farmHashFingerprint64().hashUnencodedChars(UUID.randomUUID().toString()).asLong(); - NetworkType networkType = - Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType()) - .or(NetworkType.UNMETERED); - OneTimeWorkRequest downloadRequest = - new OneTimeWorkRequest.Builder(modelDownloadWorkerClass) - .setConstraints( - new Constraints.Builder() - .setRequiredNetworkType(networkType) - .setRequiresBatteryNotLow(true) - .setRequiresStorageNotLow(true) - .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle()) - .setRequiresCharging(settings.getManifestDownloadRequiresCharging()) - .build()) - .setBackoffCriteria( - BackoffPolicy.EXPONENTIAL, - settings.getModelDownloadBackoffDelayInMillis(), - MILLISECONDS) - .setInputData( - new Data.Builder() - .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId) - .putLong( - ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP, - Instant.now().toEpochMilli()) - .build()) - .build(); - ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture = - WorkManager.getInstance(appContext) - .enqueueUniqueWork( - UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest) - .getResult(); - Futures.addCallback( - enqueueResultFuture, - new FutureCallback<Operation.State.SUCCESS>() { - @Override - public void onSuccess(Operation.State.SUCCESS unused) { - TcLog.v(TAG, "Download work scheduled."); - TextClassifierDownloadLogger.downloadWorkScheduled( - workId, reasonToSchedule, /* failedToSchedule= */ false); - } + try { + NetworkType networkType = + Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType()) + .or(NetworkType.UNMETERED); + OneTimeWorkRequest downloadRequest = + new OneTimeWorkRequest.Builder(modelDownloadWorkerClass) + .setConstraints( + new Constraints.Builder() + .setRequiredNetworkType(networkType) + .setRequiresBatteryNotLow(true) + .setRequiresStorageNotLow(true) + .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle()) + .setRequiresCharging(settings.getManifestDownloadRequiresCharging()) + .build()) + .setBackoffCriteria( + BackoffPolicy.EXPONENTIAL, + settings.getModelDownloadBackoffDelayInMillis(), + MILLISECONDS) + .setInputData( + new Data.Builder() + .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId) + .putLong( + ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP, + Instant.now().toEpochMilli()) + .build()) + .build(); + ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture = + workManagerSupplier + .call() + .enqueueUniqueWork( + UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest) + .getResult(); + Futures.addCallback( + enqueueResultFuture, + new FutureCallback<Operation.State.SUCCESS>() { + @Override + public void onSuccess(Operation.State.SUCCESS unused) { + TcLog.d(TAG, "Download work scheduled."); + TextClassifierDownloadLogger.downloadWorkScheduled( + workId, reasonToSchedule, /* failedToSchedule= */ false); + } - @Override - public void onFailure(Throwable t) { - TcLog.e(TAG, "Failed to schedule download work: ", t); - TextClassifierDownloadLogger.downloadWorkScheduled( - workId, reasonToSchedule, /* failedToSchedule= */ true); - } - }, - executorService); + @Override + public void onFailure(Throwable t) { + TcLog.e(TAG, "Failed to schedule download work: ", t); + TextClassifierDownloadLogger.downloadWorkScheduled( + workId, reasonToSchedule, /* failedToSchedule= */ true); + } + }, + executorService); + } catch (Throwable t) { + // TODO(licha): this is just for temporary fix. Refactor the try-catch in the future. + TcLog.e(TAG, "Failed to schedule download work: ", t); + TextClassifierDownloadLogger.downloadWorkScheduled( + workId, reasonToSchedule, /* failedToSchedule= */ true); + } } private void maybeOverrideLocaleListForTesting() { diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java index 6e04e16..3db0815 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java @@ -113,6 +113,7 @@ public final class ModelDownloadWorker extends ListenableWorker { @Override public final ListenableFuture<ListenableWorker.Result> startWork() { + TcLog.d(TAG, "Start download work..."); workStartedTimeMillis = getCurrentTimeMillis(); // Notice: startWork() is invoked on the main thread if (!settings.isModelDownloadManagerEnabled()) { @@ -121,7 +122,6 @@ public final class ModelDownloadWorker extends ListenableWorker { TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED); return Futures.immediateFuture(ListenableWorker.Result.failure()); } - TcLog.v(TAG, "Start download work..."); if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) { TcLog.d(TAG, "Max attempt reached. Abort download work."); logDownloadWorkCompleted( @@ -134,7 +134,7 @@ public final class ModelDownloadWorker extends ListenableWorker { downloadResult -> { Preconditions.checkNotNull(manifestsToDownload); downloadedModelManager.onDownloadCompleted(manifestsToDownload); - TcLog.v(TAG, "Download work completed: " + downloadResult); + TcLog.d(TAG, "Download work completed: " + downloadResult); if (downloadResult.failureCount() == 0) { logDownloadWorkCompleted( downloadResult.successCount() > 0 @@ -239,7 +239,7 @@ public final class ModelDownloadWorker extends ListenableWorker { return Futures.whenAllComplete(downloadResultFutures) .call( () -> { - TcLog.v(TAG, "All Download Tasks Completed"); + TcLog.d(TAG, "All Download Tasks Completed"); int successCount = 0; int failureCount = 0; for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) { @@ -333,7 +333,7 @@ public final class ModelDownloadWorker extends ListenableWorker { Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl); if (downloadedManifest != null && downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) { - TcLog.v(TAG, "Manifest already downloaded: " + manifestUrl); + TcLog.d(TAG, "Manifest already downloaded: " + manifestUrl); return Futures.immediateVoidFuture(); } if (pendingDownloads.containsKey(manifestUrl)) { diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java index 2244e9a..0b76f22 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java @@ -99,7 +99,7 @@ final class ModelDownloaderImpl implements ModelDownloader { new FutureCallback<File>() { @Override public void onSuccess(File pendingModelFile) { - TcLog.v(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath()); + TcLog.d(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath()); } @Override @@ -170,11 +170,11 @@ final class ModelDownloaderImpl implements ModelDownloader { } catch (IOException e) { throw new ModelDownloadException(ModelDownloadException.FAILED_TO_VALIDATE_MODEL, e); } - TcLog.v(TAG, "Pending model file passed validation."); + TcLog.d(TAG, "Pending model file passed validation."); } private ListenableFuture<IModelDownloaderService> connect(DownloaderServiceConnection conn) { - TcLog.v(TAG, "Starting a new connection to ModelDownloaderService"); + TcLog.d(TAG, "Starting a new connection to ModelDownloaderService"); return CallbackToFutureAdapter.getFuture( completer -> { conn.attachCompleter(completer); @@ -197,7 +197,7 @@ final class ModelDownloaderImpl implements ModelDownloader { // restult future will hang there until time out. (WorkManager forces a 10-min running time.) private static ListenableFuture<Long> scheduleDownload( IModelDownloaderService service, URI uri, File targetFile) { - TcLog.v(TAG, "Scheduling a new download task with ModelDownloaderService"); + TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService"); return CallbackToFutureAdapter.getFuture( completer -> { service.download( @@ -236,7 +236,7 @@ final class ModelDownloaderImpl implements ModelDownloader { @Override public void onServiceConnected(ComponentName componentName, IBinder iBinder) { - TcLog.v(TAG, "DownloaderService connected"); + TcLog.d(TAG, "DownloaderService connected"); completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder))); } diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java index e4ebbfa..6d7e47e 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java @@ -39,7 +39,7 @@ public final class ModelDownloaderService extends Service { @Override public IBinder onBind(Intent intent) { - TcLog.v(TAG, "Binding to ModelDownloadService"); + TcLog.d(TAG, "Binding to ModelDownloadService"); return iBinder; } } diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java index 439588b..47e6f19 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java @@ -91,7 +91,7 @@ final class ModelDownloaderServiceImpl extends IModelDownloaderService.Stub { @Override public void download(String uri, String targetFilePath, IModelDownloaderCallback callback) { - TcLog.v(TAG, "Download request received: " + uri); + TcLog.d(TAG, "Download request received: " + uri); try { File targetFile = new File(targetFilePath); File tempMetadataFile = getMetadataFile(targetFile); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java index 0e3842c..ddab8bd 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java @@ -17,7 +17,10 @@ package com.android.textclassifier; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import android.content.Context; import android.os.CancellationSignal; @@ -39,6 +42,7 @@ import com.android.os.AtomsProto.Atom; import com.android.os.AtomsProto.TextClassifierApiUsageReported; import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType; import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType; +import com.android.textclassifier.common.ModelType; import com.android.textclassifier.common.TextClassifierSettings; import com.android.textclassifier.common.statsd.StatsdTestUtils; import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger; @@ -47,21 +51,27 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; +import java.io.IOException; import java.util.List; import java.util.concurrent.Executor; import java.util.stream.Collectors; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @SmallTest @RunWith(AndroidJUnit4.class) public class DefaultTextClassifierServiceTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + /** A statsd config ID, which is arbitrary. */ private static final long CONFIG_ID = 689777; @@ -76,14 +86,21 @@ public class DefaultTextClassifierServiceTest { @Mock private TextClassifierService.Callback<TextLinks> textLinksCallback; @Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback; @Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback; + @Mock private ModelFileManager testModelFileManager; @Before - public void setup() { - MockitoAnnotations.initMocks(this); - - testInjector = new TestInjector(ApplicationProvider.getApplicationContext()); + public void setup() throws IOException { + testInjector = + new TestInjector(ApplicationProvider.getApplicationContext(), testModelFileManager); defaultTextClassifierService = new DefaultTextClassifierService(testInjector); defaultTextClassifierService.onCreate(); + + when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); + when(testModelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) + .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); + when(testModelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) + .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); } @Before @@ -207,11 +224,8 @@ public class DefaultTextClassifierServiceTest { @Test public void missingModelFile_onFailureShouldBeCalled() throws Exception { - testInjector.setModelFileManager( - new ModelFileManagerImpl( - ApplicationProvider.getApplicationContext(), - ImmutableList.of(), - testInjector.createTextClassifierSettings())); + when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(null); defaultTextClassifierService.onCreate(); TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build(); @@ -247,12 +261,9 @@ public class DefaultTextClassifierServiceTest { private final Context context; private ModelFileManager modelFileManager; - private TestInjector(Context context) { + private TestInjector(Context context, ModelFileManager modelFileManager) { this.context = Preconditions.checkNotNull(context); - } - - private void setModelFileManager(ModelFileManager modelFileManager) { - this.modelFileManager = modelFileManager; + this.modelFileManager = Preconditions.checkNotNull(modelFileManager); } @Override @@ -263,9 +274,6 @@ public class DefaultTextClassifierServiceTest { @Override public ModelFileManager createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) { - if (modelFileManager == null) { - return TestDataUtils.createModelFileManagerForTesting(context); - } return modelFileManager; } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java index 5297640..0e40515 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java @@ -25,6 +25,7 @@ import android.os.LocaleList; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import androidx.test.filters.SmallTest; +import androidx.work.WorkManager; import com.android.textclassifier.ModelFileManagerImpl.DownloaderModelsLister; import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister; import com.android.textclassifier.ModelFileManagerImpl.RegularFilePatternMatchLister; @@ -53,7 +54,8 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @SmallTest @RunWith(AndroidJUnit4.class) @@ -67,6 +69,7 @@ public final class ModelFileManagerImplTest { @Mock private DownloadedModelManager downloadedModelManager; @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule(); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); private File rootTestDir; private ModelFileManagerImpl modelFileManager; @@ -75,7 +78,6 @@ public final class ModelFileManagerImplTest { @Before public void setup() { - MockitoAnnotations.initMocks(this); deviceConfig = new TestingDeviceConfig(); rootTestDir = new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir"); @@ -86,6 +88,7 @@ public final class ModelFileManagerImplTest { new ModelDownloadManager( context, ModelDownloadWorker.class, + () -> WorkManager.getInstance(context), downloadedModelManager, settings, MoreExecutors.newDirectExecutorService()); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java index bac4fa1..a19e3ff 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java @@ -16,12 +16,10 @@ package com.android.textclassifier; -import android.content.Context; -import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister; +import com.android.textclassifier.common.ModelFile; import com.android.textclassifier.common.ModelType; -import com.android.textclassifier.common.TextClassifierSettings; -import com.google.common.collect.ImmutableList; import java.io.File; +import java.io.IOException; /** Utils to access test data files. */ public final class TestDataUtils { @@ -30,7 +28,7 @@ public final class TestDataUtils { private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model"; /** Returns the root folder that contains the test data. */ - public static File getTestDataFolder() { + private static File getTestDataFolder() { return new File("/data/local/tmp/TextClassifierServiceTest/"); } @@ -38,24 +36,25 @@ public final class TestDataUtils { return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH); } + public static ModelFile getTestAnnotatorModelFileWrapped() throws IOException { + return ModelFile.createFromRegularFile(getTestAnnotatorModelFile(), ModelType.ANNOTATOR); + } + public static File getTestActionsModelFile() { return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH); } + public static ModelFile getTestActionsModelFileWrapped() throws IOException { + return ModelFile.createFromRegularFile( + getTestActionsModelFile(), ModelType.ACTIONS_SUGGESTIONS); + } + public static File getLangIdModelFile() { return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH); } - public static ModelFileManager createModelFileManagerForTesting(Context context) { - return new ModelFileManagerImpl( - context, - ImmutableList.of( - new RegularFileFullMatchLister( - ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true), - new RegularFileFullMatchLister( - ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true), - new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)), - new TextClassifierSettings()); + public static ModelFile getLangIdModelFileWrapped() throws IOException { + return ModelFile.createFromRegularFile(getLangIdModelFile(), ModelType.LANG_ID); } private TestDataUtils() {} diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java index 42177e6..e7bf90c 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java @@ -56,6 +56,10 @@ public class TextClassifierApiTest { @Before public void setup() { + extServicesTextClassifierRule.enableVerboseLogging(); + // Verbose logging only takes effect after restarting ExtServices + extServicesTextClassifierRule.forceStopExtServices(); + textClassifier = extServicesTextClassifierRule.getTextClassifier(); } @@ -81,8 +85,8 @@ public class TextClassifierApiTest { @Test public void classifyText() { - String text = "Contact me at droid@android.com"; - String classifiedText = "droid@android.com"; + String text = "Contact me at http://www.android.com"; + String classifiedText = "http://www.android.com"; int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = @@ -90,7 +94,7 @@ public class TextClassifierApiTest { TextClassification classification = textClassifier.classifyText(request); assertThat(classification.getEntityCount()).isGreaterThan(0); - assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL); + assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL); assertThat(classification.getText()).isEqualTo(classifiedText); assertThat(classification.getActions()).isNotEmpty(); } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java index fb1aea8..c20ec8a 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java @@ -22,6 +22,9 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.testng.Assert.expectThrows; @@ -74,30 +77,34 @@ public class TextClassifierImplTest { private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US"); private static final String NO_TYPE = null; - @Mock private ModelFileManagerImpl.ModelFileLister mockModelFileLister; + @Mock private ModelFileManager modelFileManager; - private TextClassifierSettings settings; private Context context; private TestingDeviceConfig deviceConfig; - private TextClassifierImpl classifier; - - private final ModelFileManager modelFileManager = - TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext()); + private TextClassifierSettings settings; private LruCache<ModelFile, AnnotatorModel> annotatorModelCache; + private TextClassifierImpl classifier; @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.initMocks(this); - deviceConfig = new TestingDeviceConfig(); - Context context = + this.context = new FakeContextBuilder() .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT) .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app") .build(); - this.context = context; - settings = new TextClassifierSettings(deviceConfig); - // TODO(veronikanikina): consider using a testing constructor here. - classifier = new TextClassifierImpl(context, settings, modelFileManager); + this.deviceConfig = new TestingDeviceConfig(); + this.settings = new TextClassifierSettings(deviceConfig); + this.annotatorModelCache = new LruCache<>(2); + this.classifier = + new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache); + + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); + when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) + .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); + when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) + .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); } @Test @@ -110,9 +117,7 @@ public class TextClassifierImplTest { int smartStartIndex = text.indexOf(suggested); int smartEndIndex = smartStartIndex + suggested.length(); TextSelection.Request request = - new TextSelection.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat( @@ -120,6 +125,24 @@ public class TextClassifierImplTest { } @Test + public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException { + String text = "Contact me at droid@android.com"; + String selected = "droid"; + String suggested = "droid@android.com"; + int startIndex = text.indexOf(selected); + int endIndex = startIndex + selected.length(); + int smartStartIndex = text.indexOf(suggested); + int smartEndIndex = smartStartIndex + suggested.length(); + TextSelection.Request request = + new TextSelection.Request.Builder(text, startIndex, endIndex) + .setDefaultLocales(LOCALES) + .build(); + + classifier.suggestSelection(null, null, request); + verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any()); + } + + @Test public void testSuggestSelection_url() throws IOException { String text = "Visit http://www.android.com for more information"; String selected = "http"; @@ -129,9 +152,7 @@ public class TextClassifierImplTest { int smartStartIndex = text.indexOf(suggested); int smartEndIndex = smartStartIndex + suggested.length(); TextSelection.Request request = - new TextSelection.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL)); @@ -144,9 +165,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(selected); int endIndex = startIndex + selected.length(); TextSelection.Request request = - new TextSelection.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE)); @@ -160,7 +179,6 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(suggested); TextSelection.Request request = new TextSelection.Request.Builder(text, startIndex, /*endIndex=*/ startIndex + 1) - .setDefaultLocales(LOCALES) .setIncludeTextClassification(true) .build(); @@ -178,7 +196,6 @@ public class TextClassifierImplTest { String text = "Visit http://www.android.com for more information"; TextSelection.Request request = new TextSelection.Request.Builder(text, /*startIndex=*/ 0, /*endIndex=*/ 4) - .setDefaultLocales(LOCALES) .setIncludeTextClassification(false) .build(); @@ -194,9 +211,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(/* sessionId= */ null, null, request); @@ -210,9 +225,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); @@ -223,9 +236,7 @@ public class TextClassifierImplTest { public void testClassifyText_address() throws IOException { String text = "Brandschenkestrasse 110, Zürich, Switzerland"; TextClassification.Request request = - new TextClassification.Request.Builder(text, 0, text.length()) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, 0, text.length()).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS)); @@ -238,9 +249,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); @@ -254,9 +263,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE)); @@ -275,9 +282,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME)); @@ -289,14 +294,12 @@ public class TextClassifierImplTest { LocaleList.setDefault(LocaleList.forLanguageTags("en")); String japaneseText = "これは日本語のテキストです"; TextClassification.Request request = - new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build(); TextClassification classification = classifier.classifyText(null, null, request); RemoteAction translateAction = classification.getActions().get(0); assertEquals(1, classification.getActions().size()); - assertEquals("Translate", translateAction.getTitle().toString()); + assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction()); assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification)); Intent intent = ExtrasUtils.getActionsIntents(classification).get(0); @@ -323,18 +326,17 @@ public class TextClassifierImplTest { @Test public void testGenerateLinks_exclude() throws IOException { - String text = "You want apple@banana.com. See you tonight!"; + String text = "The number is +12122537077. See you tonight!"; List<String> hints = ImmutableList.of(); List<String> included = ImmutableList.of(); - List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL); + List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE); TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) - .setDefaultLocales(LOCALES) .build(); assertThat( classifier.generateLinks(null, null, request), - not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL))); + not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE))); } @Test @@ -344,7 +346,6 @@ public class TextClassifierImplTest { TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit)) - .setDefaultLocales(LOCALES) .build(); assertThat( classifier.generateLinks(null, null, request), @@ -361,7 +362,6 @@ public class TextClassifierImplTest { TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) - .setDefaultLocales(LOCALES) .build(); assertThat( classifier.generateLinks(null, null, request), @@ -573,29 +573,16 @@ public class TextClassifierImplTest { new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); ModelFile annotatorModelB = new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); - String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath(); - ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false); - - annotatorModelCache = new LruCache<>(2); - ModelFileManager modelFileManagerCached = - new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings); - TextClassifierImpl textClassifierImpl = - new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache); - LocaleList.setDefault(LocaleList.forLanguageTags("en")); String englishText = "You can reach me on +12122537077."; String classifiedText = "+12122537077"; TextClassification.Request request = - new TextClassification.Request.Builder(englishText, 0, englishText.length()) - .setDefaultLocales(LOCALES) - .build(); - - when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel)); + new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); // Check modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classificationA = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classificationA = classifier.classifyText(null, null, request); assertThat(classificationA.getId()).contains("v701"); assertThat(classificationA.getText()).contains(classifiedText); @@ -609,9 +596,9 @@ public class TextClassifierImplTest { }); // Check modelFileB v801 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelB)); - TextClassification classificationB = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelB); + TextClassification classificationB = classifier.classifyText(null, null, request); assertThat(classificationB.getId()).contains("v801"); assertThat(classificationB.getText()).contains(classifiedText); @@ -625,9 +612,9 @@ public class TextClassifierImplTest { }); // Reload modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classificationAcached = classifier.classifyText(null, null, request); assertThat(classificationAcached.getId()).contains("v701"); assertThat(classificationAcached.getText()).contains(classifiedText); @@ -651,28 +638,16 @@ public class TextClassifierImplTest { new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); ModelFile annotatorModelB = new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); - String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath(); - ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false); - - annotatorModelCache = new LruCache<>(settings.getMultiAnnotatorCacheSize()); - ModelFileManager modelFileManagerCached = - new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings); - TextClassifierImpl textClassifierImpl = - new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache); - LocaleList.setDefault(LocaleList.forLanguageTags("en")); + String englishText = "You can reach me on +12122537077."; String classifiedText = "+12122537077"; TextClassification.Request request = - new TextClassification.Request.Builder(englishText, 0, englishText.length()) - .setDefaultLocales(LOCALES) - .build(); - - when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel)); + new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); // Check modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classification = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification.getId()).contains("v701"); assertThat(classification.getText()).contains(classifiedText); @@ -686,9 +661,9 @@ public class TextClassifierImplTest { }); // Check modelFileB v801 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelB)); - TextClassification classificationB = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelB); + TextClassification classificationB = classifier.classifyText(null, null, request); assertThat(classificationB.getId()).contains("v801"); assertThat(classificationB.getText()).contains(classifiedText); @@ -702,9 +677,9 @@ public class TextClassifierImplTest { }); // Reload modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classificationAcached = classifier.classifyText(null, null, request); assertThat(classificationAcached.getId()).contains("v701"); assertThat(classificationAcached.getText()).contains(classifiedText); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java index 216cd5d..3aab211 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java @@ -27,14 +27,18 @@ import com.google.android.textclassifier.NamedVariant; import com.google.android.textclassifier.RemoteActionTemplate; import java.util.List; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @SmallTest @RunWith(AndroidJUnit4.class) public class TemplateIntentFactoryTest { + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private static final String TITLE_WITHOUT_ENTITY = "Map"; private static final String TITLE_WITH_ENTITY = "Map NW14D1"; private static final String DESCRIPTION = "Check the map"; @@ -71,7 +75,6 @@ public class TemplateIntentFactoryTest { @Before public void setup() { - MockitoAnnotations.initMocks(this); templateIntentFactory = new TemplateIntentFactory(); } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java index ffd2ee4..3a8fefc 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java @@ -86,8 +86,8 @@ public class StatsdTestUtils { return ImmutableList.copyOf( metricsList.stream() .flatMap(statsLogReport -> statsLogReport.getEventMetrics().getDataList().stream()) - .flatMap(eventMetricData -> backfillAggregatedAtomsinEventMetric( - eventMetricData).stream()) + .flatMap( + eventMetricData -> backfillAggregatedAtomsinEventMetric(eventMetricData).stream()) .sorted(Comparator.comparing(EventMetricData::getElapsedTimestampNanos)) .map(EventMetricData::getAtom) .collect(Collectors.toList())); @@ -136,7 +136,7 @@ public class StatsdTestUtils { } private static ImmutableList<EventMetricData> backfillAggregatedAtomsinEventMetric( - EventMetricData metricData) { + EventMetricData metricData) { if (metricData.hasAtom()) { return ImmutableList.of(metricData); } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java index c626ed7..9e11c09 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java @@ -46,7 +46,8 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @RunWith(AndroidJUnit4.class) public final class ModelDownloadManagerTest { @@ -61,14 +62,16 @@ public final class ModelDownloadManagerTest { public final TextClassifierDownloadLoggerTestRule loggerTestRule = new TextClassifierDownloadLoggerTestRule(); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private TestingDeviceConfig deviceConfig; private WorkManager workManager; private ModelDownloadManager downloadManager; + private ModelDownloadManager downloadManagerWithBadWorkManager; @Mock DownloadedModelManager downloadedModelManager; @Before public void setUp() { - MockitoAnnotations.initMocks(this); Context context = ApplicationProvider.getApplicationContext(); WorkManagerTestInitHelper.initializeTestWorkManager(context); @@ -78,6 +81,17 @@ public final class ModelDownloadManagerTest { new ModelDownloadManager( context, ModelDownloadWorker.class, + () -> workManager, + downloadedModelManager, + new TextClassifierSettings(deviceConfig), + MoreExecutors.newDirectExecutorService()); + this.downloadManagerWithBadWorkManager = + new ModelDownloadManager( + context, + ModelDownloadWorker.class, + () -> { + throw new IllegalStateException("WorkManager may fail!"); + }, downloadedModelManager, new TextClassifierSettings(deviceConfig), MoreExecutors.newDirectExecutorService()); @@ -94,7 +108,20 @@ public final class ModelDownloadManagerTest { } @Test + public void onTextClassifierServiceCreated_workManagerCrashed() throws Exception { + assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); + downloadManagerWithBadWorkManager.onTextClassifierServiceCreated(); + + // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test + TextClassifierDownloadWorkScheduled atom = + Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); + assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.TCS_STARTED); + assertThat(atom.getFailedToSchedule()).isTrue(); + } + + @Test public void onTextClassifierServiceCreated_requestEnqueued() throws Exception { + assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); downloadManager.onTextClassifierServiceCreated(); WorkInfo workInfo = @@ -102,21 +129,34 @@ public final class ModelDownloadManagerTest { DownloaderTestUtils.queryWorkInfos( workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)); assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED); + // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED); } @Test public void onTextClassifierServiceCreated_localeListOverridden() throws Exception { + assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr"); downloadManager.onTextClassifierServiceCreated(); assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh")); assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); + // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED); } @Test + public void onLocaleChanged_workManagerCrashed() throws Exception { + downloadManagerWithBadWorkManager.onLocaleChanged(); + + TextClassifierDownloadWorkScheduled atom = + Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); + assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.LOCALE_SETTINGS_CHANGED); + assertThat(atom.getFailedToSchedule()).isTrue(); + } + + @Test public void onLocaleChanged_requestEnqueued() throws Exception { downloadManager.onLocaleChanged(); @@ -129,6 +169,16 @@ public final class ModelDownloadManagerTest { } @Test + public void onTextClassifierDeviceConfigChanged_workManagerCrashed() throws Exception { + downloadManagerWithBadWorkManager.onTextClassifierDeviceConfigChanged(); + + TextClassifierDownloadWorkScheduled atom = + Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); + assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.DEVICE_CONFIG_UPDATED); + assertThat(atom.getFailedToSchedule()).isTrue(); + } + + @Test public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception { downloadManager.onTextClassifierDeviceConfigChanged(); @@ -186,6 +236,13 @@ public final class ModelDownloadManagerTest { assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile); } + @Test + public void listDownloadedModels_doNotCrashOnError() throws Exception { + when(downloadedModelManager.listModels(MODEL_TYPE)).thenThrow(new IllegalStateException()); + + assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).isEmpty(); + } + private void verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule) throws Exception { TextClassifierDownloadWorkScheduled atom = Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java index 9f555fc..e261158 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java @@ -18,16 +18,10 @@ package com.android.textclassifier.downloader; import static com.google.common.truth.Truth.assertThat; -import android.app.Instrumentation; -import android.app.UiAutomation; import android.util.Log; import android.view.textclassifier.TextClassification; import android.view.textclassifier.TextClassification.Request; -import android.view.textclassifier.TextClassifier; -import androidx.test.platform.app.InstrumentationRegistry; import com.android.textclassifier.testing.ExtServicesTextClassifierRule; -import com.android.textclassifier.testing.TestingLocaleListOverrideRule; -import java.io.IOException; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -48,170 +42,133 @@ public class ModelDownloaderIntegrationTest { private static final String V804_EN_TAG = "en_v804"; private static final String V804_RU_TAG = "ru_v804"; private static final String FACTORY_MODEL_TAG = "*"; - - @Rule - public final TestingLocaleListOverrideRule testingLocaleListOverrideRule = - new TestingLocaleListOverrideRule(); + private static final int ASSERT_MAX_ATTEMPTS = 20; + private static final int ASSERT_SLEEP_BEFORE_RETRY_MS = 1000; @Rule public final ExtServicesTextClassifierRule extServicesTextClassifierRule = new ExtServicesTextClassifierRule(); - private TextClassifier textClassifier; - @Before public void setup() throws Exception { - // Flag overrides below can be overridden by Phenotype sync, which makes this test flaky - runShellCommand("device_config put textclassifier config_updater_model_enabled false"); - runShellCommand("device_config put textclassifier model_download_manager_enabled true"); - runShellCommand("device_config put textclassifier model_download_backoff_delay_in_millis 5"); - - textClassifier = extServicesTextClassifierRule.getTextClassifier(); - startExtservicesProcess(); + extServicesTextClassifierRule.addDeviceConfigOverride("config_updater_model_enabled", "false"); + extServicesTextClassifierRule.addDeviceConfigOverride("model_download_manager_enabled", "true"); + extServicesTextClassifierRule.addDeviceConfigOverride( + "model_download_backoff_delay_in_millis", "5"); + extServicesTextClassifierRule.addDeviceConfigOverride("testing_locale_list_override", "en-US"); + extServicesTextClassifierRule.overrideDeviceConfig(); + + extServicesTextClassifierRule.enableVerboseLogging(); + // Verbose logging only takes effect after restarting ExtServices + extServicesTextClassifierRule.forceStopExtServices(); } @After public void tearDown() throws Exception { - runShellCommand("device_config delete textclassifier manifest_url_annotator_en"); - runShellCommand("device_config delete textclassifier manifest_url_annotator_ru"); - runShellCommand("device_config put textclassifier config_updater_model_enabled true"); - runShellCommand("device_config delete textclassifier multi_language_support_enabled"); - runShellCommand( - "device_config put textclassifier model_download_backoff_delay_in_millis 3600000"); + // This is to reset logging/locale_override for ExtServices. + extServicesTextClassifierRule.forceStopExtServices(); } @Test - public void smokeTest() throws IOException, InterruptedException { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + V804_EN_ANNOTATOR_MANIFEST_URL); + public void smokeTest() throws Exception { + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveEnglishModel(V804_EN_TAG)); + assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG)); } @Test - public void downgradeModel() throws IOException, InterruptedException { + public void downgradeModel() throws Exception { // Download an experimental model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); - - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); // Downgrade to an older model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + V804_EN_ANNOTATOR_MANIFEST_URL); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveEnglishModel(V804_EN_TAG)); - } + assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG)); } @Test - public void upgradeModel() throws IOException, InterruptedException { + public void upgradeModel() throws Exception { // Download a model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + V804_EN_ANNOTATOR_MANIFEST_URL); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveEnglishModel(V804_EN_TAG)); - } + assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG)); // Upgrade to an experimental model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); - - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); } @Test - public void clearFlag() throws IOException, InterruptedException { + public void clearFlag() throws Exception { // Download a new model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); - - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); // Revert the flag. - { - runShellCommand("device_config delete textclassifier manifest_url_annotator_en"); - // Fallback to use the universal model. - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, - () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride("manifest_url_annotator_en", ""); + // Fallback to use the universal model. + assertWithRetries( + () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG)); } @Test - public void modelsForMultipleLanguagesDownloaded() throws IOException, InterruptedException { - runShellCommand("device_config put textclassifier multi_language_support_enabled true"); - testingLocaleListOverrideRule.set("en-US", "ru-RU"); + public void modelsForMultipleLanguagesDownloaded() throws Exception { + extServicesTextClassifierRule.addDeviceConfigOverride("multi_language_support_enabled", "true"); + extServicesTextClassifierRule.addDeviceConfigOverride( + "testing_locale_list_override", "en-US,ru-RU"); // download en model - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); // download ru model - runShellCommand( - "device_config put textclassifier manifest_url_annotator_ru " - + V804_RU_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_ru", V804_RU_ANNOTATOR_MANIFEST_URL); + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - assertWithRetries(/* maxAttempts= */ 10, /* sleepMs= */ 500, this::verifyActiveRussianModel); + assertWithRetries(this::verifyActiveRussianModel); assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, () -> verifyActiveModel(/* text= */ "français", /* expectedVersion= */ FACTORY_MODEL_TAG)); } - private void assertWithRetries(int maxAttempts, int sleepMs, Runnable assertRunnable) - throws InterruptedException { - for (int i = 0; i < maxAttempts; i++) { + private void assertWithRetries(Runnable assertRunnable) throws Exception { + for (int i = 0; i < ASSERT_MAX_ATTEMPTS; i++) { try { + extServicesTextClassifierRule.overrideDeviceConfig(); assertRunnable.run(); break; // success. Bail out. } catch (AssertionError ex) { - if (i == maxAttempts - 1) { // last attempt, give up. + if (i == ASSERT_MAX_ATTEMPTS - 1) { // last attempt, give up. + extServicesTextClassifierRule.dumpDefaultTextClassifierService(); throw ex; } else { - Thread.sleep(sleepMs); + Thread.sleep(ASSERT_SLEEP_BEFORE_RETRY_MS); } + } catch (Exception unknownException) { + throw unknownException; } } } private void verifyActiveModel(String text, String expectedVersion) { TextClassification textClassification = - textClassifier.classifyText(new Request.Builder(text, 0, text.length()).build()); + extServicesTextClassifierRule + .getTextClassifier() + .classifyText(new Request.Builder(text, 0, text.length()).build()); // The result id contains the name of the just used model. + Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId()); assertThat(textClassification.getId()).contains(expectedVersion); } @@ -222,16 +179,4 @@ public class ModelDownloaderIntegrationTest { private void verifyActiveRussianModel() { verifyActiveModel("привет", V804_RU_TAG); } - - private void startExtservicesProcess() { - // Start the process of ExtServices by sending it a text classifier request. - textClassifier.classifyText(new TextClassification.Request.Builder("abc", 0, 3).build()); - } - - private static void runShellCommand(String cmd) { - Log.v(TAG, "run shell command: " + cmd); - Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation(); - UiAutomation uiAutomation = instrumentation.getUiAutomation(); - uiAutomation.executeShellCommand(cmd); - } } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java index eac2af3..76d04e0 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java @@ -37,14 +37,19 @@ import com.google.common.util.concurrent.SettableFuture; import java.io.File; import java.net.URI; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @RunWith(JUnit4.class) public final class ModelDownloaderServiceImplTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private static final long BYTES_WRITTEN = 1L; private static final String DOWNLOAD_URI = "https://www.gstatic.com/android/text_classifier/r/v999/en.fb"; @@ -66,7 +71,6 @@ public final class ModelDownloaderServiceImplTest { @Before public void setUp() { - MockitoAnnotations.initMocks(this); this.targetModelFile = new File(ApplicationProvider.getApplicationContext().getCacheDir(), "model.fb"); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java index 3ceb47b..5f8247d 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java @@ -20,64 +20,72 @@ import android.app.UiAutomation; import android.content.pm.PackageManager; import android.content.pm.PackageManager.NameNotFoundException; import android.provider.DeviceConfig; +import android.util.Log; import android.view.textclassifier.TextClassificationManager; import android.view.textclassifier.TextClassifier; import androidx.test.core.app.ApplicationProvider; import androidx.test.platform.app.InstrumentationRegistry; +import com.google.common.io.ByteStreams; +import java.io.FileInputStream; +import java.io.IOException; import org.junit.rules.ExternalResource; /** A rule that manages a text classifier that is backed by the ExtServices. */ public final class ExtServicesTextClassifierRule extends ExternalResource { + private static final String TAG = "androidtc"; private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE = "textclassifier_service_package_override"; private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services"; private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services"; - private String textClassifierServiceOverrideFlagOldValue; + private UiAutomation uiAutomation; + private DeviceConfig.Properties originalProperties; + private DeviceConfig.Properties.Builder newPropertiesBuilder; @Override - protected void before() { - UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); - try { - uiAutomation.adoptShellPermissionIdentity(); - textClassifierServiceOverrideFlagOldValue = - DeviceConfig.getString( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, - CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, - null); - DeviceConfig.setProperty( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, - CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, - getExtServicesPackageName(), - /* makeDefault= */ false); - } finally { - uiAutomation.dropShellPermissionIdentity(); - } + protected void before() throws Exception { + uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); + uiAutomation.adoptShellPermissionIdentity(); + originalProperties = DeviceConfig.getProperties(DeviceConfig.NAMESPACE_TEXTCLASSIFIER); + newPropertiesBuilder = + new DeviceConfig.Properties.Builder(DeviceConfig.NAMESPACE_TEXTCLASSIFIER) + .setString( + CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, getExtServicesPackageName()); + overrideDeviceConfig(); } @Override protected void after() { - UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); try { - uiAutomation.adoptShellPermissionIdentity(); - DeviceConfig.setProperty( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, - CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, - textClassifierServiceOverrideFlagOldValue, - /* makeDefault= */ false); + DeviceConfig.setProperties(originalProperties); + } catch (Throwable t) { + Log.e(TAG, "Failed to reset DeviceConfig", t); } finally { uiAutomation.dropShellPermissionIdentity(); } } - private static String getExtServicesPackageName() { - PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager(); - try { - packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0); - return PKG_NAME_GOOGLE_EXTSERVICES; - } catch (NameNotFoundException e) { - return PKG_NAME_AOSP_EXTSERVICES; - } + public void addDeviceConfigOverride(String name, String value) { + newPropertiesBuilder.setString(name, value); + } + + /** + * Overrides the TextClassifier DeviceConfig manually. + * + * <p>This will clean up all device configs not in newPropertiesBuilder. + * + * <p>We will need to call this everytime before testing, because DeviceConfig can be synced in + * background at anytime. DeviceConfig#setSyncDisabledMode is to disable sync, however it's a + * hidden API. + */ + public void overrideDeviceConfig() throws Exception { + DeviceConfig.setProperties(newPropertiesBuilder.build()); + } + + /** Force stop ExtServices. Force-stop-and-start can be helpful to reload some states. */ + public void forceStopExtServices() { + runShellCommand("am force-stop com.google.android.ext.services"); + runShellCommand("am force-stop android.ext.services"); } public TextClassifier getTextClassifier() { @@ -87,4 +95,38 @@ public final class ExtServicesTextClassifierRule extends ExternalResource { textClassificationManager.setTextClassifier(null); // Reset TC overrides return textClassificationManager.getTextClassifier(); } + + public void dumpDefaultTextClassifierService() { + runShellCommand( + "dumpsys activity service com.google.android.ext.services/" + + "com.android.textclassifier.DefaultTextClassifierService"); + runShellCommand("cmd device_config list textclassifier"); + } + + public void enableVerboseLogging() { + runShellCommand("setprop log.tag.androidtc VERBOSE"); + } + + private void runShellCommand(String cmd) { + Log.v(TAG, "run shell command: " + cmd); + try (FileInputStream output = + new FileInputStream(uiAutomation.executeShellCommand(cmd).getFileDescriptor())) { + String cmdOutput = new String(ByteStreams.toByteArray(output)); + if (!cmdOutput.isEmpty()) { + Log.d(TAG, "cmd output: " + cmdOutput); + } + } catch (IOException ioe) { + Log.w(TAG, "failed to get cmd output", ioe); + } + } + + private static String getExtServicesPackageName() { + PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager(); + try { + packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0); + return PKG_NAME_GOOGLE_EXTSERVICES; + } catch (NameNotFoundException e) { + return PKG_NAME_AOSP_EXTSERVICES; + } + } } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java deleted file mode 100644 index 7d46e97..0000000 --- a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (C) 2018 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.textclassifier.testing; - -import android.app.UiAutomation; -import android.os.LocaleList; -import android.util.Log; -import androidx.test.platform.app.InstrumentationRegistry; -import org.junit.rules.ExternalResource; - -/** class for overriding testing_locale_list_override from {@link TextClassifierSettings} */ -public final class TestingLocaleListOverrideRule extends ExternalResource { - private static final String TAG = "TestingLocaleListOverrideRule"; - - private LocaleList originalLocaleList; - - @Override - protected void before() { - originalLocaleList = LocaleList.getDefault(); - } - - public void set(String... localeTags) { - if (localeTags.length == 0) { - return; - } - runShellCommand( - "device_config put textclassifier testing_locale_list_override " - + String.join(",", localeTags)); - } - - @Override - protected void after() { - runShellCommand( - "device_config put textclassifier testing_locale_list_override " - + originalLocaleList.toLanguageTags()); - runShellCommand("device_config delete textclassifier testing_locale_list_override"); - } - - private static void runShellCommand(String cmd) { - Log.v(TAG, "run shell command: " + cmd); - UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); - uiAutomation.executeShellCommand(cmd); - } -} diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs Binary files differindex 7421579..6ebf1cf 100644 --- a/native/actions/actions-entity-data.bfbs +++ b/native/actions/actions-entity-data.bfbs diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs index 21584b6..e906f93 100644 --- a/native/actions/actions-entity-data.fbs +++ b/native/actions/actions-entity-data.fbs @@ -18,7 +18,7 @@ namespace libtextclassifier3; table ActionsEntityData { // Extracted text. - text:string (shared); + text:string (key, shared); } root_type libtextclassifier3.ActionsEntityData; diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc index b1a042c..9f9a8d4 100644 --- a/native/actions/actions-suggestions.cc +++ b/native/actions/actions-suggestions.cc @@ -17,6 +17,7 @@ #include "actions/actions-suggestions.h" #include <memory> +#include <string> #include <vector> #include "utils/base/statusor.h" @@ -40,6 +41,7 @@ #include "utils/strings/stringpiece.h" #include "utils/strings/utf8.h" #include "utils/utf8/unicodetext.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/lite/string_util.h" namespace libtextclassifier3 { @@ -809,12 +811,14 @@ bool ActionsSuggestions::SetupModelInput( void ActionsSuggestions::PopulateTextReplies( const tflite::Interpreter* interpreter, int suggestion_index, - int score_index, const std::string& type, + int score_index, const std::string& type, float priority_score, + const absl::flat_hash_set<std::string>& blocklist, ActionsSuggestionsResponse* response) const { const std::vector<tflite::StringRef> replies = model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter); const TensorView<float> scores = model_executor_->OutputView<float>(score_index, interpreter); + for (int i = 0; i < replies.size(); i++) { if (replies[i].len == 0) { continue; @@ -823,8 +827,12 @@ void ActionsSuggestions::PopulateTextReplies( if (score < preconditions_.min_reply_score_threshold) { continue; } - response->actions.push_back( - {std::string(replies[i].str, replies[i].len), type, score}); + std::string response_text(replies[i].str, replies[i].len); + if (blocklist.contains(response_text)) { + continue; + } + + response->actions.push_back({response_text, type, score, priority_score}); } } @@ -909,10 +917,12 @@ bool ActionsSuggestions::ReadModelOutput( // Read smart reply predictions. if (!response->output_filtered_min_triggering_score && model_->tflite_model_spec()->output_replies() >= 0) { + absl::flat_hash_set<std::string> empty_blocklist; PopulateTextReplies(interpreter, model_->tflite_model_spec()->output_replies(), model_->tflite_model_spec()->output_replies_scores(), - model_->smart_reply_action_type()->str(), response); + model_->smart_reply_action_type()->str(), + /* priority_score */ 0.0, empty_blocklist, response); } // Read actions suggestions. @@ -950,17 +960,26 @@ bool ActionsSuggestions::ReadModelOutput( const int suggestions_index = metadata->output_suggestions(); const int suggestions_scores_index = metadata->output_suggestions_scores(); + absl::flat_hash_set<std::string> response_text_blocklist; switch (metadata->prediction_type()) { case PredictionType_NEXT_MESSAGE_PREDICTION: if (!task_spec || task_spec->type()->size() == 0) { TC3_LOG(WARNING) << "Task type not provided, use default " "smart_reply_action_type!"; } + if (task_spec) { + if (task_spec->response_text_blocklist()) { + for (const auto& val : *task_spec->response_text_blocklist()) { + response_text_blocklist.insert(val->str()); + } + } + } PopulateTextReplies( interpreter, suggestions_index, suggestions_scores_index, task_spec ? task_spec->type()->str() : model_->smart_reply_action_type()->str(), - response); + task_spec ? task_spec->priority_score() : 0.0, + response_text_blocklist, response); break; case PredictionType_INTENT_TRIGGERING: PopulateIntentTriggering(interpreter, suggestions_index, diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h index 32edc78..87f55fb 100644 --- a/native/actions/actions-suggestions.h +++ b/native/actions/actions-suggestions.h @@ -43,6 +43,7 @@ #include "utils/utf8/unilib.h" #include "utils/variant.h" #include "utils/zlib/zlib.h" +#include "absl/container/flat_hash_set.h" namespace libtextclassifier3 { @@ -176,7 +177,8 @@ class ActionsSuggestions { void PopulateTextReplies(const tflite::Interpreter* interpreter, int suggestion_index, int score_index, - const std::string& type, + const std::string& type, float priority_score, + const absl::flat_hash_set<std::string>& blocklist, ActionsSuggestionsResponse* response) const; void PopulateIntentTriggering(const tflite::Interpreter* interpreter, diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc index 062d527..b51ebc7 100644 --- a/native/actions/actions-suggestions_test.cc +++ b/native/actions/actions-suggestions_test.cc @@ -1798,6 +1798,7 @@ TEST_F(ActionsSuggestionsTest, TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) { std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(kMultiTaskSrEmojiModelFileName); + const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "hello?", @@ -1807,9 +1808,31 @@ TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) { /*locales=*/"en"}}}); EXPECT_EQ(response.actions.size(), 5); EXPECT_EQ(response.actions[0].response_text, "😁"); - EXPECT_EQ(response.actions[0].type, "EMOJI_CONCEPT"); - EXPECT_EQ(response.actions[1].response_text, "Yes"); - EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION"); + EXPECT_EQ(response.actions[0].type, "text_reply"); + EXPECT_EQ(response.actions[1].response_text, "👋"); + EXPECT_EQ(response.actions[1].type, "text_reply"); + EXPECT_EQ(response.actions[2].response_text, "Yes"); + EXPECT_EQ(response.actions[2].type, "text_reply"); +} + +TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) { + std::unique_ptr<ActionsSuggestions> actions_suggestions = + LoadTestModel(kMultiTaskSrEmojiModelFileName); + + const ActionsSuggestionsResponse response = + actions_suggestions->SuggestActions( + {{{/*user_id=*/1, "a pleasure chatting", + /*reference_time_ms_utc=*/0, + /*reference_timezone=*/"Europe/Zurich", + /*annotations=*/{}, + /*locales=*/"en"}}}); + EXPECT_EQ(response.actions.size(), 3); + EXPECT_EQ(response.actions[0].response_text, "😁"); + EXPECT_EQ(response.actions[0].type, "text_reply"); + EXPECT_EQ(response.actions[1].response_text, "😘"); + EXPECT_EQ(response.actions[1].type, "text_reply"); + EXPECT_EQ(response.actions[2].response_text, "Okay"); + EXPECT_EQ(response.actions[2].type, "text_reply"); } TEST_F(ActionsSuggestionsTest, LiveRelayModel) { diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs index 8c03eeb..0d8c7ad 100644 --- a/native/actions/actions_model.fbs +++ b/native/actions/actions_model.fbs @@ -36,6 +36,17 @@ enum PredictionType : int { ENTITY_ANNOTATION = 3, } +namespace libtextclassifier3; +enum RankingOptionsSortType : int { + SORT_TYPE_UNSPECIFIED = 0, + + // Rank results (or groups) by score, then type + SORT_TYPE_SCORE = 1, + + // Rank results (or groups) by priority score, then score, then type + SORT_TYPE_PRIORITY_SCORE = 2, +} + // Prediction metadata for an arbitrary task. namespace libtextclassifier3; table PredictionMetadata { @@ -315,10 +326,11 @@ table ActionSuggestionSpec { // Additional entity information. serialized_entity_data:string (shared); - // Priority score used for internal conflict resolution. + // For ranking and internal conflict resolution. priority_score:float = 0; entity_data:ActionsEntityData; + response_text_blocklist:[string]; } // Options to specify triggering behaviour per action class. @@ -416,6 +428,8 @@ table RankingOptions { // If true, keep actions from the same entities together for ranking. group_by_annotations:bool = true; + + sort_type:RankingOptionsSortType = SORT_TYPE_SCORE; } // Entity data to set from capturing groups. diff --git a/native/actions/ranker.cc b/native/actions/ranker.cc index d52ecaa..46e392a 100644 --- a/native/actions/ranker.cc +++ b/native/actions/ranker.cc @@ -20,6 +20,8 @@ #include <set> #include <vector> +#include "actions/actions_model_generated.h" + #if !defined(TC3_DISABLE_LUA) #include "actions/lua-ranker.h" #endif @@ -34,11 +36,22 @@ namespace libtextclassifier3 { namespace { void SortByScoreAndType(std::vector<ActionSuggestion>* actions) { - std::sort(actions->begin(), actions->end(), - [](const ActionSuggestion& a, const ActionSuggestion& b) { - return a.score > b.score || - (a.score >= b.score && a.type < b.type); - }); + std::stable_sort(actions->begin(), actions->end(), + [](const ActionSuggestion& a, const ActionSuggestion& b) { + return a.score > b.score || + (a.score >= b.score && a.type < b.type); + }); +} + +void SortByPriorityAndScoreAndType(std::vector<ActionSuggestion>* actions) { + std::stable_sort( + actions->begin(), actions->end(), + [](const ActionSuggestion& a, const ActionSuggestion& b) { + return a.priority_score > b.priority_score || + (a.priority_score >= b.priority_score && a.score > b.score) || + (a.priority_score >= b.priority_score && a.score >= b.score && + a.type < b.type); + }); } template <typename T> @@ -241,13 +254,8 @@ bool ActionsSuggestionsRanker::RankActions( const reflection::Schema* annotations_entity_data_schema) const { if (options_->deduplicate_suggestions() || options_->deduplicate_suggestions_by_span()) { - // First order suggestions by priority score for deduplication. - std::sort( - response->actions.begin(), response->actions.end(), - [](const ActionSuggestion& a, const ActionSuggestion& b) { - return a.priority_score > b.priority_score || - (a.priority_score >= b.priority_score && a.score > b.score); - }); + // Order suggestions by [priority score -> score] for deduplication + SortByPriorityAndScoreAndType(&response->actions); // Deduplicate, keeping the higher score actions. if (options_->deduplicate_suggestions()) { @@ -275,6 +283,8 @@ bool ActionsSuggestionsRanker::RankActions( } } + bool sort_by_priority = + options_->sort_type() == RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE; // Suppress smart replies if actions are present. if (options_->suppress_smart_replies_with_actions()) { std::vector<ActionSuggestion> non_smart_reply_actions; @@ -316,17 +326,35 @@ bool ActionsSuggestionsRanker::RankActions( // Sort within each group by score. for (std::vector<ActionSuggestion>& group : groups) { - SortByScoreAndType(&group); + if (sort_by_priority) { + SortByPriorityAndScoreAndType(&group); + } else { + SortByScoreAndType(&group); + } } - // Sort groups by maximum score. - std::sort(groups.begin(), groups.end(), - [](const std::vector<ActionSuggestion>& a, - const std::vector<ActionSuggestion>& b) { - return a.begin()->score > b.begin()->score || - (a.begin()->score >= b.begin()->score && - a.begin()->type < b.begin()->type); - }); + // Sort groups by maximum score or priority score. + if (sort_by_priority) { + std::stable_sort( + groups.begin(), groups.end(), + [](const std::vector<ActionSuggestion>& a, + const std::vector<ActionSuggestion>& b) { + return (a.begin()->priority_score > b.begin()->priority_score) || + (a.begin()->priority_score >= b.begin()->priority_score && + a.begin()->score > b.begin()->score) || + (a.begin()->priority_score >= b.begin()->priority_score && + a.begin()->score >= b.begin()->score && + a.begin()->type < b.begin()->type); + }); + } else { + std::stable_sort(groups.begin(), groups.end(), + [](const std::vector<ActionSuggestion>& a, + const std::vector<ActionSuggestion>& b) { + return a.begin()->score > b.begin()->score || + (a.begin()->score >= b.begin()->score && + a.begin()->type < b.begin()->type); + }); + } // Flatten result. const size_t num_actions = response->actions.size(); @@ -336,9 +364,9 @@ bool ActionsSuggestionsRanker::RankActions( response->actions.insert(response->actions.end(), actions.begin(), actions.end()); } - + } else if (sort_by_priority) { + SortByPriorityAndScoreAndType(&response->actions); } else { - // Order suggestions independently by score. SortByScoreAndType(&response->actions); } diff --git a/native/actions/ranker_test.cc b/native/actions/ranker_test.cc index b52cf45..5eba45f 100644 --- a/native/actions/ranker_test.cc +++ b/native/actions/ranker_test.cc @@ -18,6 +18,7 @@ #include <string> +#include "actions/actions_model_generated.h" #include "actions/types.h" #include "utils/zlib/zlib.h" #include "gmock/gmock.h" @@ -308,12 +309,12 @@ TEST(RankingTest, GroupsActionsByAnnotations) { response.actions.push_back({/*response_text=*/"", /*type=*/"call_phone", /*score=*/1.0, - /*priority_score=*/1.0, + /*priority_score=*/0.0, /*annotations=*/{annotation}}); response.actions.push_back({/*response_text=*/"", /*type=*/"add_contact", /*score=*/0.0, - /*priority_score=*/0.0, + /*priority_score=*/1.0, /*annotations=*/{annotation}}); } response.actions.push_back({/*response_text=*/"How are you?", @@ -338,23 +339,75 @@ TEST(RankingTest, GroupsActionsByAnnotations) { IsAction("text_reply", "How are you?", 0.5)})); } -TEST(RankingTest, SortsActionsByScore) { +TEST(RankingTest, GroupsByAnnotationsSortedByPriority) { const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}}; ActionsSuggestionsResponse response; + response.actions.push_back({/*response_text=*/"How are you?", + /*type=*/"text_reply", + /*score=*/2.0, + /*priority_score=*/0.0}); { ActionSuggestionAnnotation annotation; annotation.span = {/*message_index=*/0, /*span=*/{5, 8}, /*text=*/"911"}; annotation.entity = ClassificationResult("phone", 1.0); response.actions.push_back({/*response_text=*/"", + /*type=*/"add_contact", + /*score=*/0.0, + /*priority_score=*/1.0, + /*annotations=*/{annotation}}); + response.actions.push_back({/*response_text=*/"", /*type=*/"call_phone", /*score=*/1.0, + /*priority_score=*/0.0, + /*annotations=*/{annotation}}); + response.actions.push_back({/*response_text=*/"", + /*type=*/"add_contact2", + /*score=*/0.5, /*priority_score=*/1.0, /*annotations=*/{annotation}}); + } + RankingOptionsT options; + options.group_by_annotations = true; + options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(RankingOptions::Pack(builder, &options)); + auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker( + flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()), + /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply"); + + ranker->RankActions(conversation, &response); + + // The text reply should be last, even though it's score is higher than + // any other scores -- because it's priority_score is lower than the max + // of those with the 'phone' annotation + EXPECT_THAT(response.actions, + testing::ElementsAreArray({ + // Group 1 (Phone annotation) + IsAction("add_contact2", "", 0.5), // priority_score=1.0 + IsAction("add_contact", "", 0.0), // priority_score=1.0 + IsAction("call_phone", "", 1.0), // priority_score=0.0 + IsAction("text_reply", "How are you?", 2.0), // Group 2 + })); +} + +TEST(RankingTest, SortsActionsByScore) { + const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}}; + ActionsSuggestionsResponse response; + { + ActionSuggestionAnnotation annotation; + annotation.span = {/*message_index=*/0, /*span=*/{5, 8}, + /*text=*/"911"}; + annotation.entity = ClassificationResult("phone", 1.0); + response.actions.push_back({/*response_text=*/"", + /*type=*/"call_phone", + /*score=*/1.0, + /*priority_score=*/0.0, + /*annotations=*/{annotation}}); response.actions.push_back({/*response_text=*/"", /*type=*/"add_contact", /*score=*/0.0, - /*priority_score=*/0.0, + /*priority_score=*/1.0, /*annotations=*/{annotation}}); } response.actions.push_back({/*response_text=*/"How are you?", @@ -378,5 +431,40 @@ TEST(RankingTest, SortsActionsByScore) { IsAction("add_contact", "", 0.0)})); } +TEST(RankingTest, SortsActionsByPriority) { + const Conversation conversation = {{{/*user_id=*/1, "hello?"}}}; + ActionsSuggestionsResponse response; + // emoji replies given higher priority_score + response.actions.push_back({/*response_text=*/"😁", + /*type=*/"text_reply", + /*score=*/0.5, + /*priority_score=*/1.0}); + response.actions.push_back({/*response_text=*/"👋", + /*type=*/"text_reply", + /*score=*/0.4, + /*priority_score=*/1.0}); + response.actions.push_back({/*response_text=*/"Yes", + /*type=*/"text_reply", + /*score=*/1.0, + /*priority_score=*/0.0}); + RankingOptionsT options; + // Don't group by annotation. + options.group_by_annotations = false; + options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(RankingOptions::Pack(builder, &options)); + auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker( + flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()), + /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply"); + + ranker->RankActions(conversation, &response); + + EXPECT_THAT(response.actions, testing::ElementsAreArray( + {IsAction("text_reply", "😁", 0.5), + IsAction("text_reply", "👋", 0.4), + // Ranked last because of priority score + IsAction("text_reply", "Yes", 1.0)})); +} + } // namespace } // namespace libtextclassifier3 diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model Binary files differindex 77e556c..0fa7f7e 100644 --- a/native/actions/test_data/actions_suggestions_grammar_test.model +++ b/native/actions/test_data/actions_suggestions_grammar_test.model diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model Binary files differindex c468bd5..6107e98 100644 --- a/native/actions/test_data/actions_suggestions_test.model +++ b/native/actions/test_data/actions_suggestions_test.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model Binary files differindex ec421a1..436ed93 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model Binary files differindex 24be6c6..935691d 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model Binary files differindex fd7ddf2..2c9f74b 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model Binary files differindex c969c56..cdb7523 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model Binary files differindex d171898..ac28fa2 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model Binary files differindex 937552b..d864b79 100644 --- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model +++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc index 32bd29c..e0d4241 100644 --- a/native/annotator/annotator.cc +++ b/native/annotator/annotator.cc @@ -973,11 +973,11 @@ CodepointSpan Annotator::SuggestSelection( // Sort candidates according to their position in the input, so that the next // code can assume that any connected component of overlapping spans forms a // contiguous block. - std::sort(candidates.annotated_spans[0].begin(), - candidates.annotated_spans[0].end(), - [](const AnnotatedSpan& a, const AnnotatedSpan& b) { - return a.span.first < b.span.first; - }); + std::stable_sort(candidates.annotated_spans[0].begin(), + candidates.annotated_spans[0].end(), + [](const AnnotatedSpan& a, const AnnotatedSpan& b) { + return a.span.first < b.span.first; + }); std::vector<int> candidate_indices; if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens, @@ -987,13 +987,14 @@ CodepointSpan Annotator::SuggestSelection( return original_click_indices; } - std::sort(candidate_indices.begin(), candidate_indices.end(), - [this, &candidates](int a, int b) { - return GetPriorityScore( - candidates.annotated_spans[0][a].classification) > - GetPriorityScore( - candidates.annotated_spans[0][b].classification); - }); + std::stable_sort( + candidate_indices.begin(), candidate_indices.end(), + [this, &candidates](int a, int b) { + return GetPriorityScore( + candidates.annotated_spans[0][a].classification) > + GetPriorityScore( + candidates.annotated_spans[0][b].classification); + }); for (const int i : candidate_indices) { if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) && @@ -1173,7 +1174,7 @@ bool Annotator::ResolveConflict( } } - std::sort( + std::stable_sort( conflicting_indices.begin(), conflicting_indices.end(), [this, &scores_lengths, candidates, conflicting_indices](int i, int j) { if (scores_lengths[i].first == scores_lengths[j].first && @@ -1241,7 +1242,7 @@ bool Annotator::ResolveConflict( chosen_indices_for_source_ptr->insert(considered_candidate); } - std::sort(chosen_indices->begin(), chosen_indices->end()); + std::stable_sort(chosen_indices->begin(), chosen_indices->end()); return true; } @@ -1414,10 +1415,11 @@ namespace { // Sorts the classification results from high score to low score. void SortClassificationResults( std::vector<ClassificationResult>* classification_results) { - std::sort(classification_results->begin(), classification_results->end(), - [](const ClassificationResult& a, const ClassificationResult& b) { - return a.score > b.score; - }); + std::stable_sort( + classification_results->begin(), classification_results->end(), + [](const ClassificationResult& a, const ClassificationResult& b) { + return a.score > b.score; + }); } } // namespace @@ -1936,10 +1938,11 @@ std::vector<ClassificationResult> Annotator::ClassifyText( } // Sort results according to score. - std::sort(results.begin(), results.end(), - [](const ClassificationResult& a, const ClassificationResult& b) { - return a.score > b.score; - }); + std::stable_sort( + results.begin(), results.end(), + [](const ClassificationResult& a, const ClassificationResult& b) { + return a.score > b.score; + }); if (results.empty()) { results = {{Collections::Other(), 1.0}}; @@ -2297,19 +2300,19 @@ Status Annotator::AnnotateSingleInput( // Also sort them according to the end position and collection, so that the // deduplication code below can assume that same spans and classifications // form contiguous blocks. - std::sort(candidates->begin(), candidates->end(), - [](const AnnotatedSpan& a, const AnnotatedSpan& b) { - if (a.span.first != b.span.first) { - return a.span.first < b.span.first; - } + std::stable_sort(candidates->begin(), candidates->end(), + [](const AnnotatedSpan& a, const AnnotatedSpan& b) { + if (a.span.first != b.span.first) { + return a.span.first < b.span.first; + } - if (a.span.second != b.span.second) { - return a.span.second < b.span.second; - } + if (a.span.second != b.span.second) { + return a.span.second < b.span.second; + } - return a.classification[0].collection < - b.classification[0].collection; - }); + return a.classification[0].collection < + b.classification[0].collection; + }); std::vector<int> candidate_indices; if (!ResolveConflicts(*candidates, context, tokens, @@ -2904,10 +2907,10 @@ bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest, return false; } } - std::sort(scored_chunks.rbegin(), scored_chunks.rend(), - [](const ScoredChunk& lhs, const ScoredChunk& rhs) { - return lhs.score < rhs.score; - }); + std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(), + [](const ScoredChunk& lhs, const ScoredChunk& rhs) { + return lhs.score < rhs.score; + }); // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick // them greedily as long as they do not overlap with any previously picked @@ -2936,7 +2939,7 @@ bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest, chunks->push_back(scored_chunk.token_span); } - std::sort(chunks->begin(), chunks->end()); + std::stable_sort(chunks->begin(), chunks->end()); return true; } diff --git a/native/annotator/datetime/datetime-grounder.cc b/native/annotator/datetime/datetime-grounder.cc index 7d5f440..ff0c775 100644 --- a/native/annotator/datetime/datetime-grounder.cc +++ b/native/annotator/datetime/datetime-grounder.cc @@ -16,6 +16,7 @@ #include "annotator/datetime/datetime-grounder.h" +#include <algorithm> #include <limits> #include <unordered_map> #include <vector> @@ -250,10 +251,10 @@ StatusOr<std::vector<DatetimeParseResult>> DatetimeGrounder::Ground( } // Sort the date time units by component type. - std::sort(date_components.begin(), date_components.end(), - [](DatetimeComponent a, DatetimeComponent b) { - return a.component_type > b.component_type; - }); + std::stable_sort(date_components.begin(), date_components.end(), + [](DatetimeComponent a, DatetimeComponent b) { + return a.component_type > b.component_type; + }); result.datetime_components.swap(date_components); datetime_parse_result.push_back(result); } diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc index 867c886..94a0961 100644 --- a/native/annotator/datetime/extractor.cc +++ b/native/annotator/datetime/extractor.cc @@ -16,6 +16,8 @@ #include "annotator/datetime/extractor.h" +#include <algorithm> + #include "annotator/datetime/utils.h" #include "annotator/model_generated.h" #include "annotator/types.h" @@ -347,10 +349,11 @@ bool DatetimeExtractor::ParseWrittenNumber(const UnicodeText& input, } } - std::sort(found_numbers.begin(), found_numbers.end(), - [](const std::pair<int, int>& a, const std::pair<int, int>& b) { - return a.first < b.first; - }); + std::stable_sort( + found_numbers.begin(), found_numbers.end(), + [](const std::pair<int, int>& a, const std::pair<int, int>& b) { + return a.first < b.first; + }); int sum = 0; int running_value = -1; diff --git a/native/annotator/datetime/regex-parser.cc b/native/annotator/datetime/regex-parser.cc index 4dc9c56..5daabd5 100644 --- a/native/annotator/datetime/regex-parser.cc +++ b/native/annotator/datetime/regex-parser.cc @@ -16,6 +16,7 @@ #include "annotator/datetime/regex-parser.h" +#include <algorithm> #include <iterator> #include <set> #include <unordered_set> @@ -191,17 +192,17 @@ StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse( // Resolve conflicts by always picking the longer span and breaking ties by // selecting the earlier entry in the list for a given locale. - std::sort(indexed_found_spans.begin(), indexed_found_spans.end(), - [](const std::pair<DatetimeParseResultSpan, int>& a, - const std::pair<DatetimeParseResultSpan, int>& b) { - if ((a.first.span.second - a.first.span.first) != - (b.first.span.second - b.first.span.first)) { - return (a.first.span.second - a.first.span.first) > - (b.first.span.second - b.first.span.first); - } else { - return a.second < b.second; - } - }); + std::stable_sort(indexed_found_spans.begin(), indexed_found_spans.end(), + [](const std::pair<DatetimeParseResultSpan, int>& a, + const std::pair<DatetimeParseResultSpan, int>& b) { + if ((a.first.span.second - a.first.span.first) != + (b.first.span.second - b.first.span.first)) { + return (a.first.span.second - a.first.span.first) > + (b.first.span.second - b.first.span.first); + } else { + return a.second < b.second; + } + }); std::vector<DatetimeParseResultSpan> results; std::vector<DatetimeParseResultSpan> resolved_found_spans; @@ -394,10 +395,10 @@ bool RegexDatetimeParser::ExtractDatetime( } // Sort the date time units by component type. - std::sort(date_components.begin(), date_components.end(), - [](DatetimeComponent a, DatetimeComponent b) { - return a.component_type > b.component_type; - }); + std::stable_sort(date_components.begin(), date_components.end(), + [](DatetimeComponent a, DatetimeComponent b) { + return a.component_type > b.component_type; + }); result.datetime_components.swap(date_components); results->push_back(result); } diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc index 640ceec..2c5a43c 100644 --- a/native/annotator/translate/translate.cc +++ b/native/annotator/translate/translate.cc @@ -16,6 +16,7 @@ #include "annotator/translate/translate.h" +#include <algorithm> #include <memory> #include "annotator/collections.h" @@ -142,11 +143,11 @@ TranslateAnnotator::BackoffDetectLanguages( result.push_back({key, value}); } - std::sort(result.begin(), result.end(), - [](TranslateAnnotator::LanguageConfidence& a, - TranslateAnnotator::LanguageConfidence& b) { - return a.confidence > b.confidence; - }); + std::stable_sort(result.begin(), result.end(), + [](const TranslateAnnotator::LanguageConfidence& a, + const TranslateAnnotator::LanguageConfidence& b) { + return a.confidence > b.confidence; + }); return result; } diff --git a/native/lang_id/common/embedding-network.cc b/native/lang_id/common/embedding-network.cc index 469cb1f..49c9ca0 100644 --- a/native/lang_id/common/embedding-network.cc +++ b/native/lang_id/common/embedding-network.cc @@ -16,6 +16,8 @@ #include "lang_id/common/embedding-network.h" +#include <vector> + #include "lang_id/common/lite_base/integral-types.h" #include "lang_id/common/lite_base/logging.h" diff --git a/native/lang_id/common/fel/feature-extractor.cc b/native/lang_id/common/fel/feature-extractor.cc index ab8a1a6..4e304fe 100644 --- a/native/lang_id/common/fel/feature-extractor.cc +++ b/native/lang_id/common/fel/feature-extractor.cc @@ -17,6 +17,7 @@ #include "lang_id/common/fel/feature-extractor.h" #include <string> +#include <vector> #include "lang_id/common/fel/feature-types.h" #include "lang_id/common/fel/fel-parser.h" diff --git a/native/lang_id/common/fel/workspace.cc b/native/lang_id/common/fel/workspace.cc index af41e29..60dcc46 100644 --- a/native/lang_id/common/fel/workspace.cc +++ b/native/lang_id/common/fel/workspace.cc @@ -18,6 +18,7 @@ #include <atomic> #include <string> +#include <vector> namespace libtextclassifier3 { namespace mobile { diff --git a/native/lang_id/common/fel/workspace.h b/native/lang_id/common/fel/workspace.h index f13d802..2ac5b26 100644 --- a/native/lang_id/common/fel/workspace.h +++ b/native/lang_id/common/fel/workspace.h @@ -23,6 +23,7 @@ #include <stddef.h> +#include <algorithm> #include <string> #include <unordered_map> #include <utility> diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc index 19afcc4..fc925ea 100644 --- a/native/lang_id/common/file/mmap.cc +++ b/native/lang_id/common/file/mmap.cc @@ -29,6 +29,8 @@ #endif #include <sys/stat.h> +#include <string> + #include "lang_id/common/lite_base/logging.h" #include "lang_id/common/lite_base/macros.h" diff --git a/native/lang_id/common/lite_strings/str-split.cc b/native/lang_id/common/lite_strings/str-split.cc index 199bb69..d227eec 100644 --- a/native/lang_id/common/lite_strings/str-split.cc +++ b/native/lang_id/common/lite_strings/str-split.cc @@ -16,6 +16,8 @@ #include "lang_id/common/lite_strings/str-split.h" +#include <vector> + namespace libtextclassifier3 { namespace mobile { diff --git a/native/lang_id/common/math/softmax.cc b/native/lang_id/common/math/softmax.cc index 750341d..249ed57 100644 --- a/native/lang_id/common/math/softmax.cc +++ b/native/lang_id/common/math/softmax.cc @@ -17,6 +17,7 @@ #include "lang_id/common/math/softmax.h" #include <algorithm> +#include <vector> #include "lang_id/common/lite_base/logging.h" #include "lang_id/common/math/fastexp.h" diff --git a/native/lang_id/fb_model/lang-id-from-fb.cc b/native/lang_id/fb_model/lang-id-from-fb.cc index dc36fb7..51c8c47 100644 --- a/native/lang_id/fb_model/lang-id-from-fb.cc +++ b/native/lang_id/fb_model/lang-id-from-fb.cc @@ -16,7 +16,9 @@ #include "lang_id/fb_model/lang-id-from-fb.h" +#include <memory> #include <string> +#include <utility> #include "lang_id/fb_model/model-provider-from-fb.h" diff --git a/native/lang_id/fb_model/model-provider-from-fb.cc b/native/lang_id/fb_model/model-provider-from-fb.cc index 43bf860..d14d403 100644 --- a/native/lang_id/fb_model/model-provider-from-fb.cc +++ b/native/lang_id/fb_model/model-provider-from-fb.cc @@ -16,7 +16,9 @@ #include "lang_id/fb_model/model-provider-from-fb.h" +#include <memory> #include <string> +#include <utility> #include "lang_id/common/file/file-utils.h" #include "lang_id/common/file/mmap.h" diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc index 92359a9..f7c66f7 100644 --- a/native/lang_id/lang-id.cc +++ b/native/lang_id/lang-id.cc @@ -21,6 +21,7 @@ #include <memory> #include <string> #include <unordered_map> +#include <utility> #include <vector> #include "lang_id/common/embedding-feature-interface.h" diff --git a/native/utils/codepoint-range.cc b/native/utils/codepoint-range.cc index e26b160..a4cd485 100644 --- a/native/utils/codepoint-range.cc +++ b/native/utils/codepoint-range.cc @@ -31,10 +31,11 @@ void SortCodepointRanges( CodepointRangeStruct(range->start(), range->end())); } - std::sort(sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(), - [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) { - return a.start < b.start; - }); + std::stable_sort( + sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(), + [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) { + return a.start < b.start; + }); } // Returns true if given codepoint is covered by the given sorted vector of diff --git a/native/utils/grammar/parsing/parser.cc b/native/utils/grammar/parsing/parser.cc index 4e39a98..a9e99ba 100644 --- a/native/utils/grammar/parsing/parser.cc +++ b/native/utils/grammar/parsing/parser.cc @@ -16,6 +16,7 @@ #include "utils/grammar/parsing/parser.h" +#include <algorithm> #include <unordered_map> #include "utils/grammar/parsing/parse-tree.h" @@ -177,14 +178,14 @@ std::vector<Symbol> Parser::SortedSymbolsForInput(const TextContext& input, } } - std::sort(symbols.begin(), symbols.end(), - [](const Symbol& a, const Symbol& b) { - // Sort by increasing (end, start) position to guarantee the - // matcher requirement that the tokens are fed in non-decreasing - // end position order. - return std::tie(a.codepoint_span.second, a.codepoint_span.first) < - std::tie(b.codepoint_span.second, b.codepoint_span.first); - }); + std::stable_sort( + symbols.begin(), symbols.end(), [](const Symbol& a, const Symbol& b) { + // Sort by increasing (end, start) position to guarantee the + // matcher requirement that the tokens are fed in non-decreasing + // end position order. + return std::tie(a.codepoint_span.second, a.codepoint_span.first) < + std::tie(b.codepoint_span.second, b.codepoint_span.first); + }); return symbols; } diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc index dd29e3c..c134550 100644 --- a/native/utils/grammar/utils/ir.cc +++ b/native/utils/grammar/utils/ir.cc @@ -16,6 +16,8 @@ #include "utils/grammar/utils/ir.h" +#include <algorithm> + #include "utils/i18n/locale.h" #include "utils/strings/append.h" #include "utils/strings/stringpiece.h" @@ -28,14 +30,16 @@ constexpr size_t kMaxHashTableSize = 100; template <typename T> void SortForBinarySearchLookup(T* entries) { - std::sort(entries->begin(), entries->end(), - [](const auto& a, const auto& b) { return a->key < b->key; }); + std::stable_sort( + entries->begin(), entries->end(), + [](const auto& a, const auto& b) { return a->key < b->key; }); } template <typename T> void SortStructsForBinarySearchLookup(T* entries) { - std::sort(entries->begin(), entries->end(), - [](const auto& a, const auto& b) { return a.key() < b.key(); }); + std::stable_sort( + entries->begin(), entries->end(), + [](const auto& a, const auto& b) { return a.key() < b.key(); }); } bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) { @@ -76,13 +80,14 @@ bool IsSameLhsSet(const Ir::LhsSet& lhs_set, Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) { Ir::LhsSet sorted_lhs = lhs_set; - std::sort(sorted_lhs.begin(), sorted_lhs.end(), - [](const Ir::Lhs& a, const Ir::Lhs& b) { - return std::tie(a.nonterminal, a.callback.id, a.callback.param, - a.preconditions.max_whitespace_gap) < - std::tie(b.nonterminal, b.callback.id, b.callback.param, - b.preconditions.max_whitespace_gap); - }); + std::stable_sort( + sorted_lhs.begin(), sorted_lhs.end(), + [](const Ir::Lhs& a, const Ir::Lhs& b) { + return std::tie(a.nonterminal, a.callback.id, a.callback.param, + a.preconditions.max_whitespace_gap) < + std::tie(b.nonterminal, b.callback.id, b.callback.param, + b.preconditions.max_whitespace_gap); + }); return lhs_set; } @@ -300,10 +305,10 @@ void Ir::SerializeTerminalRules( TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second}); } } - std::sort(terminal_rules.begin(), terminal_rules.end(), - [](const TerminalEntry& a, const TerminalEntry& b) { - return a.terminal < b.terminal; - }); + std::stable_sort(terminal_rules.begin(), terminal_rules.end(), + [](const TerminalEntry& a, const TerminalEntry& b) { + return a.terminal < b.terminal; + }); // Index the entries in sorted order. std::vector<int> index(terminal_rules_sets.size(), 0); diff --git a/native/utils/grammar/utils/locale-shard-map.cc b/native/utils/grammar/utils/locale-shard-map.cc index e6db06d..141ce5d 100644 --- a/native/utils/grammar/utils/locale-shard-map.cc +++ b/native/utils/grammar/utils/locale-shard-map.cc @@ -40,8 +40,8 @@ std::vector<Locale> LocaleTagsToLocaleList(const std::string& locale_tags) { locale_list.emplace_back(locale); } } - std::sort(locale_list.begin(), locale_list.end(), - [](const Locale& a, const Locale& b) { return a < b; }); + std::stable_sort(locale_list.begin(), locale_list.end(), + [](const Locale& a, const Locale& b) { return a < b; }); return locale_list; } diff --git a/native/utils/testing/test_data_generator.h b/native/utils/testing/test_data_generator.h index 30c7aed..c23b5dc 100644 --- a/native/utils/testing/test_data_generator.h +++ b/native/utils/testing/test_data_generator.h @@ -20,6 +20,7 @@ #include <algorithm> #include <iostream> #include <random> +#include <string> #include "utils/strings/stringpiece.h" @@ -35,6 +36,18 @@ class TestDataGenerator { return dist(random_engine_); } + template <> + bool generate() { + std::bernoulli_distribution dist(0.5); + return dist(random_engine_); + } + + template <> + char generate() { + std::uniform_int_distribution<int> dist(0, 25); + return dist(random_engine_) + 'a'; + } + template <typename T, typename std::enable_if_t< std::is_floating_point<T>::value>* = nullptr> T generate() { diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc index 463d910..644dde8 100644 --- a/native/utils/tflite-model-executor.cc +++ b/native/utils/tflite-model-executor.cc @@ -27,6 +27,8 @@ namespace builtin { TfLiteRegistration* Register_ADD(); TfLiteRegistration* Register_CONCATENATION(); TfLiteRegistration* Register_CONV_2D(); +TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); +TfLiteRegistration* Register_AVERAGE_POOL_2D(); TfLiteRegistration* Register_EQUAL(); TfLiteRegistration* Register_FULLY_CONNECTED(); TfLiteRegistration* Register_GREATER_EQUAL(); @@ -89,7 +91,9 @@ TfLiteRegistration* Register_GREATER(); #include "utils/tflite/dist_diversification.h" #include "utils/tflite/string_projection.h" #include "utils/tflite/text_encoder.h" +#include "utils/tflite/text_encoder3s.h" #include "utils/tflite/token_encoder.h" + namespace tflite { namespace ops { namespace custom { @@ -114,6 +118,14 @@ void RegisterSelectedOps(tflite::MutableOpResolver* resolver) { tflite::ops::builtin::Register_CONV_2D(), /*min_version=*/1, /*max_version=*/5); + resolver->AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + tflite::ops::builtin::Register_DEPTHWISE_CONV_2D(), + /*min_version=*/1, + /*max_version=*/6); + resolver->AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D, + tflite::ops::builtin::Register_AVERAGE_POOL_2D(), + /*min_version=*/1, + /*max_version=*/1); resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL, ::tflite::ops::builtin::Register_EQUAL()); @@ -289,6 +301,8 @@ std::unique_ptr<tflite::OpResolver> BuildOpResolver( tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION()); resolver->AddCustom("TextEncoder", tflite::ops::custom::Register_TEXT_ENCODER()); + resolver->AddCustom("TextEncoder3S", + tflite::ops::custom::Register_TEXT_ENCODER3S()); resolver->AddCustom("TokenEncoder", tflite::ops::custom::Register_TOKEN_ENCODER()); resolver->AddCustom( diff --git a/native/utils/tflite/encoder_common.cc b/native/utils/tflite/encoder_common.cc index 8f9f2a8..eb319f9 100644 --- a/native/utils/tflite/encoder_common.cc +++ b/native/utils/tflite/encoder_common.cc @@ -58,6 +58,11 @@ TfLiteStatus CopyValuesToTensorAndPadOrTruncate( out->data.i32 + output_offset + from_this_element, in.data.i32[value_index]); } break; + case kTfLiteInt64: { + std::fill(out->data.i64 + output_offset, + out->data.i64 + output_offset + from_this_element, + in.data.i64[value_index]); + } break; case kTfLiteFloat32: { std::fill(out->data.f + output_offset, out->data.f + output_offset + from_this_element, @@ -78,6 +83,12 @@ TfLiteStatus CopyValuesToTensorAndPadOrTruncate( std::fill(out->data.i32 + output_offset, out->data.i32 + output_size, value); } break; + case kTfLiteInt64: { + const int64_t value = + (output_offset > 0) ? out->data.i64[output_offset - 1] : 0; + std::fill(out->data.i64 + output_offset, out->data.i64 + output_size, + value); + } break; case kTfLiteFloat32: { const float value = (output_offset > 0) ? out->data.f[output_offset - 1] : 0; diff --git a/native/utils/tflite/text_encoder3s.cc b/native/utils/tflite/text_encoder3s.cc new file mode 100644 index 0000000..0b5e65b --- /dev/null +++ b/native/utils/tflite/text_encoder3s.cc @@ -0,0 +1,243 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/tflite/text_encoder3s.h" + +#include <memory> +#include <vector> + +#include "utils/base/logging.h" +#include "utils/strings/stringpiece.h" +#include "utils/tflite/encoder_common.h" +#include "utils/tflite/text_encoder_config_generated.h" +#include "utils/tokenfree/byte_encoder.h" +#include "flatbuffers/flatbuffers.h" +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" + +namespace libtextclassifier3 { +namespace { + +// Input parameters for the op. +constexpr int kInputTextInd = 0; + +constexpr int kTextLengthInd = 1; +constexpr int kMaxLengthInd = 2; +constexpr int kInputAttrInd = 3; + +// Output parameters for the op. +constexpr int kOutputEncodedInd = 0; +constexpr int kOutputPositionInd = 1; +constexpr int kOutputLengthsInd = 2; +constexpr int kOutputAttrInd = 3; + +// Initializes text encoder object from serialized parameters. +void* Initialize(TfLiteContext* context, const char* buffer, size_t length) { + std::unique_ptr<ByteEncoder> encoder(new ByteEncoder()); + return encoder.release(); +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<ByteEncoder*>(buffer); +} + +namespace { +TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node, + int max_output_length) { + TfLiteTensor& output_encoded = + context->tensors[node->outputs->data[kOutputEncodedInd]]; + + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, &output_encoded, + CreateIntArray({kEncoderBatchSize, max_output_length}))); + TfLiteTensor& output_positions = + context->tensors[node->outputs->data[kOutputPositionInd]]; + + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, &output_positions, + CreateIntArray({kEncoderBatchSize, max_output_length}))); + + const int num_output_attrs = node->outputs->size - kOutputAttrInd; + for (int i = 0; i < num_output_attrs; ++i) { + TfLiteTensor& output = + context->tensors[node->outputs->data[kOutputAttrInd + i]]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, &output, + CreateIntArray({kEncoderBatchSize, max_output_length}))); + } + return kTfLiteOk; +} +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check that the batch dimension is kEncoderBatchSize. + const TfLiteTensor& input_text = + context->tensors[node->inputs->data[kInputTextInd]]; + TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank); + TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize); + + TfLiteTensor& output_lengths = + context->tensors[node->outputs->data[kOutputLengthsInd]]; + + TfLiteTensor& output_encoded = + context->tensors[node->outputs->data[kOutputEncodedInd]]; + TfLiteTensor& output_positions = + context->tensors[node->outputs->data[kOutputPositionInd]]; + output_encoded.type = kTfLiteInt32; + output_positions.type = kTfLiteInt32; + output_lengths.type = kTfLiteInt32; + + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, &output_lengths, + CreateIntArray({kEncoderBatchSize}))); + + // Check that there are enough outputs for attributes. + const int num_output_attrs = node->outputs->size - kOutputAttrInd; + TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd, + num_output_attrs); + + // Copy attribute types from input to output tensors. + for (int i = 0; i < num_output_attrs; ++i) { + TfLiteTensor& input = + context->tensors[node->inputs->data[kInputAttrInd + i]]; + TfLiteTensor& output = + context->tensors[node->outputs->data[kOutputAttrInd + i]]; + output.type = input.type; + } + + const TfLiteTensor& output_length = + context->tensors[node->inputs->data[kMaxLengthInd]]; + + if (tflite::IsConstantTensor(&output_length)) { + return ResizeOutputTensors(context, node, output_length.data.i64[0]); + } else { + tflite::SetTensorToDynamic(&output_encoded); + tflite::SetTensorToDynamic(&output_positions); + for (int i = 0; i < num_output_attrs; ++i) { + TfLiteTensor& output_attr = + context->tensors[node->outputs->data[kOutputAttrInd + i]]; + tflite::SetTensorToDynamic(&output_attr); + } + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + if (node->user_data == nullptr) { + return kTfLiteError; + } + auto text_encoder = reinterpret_cast<ByteEncoder*>(node->user_data); + const TfLiteTensor& input_text = + context->tensors[node->inputs->data[kInputTextInd]]; + const int num_strings_in_tensor = tflite::GetStringCount(&input_text); + const int num_strings = + context->tensors[node->inputs->data[kTextLengthInd]].data.i32[0]; + + // Check that the number of strings is not bigger than the input tensor size. + TF_LITE_ENSURE(context, num_strings_in_tensor >= num_strings); + + TfLiteTensor& output_encoded = + context->tensors[node->outputs->data[kOutputEncodedInd]]; + if (tflite::IsDynamicTensor(&output_encoded)) { + const TfLiteTensor& output_length = + context->tensors[node->inputs->data[kMaxLengthInd]]; + TF_LITE_ENSURE_OK( + context, ResizeOutputTensors(context, node, output_length.data.i64[0])); + } + TfLiteTensor& output_positions = + context->tensors[node->outputs->data[kOutputPositionInd]]; + + std::vector<int> encoded_total; + std::vector<int> encoded_positions; + std::vector<int> encoded_offsets; + encoded_offsets.reserve(num_strings); + const int max_output_length = output_encoded.dims->data[1]; + const int max_encoded_position = max_output_length; + + for (int i = 0; i < num_strings; ++i) { + const auto& strref = tflite::GetString(&input_text, i); + std::vector<int64_t> encoded; + text_encoder->Encode( + libtextclassifier3::StringPiece(strref.str, strref.len), &encoded); + encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end()); + encoded_offsets.push_back(encoded_total.size()); + for (int i = 0; i < encoded.size(); ++i) { + encoded_positions.push_back(std::min(i, max_encoded_position - 1)); + } + } + + // Copy encoding to output tensor. + const int start_offset = + std::max(0, static_cast<int>(encoded_total.size()) - max_output_length); + int output_offset = 0; + int32_t* output_buffer = output_encoded.data.i32; + int32_t* output_positions_buffer = output_positions.data.i32; + for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) { + output_buffer[output_offset] = encoded_total[i]; + output_positions_buffer[output_offset] = encoded_positions[i]; + } + + // Save output encoded length. + TfLiteTensor& output_lengths = + context->tensors[node->outputs->data[kOutputLengthsInd]]; + output_lengths.data.i32[0] = output_offset; + + // Do padding. + for (; output_offset < max_output_length; ++output_offset) { + output_buffer[output_offset] = 0; + output_positions_buffer[output_offset] = 0; + } + + // Process attributes, all checks of sizes and types are done in Prepare. + const int num_output_attrs = node->outputs->size - kOutputAttrInd; + TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd, + num_output_attrs); + for (int i = 0; i < num_output_attrs; ++i) { + TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate( + context->tensors[node->inputs->data[kInputAttrInd + i]], + encoded_offsets, start_offset, context, + &context->tensors[node->outputs->data[kOutputAttrInd + i]]); + if (attr_status != kTfLiteOk) { + return attr_status; + } + } + + return kTfLiteOk; +} + +} // namespace +} // namespace libtextclassifier3 + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_TEXT_ENCODER3S() { + static TfLiteRegistration registration = { + libtextclassifier3::Initialize, libtextclassifier3::Free, + libtextclassifier3::Prepare, libtextclassifier3::Eval}; + return ®istration; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/native/utils/tflite/text_encoder3s.h b/native/utils/tflite/text_encoder3s.h new file mode 100644 index 0000000..50e1e64 --- /dev/null +++ b/native/utils/tflite/text_encoder3s.h @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// An encoder that produces positional and attributes encodings for a +// transformer style model based on byte segmentation of text. + +#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_ +#define LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_ + +#include "tensorflow/lite/context.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_TEXT_ENCODER3S(); + +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_ diff --git a/native/utils/tokenfree/byte_encoder.cc b/native/utils/tokenfree/byte_encoder.cc new file mode 100644 index 0000000..c79d3a2 --- /dev/null +++ b/native/utils/tokenfree/byte_encoder.cc @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/tokenfree/byte_encoder.h" + +#include <vector> +namespace libtextclassifier3 { + +bool ByteEncoder::Encode(StringPiece input_text, + std::vector<int64_t>* encoded_text) const { + const int len = input_text.size(); + if (len <= 0) { + *encoded_text = {}; + return true; + } + + int size = input_text.size(); + encoded_text->resize(size); + + const auto& text = input_text.ToString(); + for (int i = 0; i < size; i++) { + int64_t encoding = static_cast<int64_t>(text[i]); + (*encoded_text)[i] = encoding; + } + + return true; +} + +} // namespace libtextclassifier3 diff --git a/native/utils/tokenfree/byte_encoder.h b/native/utils/tokenfree/byte_encoder.h new file mode 100644 index 0000000..1a495ec --- /dev/null +++ b/native/utils/tokenfree/byte_encoder.h @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_ +#define LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_ + +#include <vector> + +#include "utils/base/logging.h" +#include "utils/container/string-set.h" +#include "utils/strings/stringpiece.h" + +namespace libtextclassifier3 { + +// Encoder to segment/tokenize strings into bytes +class ByteEncoder { + public: + bool Encode(StringPiece input_text, std::vector<int64_t>* encoded_text) const; + ByteEncoder() {} +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_ diff --git a/native/utils/tokenfree/byte_encoder_test.cc b/native/utils/tokenfree/byte_encoder_test.cc new file mode 100644 index 0000000..d4d119e --- /dev/null +++ b/native/utils/tokenfree/byte_encoder_test.cc @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/tokenfree/byte_encoder.h" + +#include <memory> +#include <vector> + +#include "utils/base/integral_types.h" +#include "utils/container/sorted-strings-table.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier3 { +namespace { + +using testing::ElementsAre; + +TEST(EncoderTest, SimpleTokenization) { + const ByteEncoder encoder; + { + std::vector<int64_t> encoded_text; + EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text)); + EXPECT_THAT(encoded_text, + ElementsAre(104, 101, 108, 108, 111, 116, 104, 101, 114, 101)); + } +} + +TEST(EncoderTest, SimpleTokenization2) { + const ByteEncoder encoder; + { + std::vector<int64_t> encoded_text; + EXPECT_TRUE(encoder.Encode("Hello", &encoded_text)); + EXPECT_THAT(encoded_text, ElementsAre(72, 101, 108, 108, 111)); + } +} +} // namespace +} // namespace libtextclassifier3 diff --git a/native/utils/tokenizer.cc b/native/utils/tokenizer.cc index 071141c..7038517 100644 --- a/native/utils/tokenizer.cc +++ b/native/utils/tokenizer.cc @@ -43,11 +43,12 @@ Tokenizer::Tokenizer( codepoint_ranges_.emplace_back(range->UnPack()); } - std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(), - [](const std::unique_ptr<const TokenizationCodepointRangeT>& a, - const std::unique_ptr<const TokenizationCodepointRangeT>& b) { - return a->start < b->start; - }); + std::stable_sort( + codepoint_ranges_.begin(), codepoint_ranges_.end(), + [](const std::unique_ptr<const TokenizationCodepointRangeT>& a, + const std::unique_ptr<const TokenizationCodepointRangeT>& b) { + return a->start < b->start; + }); SortCodepointRanges(internal_tokenizer_codepoint_ranges, &internal_tokenizer_codepoint_ranges_); diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java index bc30fcf..f539ba7 100644 --- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java +++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java @@ -37,15 +37,20 @@ import androidx.test.filters.LargeTest; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @LargeTest @RunWith(AndroidJUnit4.class) public class SmartSuggestionsLogSessionTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private static final String RESULT_ID = "resultId"; private static final String REPLY = "reply"; private static final float SCORE = 0.5f; @@ -55,7 +60,6 @@ public class SmartSuggestionsLogSessionTest { @Before public void setup() { - MockitoAnnotations.initMocks(this); session = new SmartSuggestionsLogSession( |