aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2023-02-26 22:33:43 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2023-02-26 22:33:43 +0000
commitd382591689a4d29498f2a4f9462a67e1f4a2c709 (patch)
tree97cf973d8b5114e65bb9444dee1207cce6b36cd6
parent782e0b69af77bad7635ea84d745efb72cc3b64c6 (diff)
parent55d84e75c29f45f9a277c01be46f19edef4e76c9 (diff)
downloadtflite-support-android13-mainline-mediaprovider-release.tar.gz
Snap for 9656615 from 55d84e75c29f45f9a277c01be46f19edef4e76c9 to mainline-mediaprovider-releaseaml_mpr_331812020aml_mpr_331711020android13-mainline-mediaprovider-release
Change-Id: If2606ed32f877d48a66dc9997b65bf2337f3756f
-rw-r--r--Android.bp2
-rw-r--r--tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java1
-rw-r--r--tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc46
3 files changed, 37 insertions, 12 deletions
diff --git a/Android.bp b/Android.bp
index 34cc018b..f3d5bb4e 100644
--- a/Android.bp
+++ b/Android.bp
@@ -225,7 +225,6 @@ cc_library_shared {
"tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc",
"tensorflow_lite_support/cc/utils/jni_utils.cc",
],
- // TODO(b/247088924): Use linker_scripts here.
version_script: "tensorflow_lite_support/java/tflite_version_script.lds",
shared_libs: ["liblog"],
static_libs: [
@@ -246,7 +245,6 @@ cc_library_shared {
"tensorflow_headers",
"flatbuffer_headers",
"jni_headers",
- "liblog_headers",
"libtextclassifier_flatbuffer_headers",
],
generated_headers: [
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
index 8c71f705..efaa9d99 100644
--- a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
+++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
@@ -124,6 +124,5 @@ public class BertNLClassifierTest {
private void verifyDynamicInputResults(BertNLClassifier classifier) {
List<Category> topics = classifier.classify("FooBarBaz");
assertThat(topics.size()).isEqualTo(446);
- // TODO(ag/19888344): Add a test for a long text input.
}
}
diff --git a/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
index 31d693a8..32d1054d 100644
--- a/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
+++ b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include <memory>
-#include "absl/memory/memory.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/op_resolver.h"
@@ -23,21 +22,15 @@ namespace tflite {
namespace task {
// Create a minimal MutableOpResolver to provide only
-// the ops required by NLClassifier/BertNLClassifier.
+// the ops required by the bert_nl_classifier and rb_model for BertNLClassifier.
std::unique_ptr<MutableOpResolver> CreateOpResolver() {
MutableOpResolver resolver;
- resolver.AddBuiltin(::tflite::BuiltinOperator_DEQUANTIZE,
- ::tflite::ops::builtin::Register_DEQUANTIZE());
resolver.AddBuiltin(::tflite::BuiltinOperator_RESHAPE,
::tflite::ops::builtin::Register_RESHAPE());
resolver.AddBuiltin(::tflite::BuiltinOperator_GATHER,
::tflite::ops::builtin::Register_GATHER());
resolver.AddBuiltin(::tflite::BuiltinOperator_STRIDED_SLICE,
::tflite::ops::builtin::Register_STRIDED_SLICE());
- resolver.AddBuiltin(::tflite::BuiltinOperator_PAD,
- ::tflite::ops::builtin::Register_PAD());
- resolver.AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
- ::tflite::ops::builtin::Register_CONCATENATION());
resolver.AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
::tflite::ops::builtin::Register_FULLY_CONNECTED());
resolver.AddBuiltin(::tflite::BuiltinOperator_CAST,
@@ -54,7 +47,42 @@ std::unique_ptr<MutableOpResolver> CreateOpResolver() {
::tflite::ops::builtin::Register_PACK());
resolver.AddBuiltin(::tflite::BuiltinOperator_SOFTMAX,
::tflite::ops::builtin::Register_SOFTMAX());
- return absl::make_unique<MutableOpResolver>(resolver);
+ resolver.AddBuiltin(::tflite::BuiltinOperator_EXPAND_DIMS,
+ ::tflite::ops::builtin::Register_EXPAND_DIMS());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SHAPE,
+ ::tflite::ops::builtin::Register_SHAPE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_FILL,
+ ::tflite::ops::builtin::Register_FILL());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SUB,
+ ::tflite::ops::builtin::Register_SUB());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_MEAN,
+ ::tflite::ops::builtin::Register_MEAN());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
+ ::tflite::ops::builtin::Register_SQUARED_DIFFERENCE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_RSQRT,
+ ::tflite::ops::builtin::Register_RSQRT());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_BATCH_MATMUL,
+ ::tflite::ops::builtin::Register_BATCH_MATMUL());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_GELU,
+ ::tflite::ops::builtin::Register_GELU());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_TANH,
+ ::tflite::ops::builtin::Register_TANH());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_LOGISTIC,
+ ::tflite::ops::builtin::Register_LOGISTIC());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SLICE,
+ ::tflite::ops::builtin::Register_SLICE());
+ // Needed for the test bert_nl_classifier model.
+ resolver.AddBuiltin(::tflite::BuiltinOperator_PAD,
+ ::tflite::ops::builtin::Register_PAD());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
+ ::tflite::ops::builtin::Register_CONCATENATION());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
+ ::tflite::ops::builtin::Register_FULLY_CONNECTED(),
+ /*version=*/9);
+ resolver.AddBuiltin(::tflite::BuiltinOperator_DEQUANTIZE,
+ ::tflite::ops::builtin::Register_DEQUANTIZE(),
+ /*version=*/2);
+ return std::make_unique<MutableOpResolver>(resolver);
}
} // namespace task