aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPresubmit Automerger Backend <android-build-presubmit-automerger-backend@system.gserviceaccount.com>2022-06-02 18:29:08 +0000
committerPresubmit Automerger Backend <android-build-presubmit-automerger-backend@system.gserviceaccount.com>2022-06-02 18:29:08 +0000
commitbf4efa0424141f9e77edc5f103aa48f2d62d6a36 (patch)
tree0a5b70a22853376cf2a1f84612cd8d810ac94357
parentfb6bf1937de42d081d0736baa5380597d262da1d (diff)
parent3119c9ac2289a461a8c5e40b576dec683a30f0f1 (diff)
downloadtflite-support-bf4efa0424141f9e77edc5f103aa48f2d62d6a36.tar.gz
[automerge] Port BertNLClassifierTest into AOSP 2p: 3119c9ac22
Original change: https://googleplex-android-review.googlesource.com/c/platform/external/tflite-support/+/18716813 Bug: 232807230 Change-Id: I68c7792a856814dde3a32e96ea4e10c029b12fb2
-rw-r--r--Android.bp51
-rw-r--r--tensorflow_lite_support/java/AndroidManifest.xml4
-rw-r--r--tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java50
-rw-r--r--tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java76
-rw-r--r--tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflitebin0 -> 25707538 bytes
5 files changed, 181 insertions, 0 deletions
diff --git a/Android.bp b/Android.bp
index 96debe5b..a0fb7f15 100644
--- a/Android.bp
+++ b/Android.bp
@@ -373,6 +373,57 @@ cc_test {
],
}
+android_test {
+ name: "TfliteSupportClassifierTests",
+ srcs: [
+ "tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java",
+ ],
+ asset_dirs: [
+ "tensorflow_lite_support/java/src/javatests/testdata/task/text",
+ ],
+ defaults: ["modules-utils-testable-device-config-defaults"],
+ manifest: "tensorflow_lite_support/java/AndroidManifest.xml",
+ sdk_version: "module_current",
+ min_sdk_version: "30",
+ static_libs: [
+ "androidx.test.core",
+ "tensorflowlite_java",
+ "truth-prebuilt",
+ "tflite_support_classifiers_java",
+ "tflite_support_test_utils_java",
+ ],
+ libs: [
+ "android.test.base",
+ "android.test.mock.stubs",
+ ],
+ test_suites: [
+ "general-tests",
+ ],
+ jni_libs: [
+ "libtflite_support_classifiers_native",
+ ],
+ aaptflags: [
+ // Avoid compression on tflite files as the Interpreter
+ // can not load compressed flat buffer formats.
+ // (*appt compresses all assets into the apk by default)
+ // See https://elinux.org/Android_aapt for more detail.
+ "-0 .tflite",
+ ],
+}
+
+java_library_static {
+ name: "tflite_support_test_utils_java",
+ sdk_version: "module_current",
+ min_sdk_version: "30",
+ srcs: [
+ "tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java",
+ ],
+ static_libs: [
+ "apache-commons-compress",
+ "guava",
+ ],
+}
+
cc_library_static {
name: "tflite_support_task_core_proto",
proto: {
diff --git a/tensorflow_lite_support/java/AndroidManifest.xml b/tensorflow_lite_support/java/AndroidManifest.xml
index 14909296..c36eb383 100644
--- a/tensorflow_lite_support/java/AndroidManifest.xml
+++ b/tensorflow_lite_support/java/AndroidManifest.xml
@@ -2,4 +2,8 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.lite.support">
<uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/>
+ <instrumentation
+ android:name="androidx.test.runner.AndroidJUnitRunner"
+ android:targetPackage="org.tensorflow.lite.support" >
+ </instrumentation>
</manifest>
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java
new file mode 100644
index 00000000..4c6f369d
--- /dev/null
+++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java
@@ -0,0 +1,50 @@
+package org.tensorflow.lite.task.core;
+
+import android.content.Context;
+import android.content.res.AssetManager;
+
+import com.google.common.io.ByteStreams;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+/** Helper class for the Java test in Task Libary. */
+public final class TestUtils {
+
+ /**
+ * Loads the file and create a {@link File} object by reading a file from the asset directory.
+ * Simulates downloading or reading a file that's not precompiled with the app.
+ *
+ * @return a {@link File} object for the model.
+ */
+ public static File loadFile(Context context, String fileName) {
+ File target = new File(context.getFilesDir(), fileName);
+ try (InputStream is = context.getAssets().open(fileName);
+ FileOutputStream os = new FileOutputStream(target)) {
+ ByteStreams.copy(is, os);
+ } catch (IOException e) {
+ throw new AssertionError("Failed to load model file at " + fileName, e);
+ }
+ return target;
+ }
+
+ /**
+ * Reads a file into a direct {@link ByteBuffer} object from the asset directory.
+ *
+ * @return a {@link ByteBuffer} object for the file.
+ */
+ public static ByteBuffer loadToDirectByteBuffer(Context context, String fileName)
+ throws IOException {
+ AssetManager assetManager = context.getAssets();
+ InputStream inputStream = assetManager.open(fileName);
+ byte[] bytes = ByteStreams.toByteArray(inputStream);
+
+ ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder());
+ buffer.put(bytes);
+ return buffer;
+ }
+} \ No newline at end of file
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
new file mode 100644
index 00000000..24a91ac4
--- /dev/null
+++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
@@ -0,0 +1,76 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.text.nlclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import androidx.test.core.app.ApplicationProvider;
+import java.io.IOException;
+import java.util.List;
+import org.junit.Test;
+import org.tensorflow.lite.support.label.Category;
+import org.tensorflow.lite.task.core.TestUtils;
+
+/** Test for {@link BertNLClassifier}. */
+public class BertNLClassifierTest {
+ private static final String MODEL_FILE = "bert_nl_classifier.tflite";
+
+ Category findCategoryWithLabel(List<Category> list, String label) {
+ return list.stream()
+ .filter(category -> label.equals(category.getLabel()))
+ .findAny()
+ .orElse(null);
+ }
+
+ @Test
+ public void createFromPath_verifyResults() throws IOException {
+ verifyResults(
+ BertNLClassifier.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
+ }
+
+ @Test
+ public void createFromFile_verifyResults() throws IOException {
+ verifyResults(
+ BertNLClassifier.createFromFile(
+ TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE)));
+ }
+
+ @Test
+ public void classify_succeedsWithModelFile() throws IOException {
+ verifyResults(
+ BertNLClassifier.createFromFile(
+ ApplicationProvider.getApplicationContext(), MODEL_FILE));
+ }
+
+ @Test
+ public void classify_succeedsWithModelBuffer() throws IOException {
+ verifyResults(
+ BertNLClassifier.createFromBuffer(
+ TestUtils.loadToDirectByteBuffer(
+ ApplicationProvider.getApplicationContext(), MODEL_FILE)));
+ }
+
+ private void verifyResults(BertNLClassifier classifier) {
+ List<Category> negativeResults = classifier.classify("unflinchingly bleak and desperate");
+ assertThat(findCategoryWithLabel(negativeResults, "negative").getScore())
+ .isGreaterThan(findCategoryWithLabel(negativeResults, "positive").getScore());
+
+ List<Category> positiveResults =
+ classifier.classify("it's a charming and often affecting journey");
+ assertThat(findCategoryWithLabel(positiveResults, "positive").getScore())
+ .isGreaterThan(findCategoryWithLabel(positiveResults, "negative").getScore());
+ }
+} \ No newline at end of file
diff --git a/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite b/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite
new file mode 100644
index 00000000..97a32da4
--- /dev/null
+++ b/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite
Binary files differ