diff options
author | Ben Nissan <bennissan@google.com> | 2022-07-15 15:50:00 +0000 |
---|---|---|
committer | Android (Google) Code Review <android-gerrit@google.com> | 2022-07-15 15:50:00 +0000 |
commit | 34ba670c765f4b43e068fcfb889d83d7f10170ec (patch) | |
tree | f667f047f21f66def01100d8fd36f0f95f707d54 | |
parent | f5a5d1ca8a4b1455a969ac19776341f0b3638870 (diff) | |
parent | 5d1e591a33054b75a5214c75be68cc14877b31d2 (diff) | |
download | tflite-support-34ba670c765f4b43e068fcfb889d83d7f10170ec.tar.gz |
Merge "Use custom op resolver for NL classifiers" into tm-mainline-prod
-rw-r--r-- | Android.bp | 2 | ||||
-rw-r--r-- | tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc | 61 |
2 files changed, 62 insertions, 1 deletions
@@ -218,7 +218,7 @@ cc_library_shared { "tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc", "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc", "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc", - "tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc", + "tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc", "tensorflow_lite_support/cc/utils/jni_utils.cc", ], shared_libs: ["liblog"], diff --git a/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc new file mode 100644 index 00000000..31d693a8 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc @@ -0,0 +1,61 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> + +#include "absl/memory/memory.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/op_resolver.h" + +namespace tflite { +namespace task { + +// Create a minimal MutableOpResolver to provide only +// the ops required by NLClassifier/BertNLClassifier. +std::unique_ptr<MutableOpResolver> CreateOpResolver() { + MutableOpResolver resolver; + resolver.AddBuiltin(::tflite::BuiltinOperator_DEQUANTIZE, + ::tflite::ops::builtin::Register_DEQUANTIZE()); + resolver.AddBuiltin(::tflite::BuiltinOperator_RESHAPE, + ::tflite::ops::builtin::Register_RESHAPE()); + resolver.AddBuiltin(::tflite::BuiltinOperator_GATHER, + ::tflite::ops::builtin::Register_GATHER()); + resolver.AddBuiltin(::tflite::BuiltinOperator_STRIDED_SLICE, + ::tflite::ops::builtin::Register_STRIDED_SLICE()); + resolver.AddBuiltin(::tflite::BuiltinOperator_PAD, + ::tflite::ops::builtin::Register_PAD()); + resolver.AddBuiltin(::tflite::BuiltinOperator_CONCATENATION, + ::tflite::ops::builtin::Register_CONCATENATION()); + resolver.AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED, + ::tflite::ops::builtin::Register_FULLY_CONNECTED()); + resolver.AddBuiltin(::tflite::BuiltinOperator_CAST, + ::tflite::ops::builtin::Register_CAST()); + resolver.AddBuiltin(::tflite::BuiltinOperator_MUL, + ::tflite::ops::builtin::Register_MUL()); + resolver.AddBuiltin(::tflite::BuiltinOperator_ADD, + ::tflite::ops::builtin::Register_ADD()); + resolver.AddBuiltin(::tflite::BuiltinOperator_TRANSPOSE, + ::tflite::ops::builtin::Register_TRANSPOSE()); + resolver.AddBuiltin(::tflite::BuiltinOperator_SPLIT, + ::tflite::ops::builtin::Register_SPLIT()); + resolver.AddBuiltin(::tflite::BuiltinOperator_PACK, + ::tflite::ops::builtin::Register_PACK()); + resolver.AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, + ::tflite::ops::builtin::Register_SOFTMAX()); + return absl::make_unique<MutableOpResolver>(resolver); +} + +} // namespace task +} // namespace tflite
\ No newline at end of file |