diff options
author | Presubmit Automerger Backend <android-build-presubmit-automerger-backend@system.gserviceaccount.com> | 2022-06-02 18:29:08 +0000 |
---|---|---|
committer | Presubmit Automerger Backend <android-build-presubmit-automerger-backend@system.gserviceaccount.com> | 2022-06-02 18:29:08 +0000 |
commit | bf4efa0424141f9e77edc5f103aa48f2d62d6a36 (patch) | |
tree | 0a5b70a22853376cf2a1f84612cd8d810ac94357 | |
parent | fb6bf1937de42d081d0736baa5380597d262da1d (diff) | |
parent | 3119c9ac2289a461a8c5e40b576dec683a30f0f1 (diff) | |
download | tflite-support-bf4efa0424141f9e77edc5f103aa48f2d62d6a36.tar.gz |
[automerge] Port BertNLClassifierTest into AOSP 2p: 3119c9ac22
Original change: https://googleplex-android-review.googlesource.com/c/platform/external/tflite-support/+/18716813
Bug: 232807230
Change-Id: I68c7792a856814dde3a32e96ea4e10c029b12fb2
5 files changed, 181 insertions, 0 deletions
@@ -373,6 +373,57 @@ cc_test { ], } +android_test { + name: "TfliteSupportClassifierTests", + srcs: [ + "tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java", + ], + asset_dirs: [ + "tensorflow_lite_support/java/src/javatests/testdata/task/text", + ], + defaults: ["modules-utils-testable-device-config-defaults"], + manifest: "tensorflow_lite_support/java/AndroidManifest.xml", + sdk_version: "module_current", + min_sdk_version: "30", + static_libs: [ + "androidx.test.core", + "tensorflowlite_java", + "truth-prebuilt", + "tflite_support_classifiers_java", + "tflite_support_test_utils_java", + ], + libs: [ + "android.test.base", + "android.test.mock.stubs", + ], + test_suites: [ + "general-tests", + ], + jni_libs: [ + "libtflite_support_classifiers_native", + ], + aaptflags: [ + // Avoid compression on tflite files as the Interpreter + // can not load compressed flat buffer formats. + // (*appt compresses all assets into the apk by default) + // See https://elinux.org/Android_aapt for more detail. + "-0 .tflite", + ], +} + +java_library_static { + name: "tflite_support_test_utils_java", + sdk_version: "module_current", + min_sdk_version: "30", + srcs: [ + "tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java", + ], + static_libs: [ + "apache-commons-compress", + "guava", + ], +} + cc_library_static { name: "tflite_support_task_core_proto", proto: { diff --git a/tensorflow_lite_support/java/AndroidManifest.xml b/tensorflow_lite_support/java/AndroidManifest.xml index 14909296..c36eb383 100644 --- a/tensorflow_lite_support/java/AndroidManifest.xml +++ b/tensorflow_lite_support/java/AndroidManifest.xml @@ -2,4 +2,8 @@ <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="org.tensorflow.lite.support"> <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> + <instrumentation + android:name="androidx.test.runner.AndroidJUnitRunner" + android:targetPackage="org.tensorflow.lite.support" > + </instrumentation> </manifest> diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java new file mode 100644 index 00000000..4c6f369d --- /dev/null +++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java @@ -0,0 +1,50 @@ +package org.tensorflow.lite.task.core; + +import android.content.Context; +import android.content.res.AssetManager; + +import com.google.common.io.ByteStreams; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** Helper class for the Java test in Task Libary. */ +public final class TestUtils { + + /** + * Loads the file and create a {@link File} object by reading a file from the asset directory. + * Simulates downloading or reading a file that's not precompiled with the app. + * + * @return a {@link File} object for the model. + */ + public static File loadFile(Context context, String fileName) { + File target = new File(context.getFilesDir(), fileName); + try (InputStream is = context.getAssets().open(fileName); + FileOutputStream os = new FileOutputStream(target)) { + ByteStreams.copy(is, os); + } catch (IOException e) { + throw new AssertionError("Failed to load model file at " + fileName, e); + } + return target; + } + + /** + * Reads a file into a direct {@link ByteBuffer} object from the asset directory. + * + * @return a {@link ByteBuffer} object for the file. + */ + public static ByteBuffer loadToDirectByteBuffer(Context context, String fileName) + throws IOException { + AssetManager assetManager = context.getAssets(); + InputStream inputStream = assetManager.open(fileName); + byte[] bytes = ByteStreams.toByteArray(inputStream); + + ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder()); + buffer.put(bytes); + return buffer; + } +}
\ No newline at end of file diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java new file mode 100644 index 00000000..24a91ac4 --- /dev/null +++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java @@ -0,0 +1,76 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.text.nlclassifier; + +import static com.google.common.truth.Truth.assertThat; + +import androidx.test.core.app.ApplicationProvider; +import java.io.IOException; +import java.util.List; +import org.junit.Test; +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.core.TestUtils; + +/** Test for {@link BertNLClassifier}. */ +public class BertNLClassifierTest { + private static final String MODEL_FILE = "bert_nl_classifier.tflite"; + + Category findCategoryWithLabel(List<Category> list, String label) { + return list.stream() + .filter(category -> label.equals(category.getLabel())) + .findAny() + .orElse(null); + } + + @Test + public void createFromPath_verifyResults() throws IOException { + verifyResults( + BertNLClassifier.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE)); + } + + @Test + public void createFromFile_verifyResults() throws IOException { + verifyResults( + BertNLClassifier.createFromFile( + TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE))); + } + + @Test + public void classify_succeedsWithModelFile() throws IOException { + verifyResults( + BertNLClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), MODEL_FILE)); + } + + @Test + public void classify_succeedsWithModelBuffer() throws IOException { + verifyResults( + BertNLClassifier.createFromBuffer( + TestUtils.loadToDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE))); + } + + private void verifyResults(BertNLClassifier classifier) { + List<Category> negativeResults = classifier.classify("unflinchingly bleak and desperate"); + assertThat(findCategoryWithLabel(negativeResults, "negative").getScore()) + .isGreaterThan(findCategoryWithLabel(negativeResults, "positive").getScore()); + + List<Category> positiveResults = + classifier.classify("it's a charming and often affecting journey"); + assertThat(findCategoryWithLabel(positiveResults, "positive").getScore()) + .isGreaterThan(findCategoryWithLabel(positiveResults, "negative").getScore()); + } +}
\ No newline at end of file diff --git a/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite b/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite Binary files differnew file mode 100644 index 00000000..97a32da4 --- /dev/null +++ b/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite |