aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Nissan <bennissan@google.com>2022-07-15 15:50:00 +0000
committerAndroid (Google) Code Review <android-gerrit@google.com>2022-07-15 15:50:00 +0000
commit34ba670c765f4b43e068fcfb889d83d7f10170ec (patch)
treef667f047f21f66def01100d8fd36f0f95f707d54
parentf5a5d1ca8a4b1455a969ac19776341f0b3638870 (diff)
parent5d1e591a33054b75a5214c75be68cc14877b31d2 (diff)
downloadtflite-support-34ba670c765f4b43e068fcfb889d83d7f10170ec.tar.gz
Merge "Use custom op resolver for NL classifiers" into tm-mainline-prod
-rw-r--r--Android.bp2
-rw-r--r--tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc61
2 files changed, 62 insertions, 1 deletions
diff --git a/Android.bp b/Android.bp
index a0fb7f15..daed67eb 100644
--- a/Android.bp
+++ b/Android.bp
@@ -218,7 +218,7 @@ cc_library_shared {
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc",
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc",
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc",
- "tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc",
+ "tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc",
"tensorflow_lite_support/cc/utils/jni_utils.cc",
],
shared_libs: ["liblog"],
diff --git a/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
new file mode 100644
index 00000000..31d693a8
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
@@ -0,0 +1,61 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/lite/kernels/builtin_op_kernels.h"
+#include "tensorflow/lite/op_resolver.h"
+
+namespace tflite {
+namespace task {
+
+// Create a minimal MutableOpResolver to provide only
+// the ops required by NLClassifier/BertNLClassifier.
+std::unique_ptr<MutableOpResolver> CreateOpResolver() {
+ MutableOpResolver resolver;
+ resolver.AddBuiltin(::tflite::BuiltinOperator_DEQUANTIZE,
+ ::tflite::ops::builtin::Register_DEQUANTIZE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_RESHAPE,
+ ::tflite::ops::builtin::Register_RESHAPE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_GATHER,
+ ::tflite::ops::builtin::Register_GATHER());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_STRIDED_SLICE,
+ ::tflite::ops::builtin::Register_STRIDED_SLICE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_PAD,
+ ::tflite::ops::builtin::Register_PAD());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
+ ::tflite::ops::builtin::Register_CONCATENATION());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
+ ::tflite::ops::builtin::Register_FULLY_CONNECTED());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_CAST,
+ ::tflite::ops::builtin::Register_CAST());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_MUL,
+ ::tflite::ops::builtin::Register_MUL());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
+ ::tflite::ops::builtin::Register_ADD());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_TRANSPOSE,
+ ::tflite::ops::builtin::Register_TRANSPOSE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SPLIT,
+ ::tflite::ops::builtin::Register_SPLIT());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_PACK,
+ ::tflite::ops::builtin::Register_PACK());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SOFTMAX,
+ ::tflite::ops::builtin::Register_SOFTMAX());
+ return absl::make_unique<MutableOpResolver>(resolver);
+}
+
+} // namespace task
+} // namespace tflite \ No newline at end of file