diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-04-28 15:57:22 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-04-28 15:57:22 +0000 |
commit | c67d414280899fdb72658eb513415289a1751c07 (patch) | |
tree | 4b3ec7a534166d5663d826de85b941d82e74ea9b | |
parent | aa4d582837324cd2cb9e32f01fcb2553d16fa1bf (diff) | |
parent | 775e966e07fb11a55afff5ab93b79128c29a84ac (diff) | |
download | libtextclassifier-android13-frc-cellbroadcast-release.tar.gz |
Snap for 8512216 from 775e966e07fb11a55afff5ab93b79128c29a84ac to tm-frc-cellbroadcast-releaset_frc_cbr_330443000android13-frc-cellbroadcast-release
Change-Id: Ib96d5698a907ff3295f80cf84ff6b39c48bc6065
19 files changed, 460 insertions, 433 deletions
@@ -2,6 +2,6 @@ # Please update this list if you find better candidates. tonymak@google.com toki@google.com -zilka@google.com -mns@google.com -jalt@google.com +licha@google.com +joannechung@google.com +lpeter@google.com
\ No newline at end of file diff --git a/TEST_MAPPING b/TEST_MAPPING index 72e022b..370acd6 100644 --- a/TEST_MAPPING +++ b/TEST_MAPPING @@ -21,6 +21,25 @@ "name": "TCSModelDownloaderIntegrationTest" } ], + "hwasan-postsubmit": [ + { + "name": "TextClassifierServiceTest", + "options": [ + { + "exclude-annotation": "androidx.test.filters.FlakyTest" + } + ] + }, + { + "name": "libtextclassifier_tests" + }, + { + "name": "libtextclassifier_java_tests" + }, + { + "name": "TextClassifierNotificationTests" + } + ], "mainline-presubmit": [ { "name": "TextClassifierNotificationTests[com.google.android.extservices.apex]" diff --git a/java/src/com/android/textclassifier/ExtrasUtils.java b/java/src/com/android/textclassifier/ExtrasUtils.java index fd64581..bde3898 100644 --- a/java/src/com/android/textclassifier/ExtrasUtils.java +++ b/java/src/com/android/textclassifier/ExtrasUtils.java @@ -87,7 +87,9 @@ public final class ExtrasUtils { return classification.getExtras().getBundle(FOREIGN_LANGUAGE); } - /** @see #getTopLanguage(Intent) */ + /** + * @see #getTopLanguage(Intent) + */ static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) { final int maxSize = Math.min(3, languageScores.getEntities().size()); final String[] languages = diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java index 1ae79ce..9bdfb5e 100644 --- a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java +++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java @@ -195,7 +195,7 @@ public final class DownloadedModelManagerImpl implements DownloadedModelManager @Override public void onDownloadCompleted( ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload) { - TcLog.v(TAG, "Start to clean up models and update model lookup cache..."); + TcLog.d(TAG, "Start to clean up models and update model lookup cache..."); // Step 1: Clean up ManifestEnrollment table List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments(); List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>(); @@ -286,7 +286,7 @@ public final class DownloadedModelManagerImpl implements DownloadedModelManager // Clear the cache table and rebuild the cache based on ModelView table private void updateCache() { synchronized (cacheLock) { - TcLog.v(TAG, "Updating model lookup cache..."); + TcLog.d(TAG, "Updating model lookup cache..."); for (String modelType : ModelType.values()) { modelLookupCache.get(modelType).clear(); } diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java index b125f13..af33e21 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java @@ -44,6 +44,7 @@ import com.android.textclassifier.utils.IndentingPrintWriter; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Enums; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.hash.Hashing; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; @@ -54,6 +55,7 @@ import java.time.Instant; import java.util.List; import java.util.Locale; import java.util.UUID; +import java.util.concurrent.Callable; import javax.annotation.Nullable; /** Manager to listen to config update and download latest models. */ @@ -64,6 +66,7 @@ public final class ModelDownloadManager { private final Context appContext; private final Class<? extends ListenableWorker> modelDownloadWorkerClass; + private final Callable<WorkManager> workManagerSupplier; private final DownloadedModelManager downloadedModelManager; private final TextClassifierSettings settings; private final ListeningExecutorService executorService; @@ -84,6 +87,7 @@ public final class ModelDownloadManager { this( appContext, ModelDownloadWorker.class, + () -> WorkManager.getInstance(appContext), DownloadedModelManagerImpl.getInstance(appContext), settings, executorService); @@ -93,11 +97,13 @@ public final class ModelDownloadManager { public ModelDownloadManager( Context appContext, Class<? extends ListenableWorker> modelDownloadWorkerClass, + Callable<WorkManager> workManagerSupplier, DownloadedModelManager downloadedModelManager, TextClassifierSettings settings, ListeningExecutorService executorService) { this.appContext = Preconditions.checkNotNull(appContext); this.modelDownloadWorkerClass = Preconditions.checkNotNull(modelDownloadWorkerClass); + this.workManagerSupplier = Preconditions.checkNotNull(workManagerSupplier); this.downloadedModelManager = Preconditions.checkNotNull(downloadedModelManager); this.settings = Preconditions.checkNotNull(settings); this.executorService = Preconditions.checkNotNull(executorService); @@ -121,22 +127,31 @@ public final class ModelDownloadManager { /** Returns the downlaoded models for the given modelType. */ @Nullable public List<File> listDownloadedModels(@ModelTypeDef String modelType) { - return downloadedModelManager.listModels(modelType); + try { + return downloadedModelManager.listModels(modelType); + } catch (Throwable t) { + TcLog.e(TAG, "Failed to list downloaded models", t); + return ImmutableList.of(); + } } /** Notifies the model downlaoder that the text classifier service is created. */ public void onTextClassifierServiceCreated() { - DeviceConfig.addOnPropertiesChangedListener( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener); - appContext.registerReceiver( - localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED)); - TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered."); - if (!settings.isModelDownloadManagerEnabled()) { - return; + try { + DeviceConfig.addOnPropertiesChangedListener( + DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener); + appContext.registerReceiver( + localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED)); + TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered."); + if (!settings.isModelDownloadManagerEnabled()) { + return; + } + maybeOverrideLocaleListForTesting(); + TcLog.d(TAG, "Try to schedule model download work because TextClassifierService started."); + scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED); + } catch (Throwable t) { + TcLog.e(TAG, "Failed inside onTextClassifierServiceCreated", t); } - maybeOverrideLocaleListForTesting(); - TcLog.v(TAG, "Try to schedule model download work because TextClassifierService started."); - scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED); } // TODO(licha): Make this private. Let the constructor accept a receiver to enable testing. @@ -146,8 +161,12 @@ public final class ModelDownloadManager { if (!settings.isModelDownloadManagerEnabled()) { return; } - TcLog.v(TAG, "Try to schedule model download work because of system locale changes."); - scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED); + TcLog.d(TAG, "Try to schedule model download work because of system locale changes."); + try { + scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED); + } catch (Throwable t) { + TcLog.e(TAG, "Failed inside onLocaleChanged", t); + } } // TODO(licha): Make this private. Let the constructor accept a receiver to enable testing. @@ -157,16 +176,24 @@ public final class ModelDownloadManager { if (!settings.isModelDownloadManagerEnabled()) { return; } - maybeOverrideLocaleListForTesting(); - TcLog.v(TAG, "Try to schedule model download work because of device config changes."); - scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED); + TcLog.d(TAG, "Try to schedule model download work because of device config changes."); + try { + maybeOverrideLocaleListForTesting(); + scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED); + } catch (Throwable t) { + TcLog.e(TAG, "Failed inside onTextClassifierDeviceConfigChanged", t); + } } /** Clean up internal states on destroying. */ public void destroy() { - DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener); - appContext.unregisterReceiver(localeChangedReceiver); - TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager"); + try { + DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener); + appContext.unregisterReceiver(localeChangedReceiver); + TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager"); + } catch (Throwable t) { + TcLog.e(TAG, "Failed to destroy ModelDownloadManager", t); + } } /** @@ -178,10 +205,14 @@ public final class ModelDownloadManager { if (!settings.isModelDownloadManagerEnabled()) { return; } - printWriter.println("ModelDownloadManager:"); - printWriter.increaseIndent(); - downloadedModelManager.dump(printWriter); - printWriter.decreaseIndent(); + try { + printWriter.println("ModelDownloadManager:"); + printWriter.increaseIndent(); + downloadedModelManager.dump(printWriter); + printWriter.decreaseIndent(); + } catch (Throwable t) { + TcLog.e(TAG, "Failed to dump ModelDownloadManager", t); + } } /** @@ -193,54 +224,62 @@ public final class ModelDownloadManager { private void scheduleDownloadWork(int reasonToSchedule) { long workId = Hashing.farmHashFingerprint64().hashUnencodedChars(UUID.randomUUID().toString()).asLong(); - NetworkType networkType = - Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType()) - .or(NetworkType.UNMETERED); - OneTimeWorkRequest downloadRequest = - new OneTimeWorkRequest.Builder(modelDownloadWorkerClass) - .setConstraints( - new Constraints.Builder() - .setRequiredNetworkType(networkType) - .setRequiresBatteryNotLow(true) - .setRequiresStorageNotLow(true) - .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle()) - .setRequiresCharging(settings.getManifestDownloadRequiresCharging()) - .build()) - .setBackoffCriteria( - BackoffPolicy.EXPONENTIAL, - settings.getModelDownloadBackoffDelayInMillis(), - MILLISECONDS) - .setInputData( - new Data.Builder() - .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId) - .putLong( - ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP, - Instant.now().toEpochMilli()) - .build()) - .build(); - ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture = - WorkManager.getInstance(appContext) - .enqueueUniqueWork( - UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest) - .getResult(); - Futures.addCallback( - enqueueResultFuture, - new FutureCallback<Operation.State.SUCCESS>() { - @Override - public void onSuccess(Operation.State.SUCCESS unused) { - TcLog.v(TAG, "Download work scheduled."); - TextClassifierDownloadLogger.downloadWorkScheduled( - workId, reasonToSchedule, /* failedToSchedule= */ false); - } + try { + NetworkType networkType = + Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType()) + .or(NetworkType.UNMETERED); + OneTimeWorkRequest downloadRequest = + new OneTimeWorkRequest.Builder(modelDownloadWorkerClass) + .setConstraints( + new Constraints.Builder() + .setRequiredNetworkType(networkType) + .setRequiresBatteryNotLow(true) + .setRequiresStorageNotLow(true) + .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle()) + .setRequiresCharging(settings.getManifestDownloadRequiresCharging()) + .build()) + .setBackoffCriteria( + BackoffPolicy.EXPONENTIAL, + settings.getModelDownloadBackoffDelayInMillis(), + MILLISECONDS) + .setInputData( + new Data.Builder() + .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId) + .putLong( + ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP, + Instant.now().toEpochMilli()) + .build()) + .build(); + ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture = + workManagerSupplier + .call() + .enqueueUniqueWork( + UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest) + .getResult(); + Futures.addCallback( + enqueueResultFuture, + new FutureCallback<Operation.State.SUCCESS>() { + @Override + public void onSuccess(Operation.State.SUCCESS unused) { + TcLog.d(TAG, "Download work scheduled."); + TextClassifierDownloadLogger.downloadWorkScheduled( + workId, reasonToSchedule, /* failedToSchedule= */ false); + } - @Override - public void onFailure(Throwable t) { - TcLog.e(TAG, "Failed to schedule download work: ", t); - TextClassifierDownloadLogger.downloadWorkScheduled( - workId, reasonToSchedule, /* failedToSchedule= */ true); - } - }, - executorService); + @Override + public void onFailure(Throwable t) { + TcLog.e(TAG, "Failed to schedule download work: ", t); + TextClassifierDownloadLogger.downloadWorkScheduled( + workId, reasonToSchedule, /* failedToSchedule= */ true); + } + }, + executorService); + } catch (Throwable t) { + // TODO(licha): this is just for temporary fix. Refactor the try-catch in the future. + TcLog.e(TAG, "Failed to schedule download work: ", t); + TextClassifierDownloadLogger.downloadWorkScheduled( + workId, reasonToSchedule, /* failedToSchedule= */ true); + } } private void maybeOverrideLocaleListForTesting() { diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java index 6e04e16..3db0815 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java @@ -113,6 +113,7 @@ public final class ModelDownloadWorker extends ListenableWorker { @Override public final ListenableFuture<ListenableWorker.Result> startWork() { + TcLog.d(TAG, "Start download work..."); workStartedTimeMillis = getCurrentTimeMillis(); // Notice: startWork() is invoked on the main thread if (!settings.isModelDownloadManagerEnabled()) { @@ -121,7 +122,6 @@ public final class ModelDownloadWorker extends ListenableWorker { TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED); return Futures.immediateFuture(ListenableWorker.Result.failure()); } - TcLog.v(TAG, "Start download work..."); if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) { TcLog.d(TAG, "Max attempt reached. Abort download work."); logDownloadWorkCompleted( @@ -134,7 +134,7 @@ public final class ModelDownloadWorker extends ListenableWorker { downloadResult -> { Preconditions.checkNotNull(manifestsToDownload); downloadedModelManager.onDownloadCompleted(manifestsToDownload); - TcLog.v(TAG, "Download work completed: " + downloadResult); + TcLog.d(TAG, "Download work completed: " + downloadResult); if (downloadResult.failureCount() == 0) { logDownloadWorkCompleted( downloadResult.successCount() > 0 @@ -239,7 +239,7 @@ public final class ModelDownloadWorker extends ListenableWorker { return Futures.whenAllComplete(downloadResultFutures) .call( () -> { - TcLog.v(TAG, "All Download Tasks Completed"); + TcLog.d(TAG, "All Download Tasks Completed"); int successCount = 0; int failureCount = 0; for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) { @@ -333,7 +333,7 @@ public final class ModelDownloadWorker extends ListenableWorker { Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl); if (downloadedManifest != null && downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) { - TcLog.v(TAG, "Manifest already downloaded: " + manifestUrl); + TcLog.d(TAG, "Manifest already downloaded: " + manifestUrl); return Futures.immediateVoidFuture(); } if (pendingDownloads.containsKey(manifestUrl)) { diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java index 2244e9a..0b76f22 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java @@ -99,7 +99,7 @@ final class ModelDownloaderImpl implements ModelDownloader { new FutureCallback<File>() { @Override public void onSuccess(File pendingModelFile) { - TcLog.v(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath()); + TcLog.d(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath()); } @Override @@ -170,11 +170,11 @@ final class ModelDownloaderImpl implements ModelDownloader { } catch (IOException e) { throw new ModelDownloadException(ModelDownloadException.FAILED_TO_VALIDATE_MODEL, e); } - TcLog.v(TAG, "Pending model file passed validation."); + TcLog.d(TAG, "Pending model file passed validation."); } private ListenableFuture<IModelDownloaderService> connect(DownloaderServiceConnection conn) { - TcLog.v(TAG, "Starting a new connection to ModelDownloaderService"); + TcLog.d(TAG, "Starting a new connection to ModelDownloaderService"); return CallbackToFutureAdapter.getFuture( completer -> { conn.attachCompleter(completer); @@ -197,7 +197,7 @@ final class ModelDownloaderImpl implements ModelDownloader { // restult future will hang there until time out. (WorkManager forces a 10-min running time.) private static ListenableFuture<Long> scheduleDownload( IModelDownloaderService service, URI uri, File targetFile) { - TcLog.v(TAG, "Scheduling a new download task with ModelDownloaderService"); + TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService"); return CallbackToFutureAdapter.getFuture( completer -> { service.download( @@ -236,7 +236,7 @@ final class ModelDownloaderImpl implements ModelDownloader { @Override public void onServiceConnected(ComponentName componentName, IBinder iBinder) { - TcLog.v(TAG, "DownloaderService connected"); + TcLog.d(TAG, "DownloaderService connected"); completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder))); } diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java index e4ebbfa..6d7e47e 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java @@ -39,7 +39,7 @@ public final class ModelDownloaderService extends Service { @Override public IBinder onBind(Intent intent) { - TcLog.v(TAG, "Binding to ModelDownloadService"); + TcLog.d(TAG, "Binding to ModelDownloadService"); return iBinder; } } diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java index 439588b..47e6f19 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java @@ -91,7 +91,7 @@ final class ModelDownloaderServiceImpl extends IModelDownloaderService.Stub { @Override public void download(String uri, String targetFilePath, IModelDownloaderCallback callback) { - TcLog.v(TAG, "Download request received: " + uri); + TcLog.d(TAG, "Download request received: " + uri); try { File targetFile = new File(targetFilePath); File tempMetadataFile = getMetadataFile(targetFile); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java index 71f9a4f..ddab8bd 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java @@ -17,7 +17,10 @@ package com.android.textclassifier; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import android.content.Context; import android.os.CancellationSignal; @@ -39,6 +42,7 @@ import com.android.os.AtomsProto.Atom; import com.android.os.AtomsProto.TextClassifierApiUsageReported; import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType; import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType; +import com.android.textclassifier.common.ModelType; import com.android.textclassifier.common.TextClassifierSettings; import com.android.textclassifier.common.statsd.StatsdTestUtils; import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger; @@ -47,6 +51,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; +import java.io.IOException; import java.util.List; import java.util.concurrent.Executor; import java.util.stream.Collectors; @@ -81,13 +86,21 @@ public class DefaultTextClassifierServiceTest { @Mock private TextClassifierService.Callback<TextLinks> textLinksCallback; @Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback; @Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback; + @Mock private ModelFileManager testModelFileManager; @Before - public void setup() { - - testInjector = new TestInjector(ApplicationProvider.getApplicationContext()); + public void setup() throws IOException { + testInjector = + new TestInjector(ApplicationProvider.getApplicationContext(), testModelFileManager); defaultTextClassifierService = new DefaultTextClassifierService(testInjector); defaultTextClassifierService.onCreate(); + + when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); + when(testModelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) + .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); + when(testModelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) + .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); } @Before @@ -211,11 +224,8 @@ public class DefaultTextClassifierServiceTest { @Test public void missingModelFile_onFailureShouldBeCalled() throws Exception { - testInjector.setModelFileManager( - new ModelFileManagerImpl( - ApplicationProvider.getApplicationContext(), - ImmutableList.of(), - testInjector.createTextClassifierSettings())); + when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(null); defaultTextClassifierService.onCreate(); TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build(); @@ -251,12 +261,9 @@ public class DefaultTextClassifierServiceTest { private final Context context; private ModelFileManager modelFileManager; - private TestInjector(Context context) { + private TestInjector(Context context, ModelFileManager modelFileManager) { this.context = Preconditions.checkNotNull(context); - } - - private void setModelFileManager(ModelFileManager modelFileManager) { - this.modelFileManager = modelFileManager; + this.modelFileManager = Preconditions.checkNotNull(modelFileManager); } @Override @@ -267,9 +274,6 @@ public class DefaultTextClassifierServiceTest { @Override public ModelFileManager createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) { - if (modelFileManager == null) { - return TestDataUtils.createModelFileManagerForTesting(context); - } return modelFileManager; } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java index 20ae592..0e40515 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java @@ -25,6 +25,7 @@ import android.os.LocaleList; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import androidx.test.filters.SmallTest; +import androidx.work.WorkManager; import com.android.textclassifier.ModelFileManagerImpl.DownloaderModelsLister; import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister; import com.android.textclassifier.ModelFileManagerImpl.RegularFilePatternMatchLister; @@ -87,6 +88,7 @@ public final class ModelFileManagerImplTest { new ModelDownloadManager( context, ModelDownloadWorker.class, + () -> WorkManager.getInstance(context), downloadedModelManager, settings, MoreExecutors.newDirectExecutorService()); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java index bac4fa1..a19e3ff 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java @@ -16,12 +16,10 @@ package com.android.textclassifier; -import android.content.Context; -import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister; +import com.android.textclassifier.common.ModelFile; import com.android.textclassifier.common.ModelType; -import com.android.textclassifier.common.TextClassifierSettings; -import com.google.common.collect.ImmutableList; import java.io.File; +import java.io.IOException; /** Utils to access test data files. */ public final class TestDataUtils { @@ -30,7 +28,7 @@ public final class TestDataUtils { private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model"; /** Returns the root folder that contains the test data. */ - public static File getTestDataFolder() { + private static File getTestDataFolder() { return new File("/data/local/tmp/TextClassifierServiceTest/"); } @@ -38,24 +36,25 @@ public final class TestDataUtils { return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH); } + public static ModelFile getTestAnnotatorModelFileWrapped() throws IOException { + return ModelFile.createFromRegularFile(getTestAnnotatorModelFile(), ModelType.ANNOTATOR); + } + public static File getTestActionsModelFile() { return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH); } + public static ModelFile getTestActionsModelFileWrapped() throws IOException { + return ModelFile.createFromRegularFile( + getTestActionsModelFile(), ModelType.ACTIONS_SUGGESTIONS); + } + public static File getLangIdModelFile() { return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH); } - public static ModelFileManager createModelFileManagerForTesting(Context context) { - return new ModelFileManagerImpl( - context, - ImmutableList.of( - new RegularFileFullMatchLister( - ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true), - new RegularFileFullMatchLister( - ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true), - new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)), - new TextClassifierSettings()); + public static ModelFile getLangIdModelFileWrapped() throws IOException { + return ModelFile.createFromRegularFile(getLangIdModelFile(), ModelType.LANG_ID); } private TestDataUtils() {} diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java index 42177e6..e7bf90c 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java @@ -56,6 +56,10 @@ public class TextClassifierApiTest { @Before public void setup() { + extServicesTextClassifierRule.enableVerboseLogging(); + // Verbose logging only takes effect after restarting ExtServices + extServicesTextClassifierRule.forceStopExtServices(); + textClassifier = extServicesTextClassifierRule.getTextClassifier(); } @@ -81,8 +85,8 @@ public class TextClassifierApiTest { @Test public void classifyText() { - String text = "Contact me at droid@android.com"; - String classifiedText = "droid@android.com"; + String text = "Contact me at http://www.android.com"; + String classifiedText = "http://www.android.com"; int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = @@ -90,7 +94,7 @@ public class TextClassifierApiTest { TextClassification classification = textClassifier.classifyText(request); assertThat(classification.getEntityCount()).isGreaterThan(0); - assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL); + assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL); assertThat(classification.getText()).isEqualTo(classifiedText); assertThat(classification.getActions()).isNotEmpty(); } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java index fb1aea8..c20ec8a 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java @@ -22,6 +22,9 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.testng.Assert.expectThrows; @@ -74,30 +77,34 @@ public class TextClassifierImplTest { private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US"); private static final String NO_TYPE = null; - @Mock private ModelFileManagerImpl.ModelFileLister mockModelFileLister; + @Mock private ModelFileManager modelFileManager; - private TextClassifierSettings settings; private Context context; private TestingDeviceConfig deviceConfig; - private TextClassifierImpl classifier; - - private final ModelFileManager modelFileManager = - TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext()); + private TextClassifierSettings settings; private LruCache<ModelFile, AnnotatorModel> annotatorModelCache; + private TextClassifierImpl classifier; @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.initMocks(this); - deviceConfig = new TestingDeviceConfig(); - Context context = + this.context = new FakeContextBuilder() .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT) .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app") .build(); - this.context = context; - settings = new TextClassifierSettings(deviceConfig); - // TODO(veronikanikina): consider using a testing constructor here. - classifier = new TextClassifierImpl(context, settings, modelFileManager); + this.deviceConfig = new TestingDeviceConfig(); + this.settings = new TextClassifierSettings(deviceConfig); + this.annotatorModelCache = new LruCache<>(2); + this.classifier = + new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache); + + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); + when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) + .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); + when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) + .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); } @Test @@ -110,9 +117,7 @@ public class TextClassifierImplTest { int smartStartIndex = text.indexOf(suggested); int smartEndIndex = smartStartIndex + suggested.length(); TextSelection.Request request = - new TextSelection.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat( @@ -120,6 +125,24 @@ public class TextClassifierImplTest { } @Test + public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException { + String text = "Contact me at droid@android.com"; + String selected = "droid"; + String suggested = "droid@android.com"; + int startIndex = text.indexOf(selected); + int endIndex = startIndex + selected.length(); + int smartStartIndex = text.indexOf(suggested); + int smartEndIndex = smartStartIndex + suggested.length(); + TextSelection.Request request = + new TextSelection.Request.Builder(text, startIndex, endIndex) + .setDefaultLocales(LOCALES) + .build(); + + classifier.suggestSelection(null, null, request); + verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any()); + } + + @Test public void testSuggestSelection_url() throws IOException { String text = "Visit http://www.android.com for more information"; String selected = "http"; @@ -129,9 +152,7 @@ public class TextClassifierImplTest { int smartStartIndex = text.indexOf(suggested); int smartEndIndex = smartStartIndex + suggested.length(); TextSelection.Request request = - new TextSelection.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL)); @@ -144,9 +165,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(selected); int endIndex = startIndex + selected.length(); TextSelection.Request request = - new TextSelection.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE)); @@ -160,7 +179,6 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(suggested); TextSelection.Request request = new TextSelection.Request.Builder(text, startIndex, /*endIndex=*/ startIndex + 1) - .setDefaultLocales(LOCALES) .setIncludeTextClassification(true) .build(); @@ -178,7 +196,6 @@ public class TextClassifierImplTest { String text = "Visit http://www.android.com for more information"; TextSelection.Request request = new TextSelection.Request.Builder(text, /*startIndex=*/ 0, /*endIndex=*/ 4) - .setDefaultLocales(LOCALES) .setIncludeTextClassification(false) .build(); @@ -194,9 +211,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(/* sessionId= */ null, null, request); @@ -210,9 +225,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); @@ -223,9 +236,7 @@ public class TextClassifierImplTest { public void testClassifyText_address() throws IOException { String text = "Brandschenkestrasse 110, Zürich, Switzerland"; TextClassification.Request request = - new TextClassification.Request.Builder(text, 0, text.length()) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, 0, text.length()).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS)); @@ -238,9 +249,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); @@ -254,9 +263,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE)); @@ -275,9 +282,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME)); @@ -289,14 +294,12 @@ public class TextClassifierImplTest { LocaleList.setDefault(LocaleList.forLanguageTags("en")); String japaneseText = "これは日本語のテキストです"; TextClassification.Request request = - new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build(); TextClassification classification = classifier.classifyText(null, null, request); RemoteAction translateAction = classification.getActions().get(0); assertEquals(1, classification.getActions().size()); - assertEquals("Translate", translateAction.getTitle().toString()); + assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction()); assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification)); Intent intent = ExtrasUtils.getActionsIntents(classification).get(0); @@ -323,18 +326,17 @@ public class TextClassifierImplTest { @Test public void testGenerateLinks_exclude() throws IOException { - String text = "You want apple@banana.com. See you tonight!"; + String text = "The number is +12122537077. See you tonight!"; List<String> hints = ImmutableList.of(); List<String> included = ImmutableList.of(); - List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL); + List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE); TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) - .setDefaultLocales(LOCALES) .build(); assertThat( classifier.generateLinks(null, null, request), - not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL))); + not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE))); } @Test @@ -344,7 +346,6 @@ public class TextClassifierImplTest { TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit)) - .setDefaultLocales(LOCALES) .build(); assertThat( classifier.generateLinks(null, null, request), @@ -361,7 +362,6 @@ public class TextClassifierImplTest { TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) - .setDefaultLocales(LOCALES) .build(); assertThat( classifier.generateLinks(null, null, request), @@ -573,29 +573,16 @@ public class TextClassifierImplTest { new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); ModelFile annotatorModelB = new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); - String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath(); - ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false); - - annotatorModelCache = new LruCache<>(2); - ModelFileManager modelFileManagerCached = - new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings); - TextClassifierImpl textClassifierImpl = - new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache); - LocaleList.setDefault(LocaleList.forLanguageTags("en")); String englishText = "You can reach me on +12122537077."; String classifiedText = "+12122537077"; TextClassification.Request request = - new TextClassification.Request.Builder(englishText, 0, englishText.length()) - .setDefaultLocales(LOCALES) - .build(); - - when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel)); + new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); // Check modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classificationA = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classificationA = classifier.classifyText(null, null, request); assertThat(classificationA.getId()).contains("v701"); assertThat(classificationA.getText()).contains(classifiedText); @@ -609,9 +596,9 @@ public class TextClassifierImplTest { }); // Check modelFileB v801 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelB)); - TextClassification classificationB = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelB); + TextClassification classificationB = classifier.classifyText(null, null, request); assertThat(classificationB.getId()).contains("v801"); assertThat(classificationB.getText()).contains(classifiedText); @@ -625,9 +612,9 @@ public class TextClassifierImplTest { }); // Reload modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classificationAcached = classifier.classifyText(null, null, request); assertThat(classificationAcached.getId()).contains("v701"); assertThat(classificationAcached.getText()).contains(classifiedText); @@ -651,28 +638,16 @@ public class TextClassifierImplTest { new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); ModelFile annotatorModelB = new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); - String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath(); - ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false); - - annotatorModelCache = new LruCache<>(settings.getMultiAnnotatorCacheSize()); - ModelFileManager modelFileManagerCached = - new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings); - TextClassifierImpl textClassifierImpl = - new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache); - LocaleList.setDefault(LocaleList.forLanguageTags("en")); + String englishText = "You can reach me on +12122537077."; String classifiedText = "+12122537077"; TextClassification.Request request = - new TextClassification.Request.Builder(englishText, 0, englishText.length()) - .setDefaultLocales(LOCALES) - .build(); - - when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel)); + new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); // Check modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classification = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification.getId()).contains("v701"); assertThat(classification.getText()).contains(classifiedText); @@ -686,9 +661,9 @@ public class TextClassifierImplTest { }); // Check modelFileB v801 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelB)); - TextClassification classificationB = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelB); + TextClassification classificationB = classifier.classifyText(null, null, request); assertThat(classificationB.getId()).contains("v801"); assertThat(classificationB.getText()).contains(classifiedText); @@ -702,9 +677,9 @@ public class TextClassifierImplTest { }); // Reload modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classificationAcached = classifier.classifyText(null, null, request); assertThat(classificationAcached.getId()).contains("v701"); assertThat(classificationAcached.getText()).contains(classifiedText); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java index 394b7ad..9e11c09 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java @@ -67,6 +67,7 @@ public final class ModelDownloadManagerTest { private TestingDeviceConfig deviceConfig; private WorkManager workManager; private ModelDownloadManager downloadManager; + private ModelDownloadManager downloadManagerWithBadWorkManager; @Mock DownloadedModelManager downloadedModelManager; @Before @@ -80,6 +81,17 @@ public final class ModelDownloadManagerTest { new ModelDownloadManager( context, ModelDownloadWorker.class, + () -> workManager, + downloadedModelManager, + new TextClassifierSettings(deviceConfig), + MoreExecutors.newDirectExecutorService()); + this.downloadManagerWithBadWorkManager = + new ModelDownloadManager( + context, + ModelDownloadWorker.class, + () -> { + throw new IllegalStateException("WorkManager may fail!"); + }, downloadedModelManager, new TextClassifierSettings(deviceConfig), MoreExecutors.newDirectExecutorService()); @@ -96,7 +108,20 @@ public final class ModelDownloadManagerTest { } @Test + public void onTextClassifierServiceCreated_workManagerCrashed() throws Exception { + assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); + downloadManagerWithBadWorkManager.onTextClassifierServiceCreated(); + + // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test + TextClassifierDownloadWorkScheduled atom = + Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); + assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.TCS_STARTED); + assertThat(atom.getFailedToSchedule()).isTrue(); + } + + @Test public void onTextClassifierServiceCreated_requestEnqueued() throws Exception { + assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); downloadManager.onTextClassifierServiceCreated(); WorkInfo workInfo = @@ -104,21 +129,34 @@ public final class ModelDownloadManagerTest { DownloaderTestUtils.queryWorkInfos( workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)); assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED); + // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED); } @Test public void onTextClassifierServiceCreated_localeListOverridden() throws Exception { + assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr"); downloadManager.onTextClassifierServiceCreated(); assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh")); assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); + // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED); } @Test + public void onLocaleChanged_workManagerCrashed() throws Exception { + downloadManagerWithBadWorkManager.onLocaleChanged(); + + TextClassifierDownloadWorkScheduled atom = + Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); + assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.LOCALE_SETTINGS_CHANGED); + assertThat(atom.getFailedToSchedule()).isTrue(); + } + + @Test public void onLocaleChanged_requestEnqueued() throws Exception { downloadManager.onLocaleChanged(); @@ -131,6 +169,16 @@ public final class ModelDownloadManagerTest { } @Test + public void onTextClassifierDeviceConfigChanged_workManagerCrashed() throws Exception { + downloadManagerWithBadWorkManager.onTextClassifierDeviceConfigChanged(); + + TextClassifierDownloadWorkScheduled atom = + Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); + assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.DEVICE_CONFIG_UPDATED); + assertThat(atom.getFailedToSchedule()).isTrue(); + } + + @Test public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception { downloadManager.onTextClassifierDeviceConfigChanged(); @@ -188,6 +236,13 @@ public final class ModelDownloadManagerTest { assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile); } + @Test + public void listDownloadedModels_doNotCrashOnError() throws Exception { + when(downloadedModelManager.listModels(MODEL_TYPE)).thenThrow(new IllegalStateException()); + + assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).isEmpty(); + } + private void verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule) throws Exception { TextClassifierDownloadWorkScheduled atom = Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java index e4360c6..e261158 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java @@ -18,16 +18,10 @@ package com.android.textclassifier.downloader; import static com.google.common.truth.Truth.assertThat; -import android.app.Instrumentation; -import android.app.UiAutomation; import android.util.Log; import android.view.textclassifier.TextClassification; import android.view.textclassifier.TextClassification.Request; -import android.view.textclassifier.TextClassifier; -import androidx.test.platform.app.InstrumentationRegistry; import com.android.textclassifier.testing.ExtServicesTextClassifierRule; -import com.android.textclassifier.testing.TestingLocaleListOverrideRule; -import java.io.IOException; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -48,171 +42,133 @@ public class ModelDownloaderIntegrationTest { private static final String V804_EN_TAG = "en_v804"; private static final String V804_RU_TAG = "ru_v804"; private static final String FACTORY_MODEL_TAG = "*"; - - @Rule - public final TestingLocaleListOverrideRule testingLocaleListOverrideRule = - new TestingLocaleListOverrideRule(); + private static final int ASSERT_MAX_ATTEMPTS = 20; + private static final int ASSERT_SLEEP_BEFORE_RETRY_MS = 1000; @Rule public final ExtServicesTextClassifierRule extServicesTextClassifierRule = new ExtServicesTextClassifierRule(); - private TextClassifier textClassifier; - @Before public void setup() throws Exception { - // Flag overrides below can be overridden by Phenotype sync, which makes this test flaky - runShellCommand("device_config put textclassifier config_updater_model_enabled false"); - runShellCommand("device_config put textclassifier model_download_manager_enabled true"); - runShellCommand("device_config put textclassifier model_download_backoff_delay_in_millis 5"); - - textClassifier = extServicesTextClassifierRule.getTextClassifier(); - startExtservicesProcess(); + extServicesTextClassifierRule.addDeviceConfigOverride("config_updater_model_enabled", "false"); + extServicesTextClassifierRule.addDeviceConfigOverride("model_download_manager_enabled", "true"); + extServicesTextClassifierRule.addDeviceConfigOverride( + "model_download_backoff_delay_in_millis", "5"); + extServicesTextClassifierRule.addDeviceConfigOverride("testing_locale_list_override", "en-US"); + extServicesTextClassifierRule.overrideDeviceConfig(); + + extServicesTextClassifierRule.enableVerboseLogging(); + // Verbose logging only takes effect after restarting ExtServices + extServicesTextClassifierRule.forceStopExtServices(); } @After public void tearDown() throws Exception { - runShellCommand("device_config delete textclassifier manifest_url_annotator_en"); - runShellCommand("device_config delete textclassifier manifest_url_annotator_ru"); - runShellCommand("device_config put textclassifier config_updater_model_enabled true"); - runShellCommand("device_config delete textclassifier multi_language_support_enabled"); - runShellCommand( - "device_config put textclassifier model_download_backoff_delay_in_millis 3600000"); + // This is to reset logging/locale_override for ExtServices. + extServicesTextClassifierRule.forceStopExtServices(); } @Test - public void smokeTest() throws IOException, InterruptedException { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + V804_EN_ANNOTATOR_MANIFEST_URL); + public void smokeTest() throws Exception { + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, /* sleepMs= */ 1000, () -> verifyActiveEnglishModel(V804_EN_TAG)); + assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG)); } @Test - public void downgradeModel() throws IOException, InterruptedException { + public void downgradeModel() throws Exception { // Download an experimental model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); - - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 1000, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); // Downgrade to an older model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + V804_EN_ANNOTATOR_MANIFEST_URL); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, /* sleepMs= */ 1000, () -> verifyActiveEnglishModel(V804_EN_TAG)); - } + assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG)); } @Test - public void upgradeModel() throws IOException, InterruptedException { + public void upgradeModel() throws Exception { // Download a model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + V804_EN_ANNOTATOR_MANIFEST_URL); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, /* sleepMs= */ 1000, () -> verifyActiveEnglishModel(V804_EN_TAG)); - } + assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG)); // Upgrade to an experimental model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); - - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 1000, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); } @Test - public void clearFlag() throws IOException, InterruptedException { + public void clearFlag() throws Exception { // Download a new model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); - - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 1000, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); // Revert the flag. - { - runShellCommand("device_config delete textclassifier manifest_url_annotator_en"); - // Fallback to use the universal model. - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 1000, - () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride("manifest_url_annotator_en", ""); + // Fallback to use the universal model. + assertWithRetries( + () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG)); } @Test - public void modelsForMultipleLanguagesDownloaded() throws IOException, InterruptedException { - runShellCommand("device_config put textclassifier multi_language_support_enabled true"); - testingLocaleListOverrideRule.set("en-US", "ru-RU"); + public void modelsForMultipleLanguagesDownloaded() throws Exception { + extServicesTextClassifierRule.addDeviceConfigOverride("multi_language_support_enabled", "true"); + extServicesTextClassifierRule.addDeviceConfigOverride( + "testing_locale_list_override", "en-US,ru-RU"); // download en model - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); // download ru model - runShellCommand( - "device_config put textclassifier manifest_url_annotator_ru " - + V804_RU_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 1000, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_ru", V804_RU_ANNOTATOR_MANIFEST_URL); + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - assertWithRetries(/* maxAttempts= */ 10, /* sleepMs= */ 1000, this::verifyActiveRussianModel); + assertWithRetries(this::verifyActiveRussianModel); assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 1000, () -> verifyActiveModel(/* text= */ "français", /* expectedVersion= */ FACTORY_MODEL_TAG)); } - private void assertWithRetries(int maxAttempts, int sleepMs, Runnable assertRunnable) - throws InterruptedException { - for (int i = 0; i < maxAttempts; i++) { + private void assertWithRetries(Runnable assertRunnable) throws Exception { + for (int i = 0; i < ASSERT_MAX_ATTEMPTS; i++) { try { + extServicesTextClassifierRule.overrideDeviceConfig(); assertRunnable.run(); break; // success. Bail out. } catch (AssertionError ex) { - if (i == maxAttempts - 1) { // last attempt, give up. + if (i == ASSERT_MAX_ATTEMPTS - 1) { // last attempt, give up. + extServicesTextClassifierRule.dumpDefaultTextClassifierService(); throw ex; } else { - Thread.sleep(sleepMs); + Thread.sleep(ASSERT_SLEEP_BEFORE_RETRY_MS); } + } catch (Exception unknownException) { + throw unknownException; } } } private void verifyActiveModel(String text, String expectedVersion) { TextClassification textClassification = - textClassifier.classifyText(new Request.Builder(text, 0, text.length()).build()); - Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId()); + extServicesTextClassifierRule + .getTextClassifier() + .classifyText(new Request.Builder(text, 0, text.length()).build()); // The result id contains the name of the just used model. + Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId()); assertThat(textClassification.getId()).contains(expectedVersion); } @@ -223,16 +179,4 @@ public class ModelDownloaderIntegrationTest { private void verifyActiveRussianModel() { verifyActiveModel("привет", V804_RU_TAG); } - - private void startExtservicesProcess() { - // Start the process of ExtServices by sending it a text classifier request. - textClassifier.classifyText(new TextClassification.Request.Builder("abc", 0, 3).build()); - } - - private static void runShellCommand(String cmd) { - Log.v(TAG, "run shell command: " + cmd); - Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation(); - UiAutomation uiAutomation = instrumentation.getUiAutomation(); - uiAutomation.executeShellCommand(cmd); - } } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java index 3ceb47b..5f8247d 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java @@ -20,64 +20,72 @@ import android.app.UiAutomation; import android.content.pm.PackageManager; import android.content.pm.PackageManager.NameNotFoundException; import android.provider.DeviceConfig; +import android.util.Log; import android.view.textclassifier.TextClassificationManager; import android.view.textclassifier.TextClassifier; import androidx.test.core.app.ApplicationProvider; import androidx.test.platform.app.InstrumentationRegistry; +import com.google.common.io.ByteStreams; +import java.io.FileInputStream; +import java.io.IOException; import org.junit.rules.ExternalResource; /** A rule that manages a text classifier that is backed by the ExtServices. */ public final class ExtServicesTextClassifierRule extends ExternalResource { + private static final String TAG = "androidtc"; private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE = "textclassifier_service_package_override"; private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services"; private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services"; - private String textClassifierServiceOverrideFlagOldValue; + private UiAutomation uiAutomation; + private DeviceConfig.Properties originalProperties; + private DeviceConfig.Properties.Builder newPropertiesBuilder; @Override - protected void before() { - UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); - try { - uiAutomation.adoptShellPermissionIdentity(); - textClassifierServiceOverrideFlagOldValue = - DeviceConfig.getString( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, - CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, - null); - DeviceConfig.setProperty( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, - CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, - getExtServicesPackageName(), - /* makeDefault= */ false); - } finally { - uiAutomation.dropShellPermissionIdentity(); - } + protected void before() throws Exception { + uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); + uiAutomation.adoptShellPermissionIdentity(); + originalProperties = DeviceConfig.getProperties(DeviceConfig.NAMESPACE_TEXTCLASSIFIER); + newPropertiesBuilder = + new DeviceConfig.Properties.Builder(DeviceConfig.NAMESPACE_TEXTCLASSIFIER) + .setString( + CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, getExtServicesPackageName()); + overrideDeviceConfig(); } @Override protected void after() { - UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); try { - uiAutomation.adoptShellPermissionIdentity(); - DeviceConfig.setProperty( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, - CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, - textClassifierServiceOverrideFlagOldValue, - /* makeDefault= */ false); + DeviceConfig.setProperties(originalProperties); + } catch (Throwable t) { + Log.e(TAG, "Failed to reset DeviceConfig", t); } finally { uiAutomation.dropShellPermissionIdentity(); } } - private static String getExtServicesPackageName() { - PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager(); - try { - packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0); - return PKG_NAME_GOOGLE_EXTSERVICES; - } catch (NameNotFoundException e) { - return PKG_NAME_AOSP_EXTSERVICES; - } + public void addDeviceConfigOverride(String name, String value) { + newPropertiesBuilder.setString(name, value); + } + + /** + * Overrides the TextClassifier DeviceConfig manually. + * + * <p>This will clean up all device configs not in newPropertiesBuilder. + * + * <p>We will need to call this everytime before testing, because DeviceConfig can be synced in + * background at anytime. DeviceConfig#setSyncDisabledMode is to disable sync, however it's a + * hidden API. + */ + public void overrideDeviceConfig() throws Exception { + DeviceConfig.setProperties(newPropertiesBuilder.build()); + } + + /** Force stop ExtServices. Force-stop-and-start can be helpful to reload some states. */ + public void forceStopExtServices() { + runShellCommand("am force-stop com.google.android.ext.services"); + runShellCommand("am force-stop android.ext.services"); } public TextClassifier getTextClassifier() { @@ -87,4 +95,38 @@ public final class ExtServicesTextClassifierRule extends ExternalResource { textClassificationManager.setTextClassifier(null); // Reset TC overrides return textClassificationManager.getTextClassifier(); } + + public void dumpDefaultTextClassifierService() { + runShellCommand( + "dumpsys activity service com.google.android.ext.services/" + + "com.android.textclassifier.DefaultTextClassifierService"); + runShellCommand("cmd device_config list textclassifier"); + } + + public void enableVerboseLogging() { + runShellCommand("setprop log.tag.androidtc VERBOSE"); + } + + private void runShellCommand(String cmd) { + Log.v(TAG, "run shell command: " + cmd); + try (FileInputStream output = + new FileInputStream(uiAutomation.executeShellCommand(cmd).getFileDescriptor())) { + String cmdOutput = new String(ByteStreams.toByteArray(output)); + if (!cmdOutput.isEmpty()) { + Log.d(TAG, "cmd output: " + cmdOutput); + } + } catch (IOException ioe) { + Log.w(TAG, "failed to get cmd output", ioe); + } + } + + private static String getExtServicesPackageName() { + PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager(); + try { + packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0); + return PKG_NAME_GOOGLE_EXTSERVICES; + } catch (NameNotFoundException e) { + return PKG_NAME_AOSP_EXTSERVICES; + } + } } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java deleted file mode 100644 index 7d46e97..0000000 --- a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (C) 2018 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.textclassifier.testing; - -import android.app.UiAutomation; -import android.os.LocaleList; -import android.util.Log; -import androidx.test.platform.app.InstrumentationRegistry; -import org.junit.rules.ExternalResource; - -/** class for overriding testing_locale_list_override from {@link TextClassifierSettings} */ -public final class TestingLocaleListOverrideRule extends ExternalResource { - private static final String TAG = "TestingLocaleListOverrideRule"; - - private LocaleList originalLocaleList; - - @Override - protected void before() { - originalLocaleList = LocaleList.getDefault(); - } - - public void set(String... localeTags) { - if (localeTags.length == 0) { - return; - } - runShellCommand( - "device_config put textclassifier testing_locale_list_override " - + String.join(",", localeTags)); - } - - @Override - protected void after() { - runShellCommand( - "device_config put textclassifier testing_locale_list_override " - + originalLocaleList.toLanguageTags()); - runShellCommand("device_config delete textclassifier testing_locale_list_override"); - } - - private static void runShellCommand(String cmd) { - Log.v(TAG, "run shell command: " + cmd); - UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); - uiAutomation.executeShellCommand(cmd); - } -} diff --git a/native/utils/tokenfree/byte_encoder_test.cc b/native/utils/tokenfree/byte_encoder_test.cc index d4d119e..964e316 100644 --- a/native/utils/tokenfree/byte_encoder_test.cc +++ b/native/utils/tokenfree/byte_encoder_test.cc @@ -29,7 +29,7 @@ namespace { using testing::ElementsAre; -TEST(EncoderTest, SimpleTokenization) { +TEST(ByteEncoderTest, SimpleTokenization) { const ByteEncoder encoder; { std::vector<int64_t> encoded_text; @@ -39,7 +39,7 @@ TEST(EncoderTest, SimpleTokenization) { } } -TEST(EncoderTest, SimpleTokenization2) { +TEST(ByteEncoderTest, SimpleTokenization2) { const ByteEncoder encoder; { std::vector<int64_t> encoded_text; |