aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPresubmit Automerger Backend <android-build-presubmit-automerger-backend@system.gserviceaccount.com>2022-08-08 20:44:00 +0000
committerPresubmit Automerger Backend <android-build-presubmit-automerger-backend@system.gserviceaccount.com>2022-08-08 20:44:00 +0000
commit9bf1b3cc8d4fd76abea46c3daa93a34bbf592310 (patch)
tree845121243bce2106a5d816226999d6c87d6748bf
parent77decfd9b6e3f00ede4df76c83a81c05fc09d49e (diff)
parente4266a8951b0f53a80bc54392276c355583c0a53 (diff)
downloadtflite-support-9bf1b3cc8d4fd76abea46c3daa93a34bbf592310.tar.gz
[automerge] Supports dynamic tensor inputs in BertNLClassifier. 2p: e4266a8951
Original change: https://googleplex-android-review.googlesource.com/c/platform/external/tflite-support/+/19536078 Bug: 241507692 Change-Id: I9c9010f13f037d1b3c3fb7547c737a477ecc0313
-rw-r--r--Android.bp4
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc113
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h1
-rw-r--r--tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java2
4 files changed, 95 insertions, 25 deletions
diff --git a/Android.bp b/Android.bp
index 2b87f473..de4f3ddd 100644
--- a/Android.bp
+++ b/Android.bp
@@ -218,7 +218,7 @@ cc_library_shared {
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc",
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc",
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc",
- "tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc",
+ "tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc",
"tensorflow_lite_support/cc/utils/jni_utils.cc",
],
shared_libs: ["liblog"],
@@ -477,4 +477,4 @@ genrule {
srcs: ["tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config.fbs"],
out: ["tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"],
defaults: ["tflite_support_fbgen"],
-} \ No newline at end of file
+}
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
index d689c9e8..a246066b 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
@@ -55,12 +55,16 @@ namespace {
constexpr char kIdsTensorName[] = "ids";
constexpr char kMaskTensorName[] = "mask";
constexpr char kSegmentIdsTensorName[] = "segment_ids";
+constexpr int kIdsTensorIndex = 0;
+constexpr int kMaskTensorIndex = 1;
+constexpr int kSegmentIdsTensorIndex = 2;
constexpr char kScoreTensorName[] = "probability";
constexpr char kClassificationToken[] = "[CLS]";
constexpr char kSeparator[] = "[SEP]";
constexpr int kTokenizerProcessUnitIndex = 0;
} // namespace
+// TODO(b/241507692) Add a unit test for a model with dynamic tensors.
absl::Status BertNLClassifier::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
auto* input_tensor_metadatas =
@@ -78,39 +82,46 @@ absl::Status BertNLClassifier::Preprocess(
TokenizerResult input_tokenize_results;
input_tokenize_results = tokenizer_->Tokenize(processed_input);
- // 2 accounts for [CLS], [SEP]
- absl::Span<const std::string> query_tokens =
- absl::MakeSpan(input_tokenize_results.subwords.data(),
- input_tokenize_results.subwords.data() +
- std::min(static_cast<size_t>(kMaxSeqLen - 2),
- input_tokenize_results.subwords.size()));
-
- std::vector<std::string> tokens;
- tokens.reserve(2 + query_tokens.size());
- // Start of generating the features.
- tokens.push_back(kClassificationToken);
- // For query input.
- for (const auto& query_token : query_tokens) {
- tokens.push_back(query_token);
+ // Offset by 2 to account for [CLS] and [SEP]
+ int input_tokens_size =
+ static_cast<int>(input_tokenize_results.subwords.size()) + 2;
+ int input_tensor_length = input_tokens_size;
+ if (!input_tensors_are_dynamic_) {
+ input_tokens_size = std::min(kMaxSeqLen, input_tokens_size);
+ input_tensor_length = kMaxSeqLen;
+ } else {
+ GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kIdsTensorIndex,
+ {1, input_tensor_length});
+ GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kMaskTensorIndex,
+ {1, input_tensor_length});
+ GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kSegmentIdsTensorIndex,
+ {1, input_tensor_length});
+ GetTfLiteEngine()->interpreter()->AllocateTensors();
}
- // For Separation.
- tokens.push_back(kSeparator);
- std::vector<int> input_ids(kMaxSeqLen, 0);
- std::vector<int> input_mask(kMaxSeqLen, 0);
+ std::vector<std::string> input_tokens;
+ input_tokens.reserve(input_tokens_size);
+ input_tokens.push_back(std::string(kClassificationToken));
+ for (int i = 0; i < input_tokens_size - 2; ++i) {
+ input_tokens.push_back(std::move(input_tokenize_results.subwords[i]));
+ }
+ input_tokens.push_back(std::string(kSeparator));
+
+ std::vector<int> input_ids(input_tensor_length, 0);
+ std::vector<int> input_mask(input_tensor_length, 0);
// Convert tokens back into ids and set mask
- for (int i = 0; i < tokens.size(); ++i) {
- tokenizer_->LookupId(tokens[i], &input_ids[i]);
+ for (int i = 0; i < input_tokens.size(); ++i) {
+ tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
input_mask[i] = 1;
}
- // |<-----------kMaxSeqLen---------->|
+ // |<--------input_tensor_length------->|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
// input_masks 1 1 1... 1 1 0 0... 0
// segment_ids 0 0 0... 0 0 0 0... 0
PopulateTensor(input_ids, ids_tensor);
PopulateTensor(input_mask, mask_tensor);
- PopulateTensor(std::vector<int>(kMaxSeqLen, 0), segment_ids_tensor);
+ PopulateTensor(std::vector<int>(input_tensor_length, 0), segment_ids_tensor);
return absl::OkStatus();
}
@@ -189,6 +200,64 @@ absl::Status BertNLClassifier::InitializeFromMetadata() {
TrySetLabelFromMetadata(
GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex))
.IgnoreError();
+
+ auto* input_tensor_metadatas =
+ GetMetadataExtractor()->GetInputTensorMetadata();
+ const auto& input_tensors = GetInputTensors();
+ const auto& ids_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
+ kIdsTensorName);
+ const auto& mask_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
+ kMaskTensorName);
+ const auto& segment_ids_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
+ kSegmentIdsTensorName);
+ if (ids_tensor.dims->size != 2 || mask_tensor.dims->size != 2 ||
+ segment_ids_tensor.dims->size != 2) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ absl::StrFormat(
+ "The three input tensors in Bert models are expected to have dim "
+ "2, but got ids_tensor (%d), mask_tensor (%d), segment_ids_tensor "
+ "(%d).",
+ ids_tensor.dims->size, mask_tensor.dims->size,
+ segment_ids_tensor.dims->size),
+ TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
+ }
+ if (ids_tensor.dims->data[0] != 1 || mask_tensor.dims->data[0] != 1 ||
+ segment_ids_tensor.dims->data[0] != 1) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ absl::StrFormat(
+ "The three input tensors in Bert models are expected to have same "
+ "batch size 1, but got ids_tensor (%d), mask_tensor (%d), "
+ "segment_ids_tensor (%d).",
+ ids_tensor.dims->data[0], mask_tensor.dims->data[0],
+ segment_ids_tensor.dims->data[0]),
+ TfLiteSupportStatus::kInvalidInputTensorSizeError);
+ }
+ if (ids_tensor.dims->data[1] != mask_tensor.dims->data[1] ||
+ ids_tensor.dims->data[1] != segment_ids_tensor.dims->data[1]) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ absl::StrFormat("The three input tensors in Bert models are "
+ "expected to have same length, but got ids_tensor "
+ "(%d), mask_tensor (%d), segment_ids_tensor (%d).",
+ ids_tensor.dims->data[1], mask_tensor.dims->data[1],
+ segment_ids_tensor.dims->data[1]),
+ TfLiteSupportStatus::kInvalidInputTensorSizeError);
+ }
+ if (ids_tensor.dims_signature->data[1] == -1 &&
+ mask_tensor.dims_signature->data[1] == -1 &&
+ segment_ids_tensor.dims_signature->data[1] == -1) {
+ input_tensors_are_dynamic_ = true;
+ } else if (ids_tensor.dims_signature->data[1] == -1 ||
+ mask_tensor.dims_signature->data[1] == -1 ||
+ segment_ids_tensor.dims_signature->data[1] == -1) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ "Input tensors contain a mix of static and dynamic tensors",
+ TfLiteSupportStatus::kInvalidInputTensorSizeError);
+ }
+
return absl::OkStatus();
}
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h
index 0c709ee0..7a4b0587 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h
@@ -95,6 +95,7 @@ class BertNLClassifier : public NLClassifier {
absl::Status InitializeFromMetadata();
std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_;
+ bool input_tensors_are_dynamic_ = false;
};
} // namespace nlclassifier
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
index f0667d77..0a609fd2 100644
--- a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
+++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
@@ -93,4 +93,4 @@ public class BertNLClassifierTest {
assertThat(findCategoryWithLabel(positiveResults, "positive").getScore())
.isGreaterThan(findCategoryWithLabel(positiveResults, "negative").getScore());
}
-} \ No newline at end of file
+}