diff options
author | Presubmit Automerger Backend <android-build-presubmit-automerger-backend@system.gserviceaccount.com> | 2022-08-08 20:44:00 +0000 |
---|---|---|
committer | Presubmit Automerger Backend <android-build-presubmit-automerger-backend@system.gserviceaccount.com> | 2022-08-08 20:44:00 +0000 |
commit | 9bf1b3cc8d4fd76abea46c3daa93a34bbf592310 (patch) | |
tree | 845121243bce2106a5d816226999d6c87d6748bf | |
parent | 77decfd9b6e3f00ede4df76c83a81c05fc09d49e (diff) | |
parent | e4266a8951b0f53a80bc54392276c355583c0a53 (diff) | |
download | tflite-support-9bf1b3cc8d4fd76abea46c3daa93a34bbf592310.tar.gz |
[automerge] Supports dynamic tensor inputs in BertNLClassifier. 2p: e4266a8951
Original change: https://googleplex-android-review.googlesource.com/c/platform/external/tflite-support/+/19536078
Bug: 241507692
Change-Id: I9c9010f13f037d1b3c3fb7547c737a477ecc0313
4 files changed, 95 insertions, 25 deletions
@@ -218,7 +218,7 @@ cc_library_shared { "tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc", "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc", "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc", - "tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc", + "tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc", "tensorflow_lite_support/cc/utils/jni_utils.cc", ], shared_libs: ["liblog"], @@ -477,4 +477,4 @@ genrule { srcs: ["tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config.fbs"], out: ["tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"], defaults: ["tflite_support_fbgen"], -}
\ No newline at end of file +} diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc index d689c9e8..a246066b 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc @@ -55,12 +55,16 @@ namespace { constexpr char kIdsTensorName[] = "ids"; constexpr char kMaskTensorName[] = "mask"; constexpr char kSegmentIdsTensorName[] = "segment_ids"; +constexpr int kIdsTensorIndex = 0; +constexpr int kMaskTensorIndex = 1; +constexpr int kSegmentIdsTensorIndex = 2; constexpr char kScoreTensorName[] = "probability"; constexpr char kClassificationToken[] = "[CLS]"; constexpr char kSeparator[] = "[SEP]"; constexpr int kTokenizerProcessUnitIndex = 0; } // namespace +// TODO(b/241507692) Add a unit test for a model with dynamic tensors. absl::Status BertNLClassifier::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) { auto* input_tensor_metadatas = @@ -78,39 +82,46 @@ absl::Status BertNLClassifier::Preprocess( TokenizerResult input_tokenize_results; input_tokenize_results = tokenizer_->Tokenize(processed_input); - // 2 accounts for [CLS], [SEP] - absl::Span<const std::string> query_tokens = - absl::MakeSpan(input_tokenize_results.subwords.data(), - input_tokenize_results.subwords.data() + - std::min(static_cast<size_t>(kMaxSeqLen - 2), - input_tokenize_results.subwords.size())); - - std::vector<std::string> tokens; - tokens.reserve(2 + query_tokens.size()); - // Start of generating the features. - tokens.push_back(kClassificationToken); - // For query input. - for (const auto& query_token : query_tokens) { - tokens.push_back(query_token); + // Offset by 2 to account for [CLS] and [SEP] + int input_tokens_size = + static_cast<int>(input_tokenize_results.subwords.size()) + 2; + int input_tensor_length = input_tokens_size; + if (!input_tensors_are_dynamic_) { + input_tokens_size = std::min(kMaxSeqLen, input_tokens_size); + input_tensor_length = kMaxSeqLen; + } else { + GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kIdsTensorIndex, + {1, input_tensor_length}); + GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kMaskTensorIndex, + {1, input_tensor_length}); + GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kSegmentIdsTensorIndex, + {1, input_tensor_length}); + GetTfLiteEngine()->interpreter()->AllocateTensors(); } - // For Separation. - tokens.push_back(kSeparator); - std::vector<int> input_ids(kMaxSeqLen, 0); - std::vector<int> input_mask(kMaxSeqLen, 0); + std::vector<std::string> input_tokens; + input_tokens.reserve(input_tokens_size); + input_tokens.push_back(std::string(kClassificationToken)); + for (int i = 0; i < input_tokens_size - 2; ++i) { + input_tokens.push_back(std::move(input_tokenize_results.subwords[i])); + } + input_tokens.push_back(std::string(kSeparator)); + + std::vector<int> input_ids(input_tensor_length, 0); + std::vector<int> input_mask(input_tensor_length, 0); // Convert tokens back into ids and set mask - for (int i = 0; i < tokens.size(); ++i) { - tokenizer_->LookupId(tokens[i], &input_ids[i]); + for (int i = 0; i < input_tokens.size(); ++i) { + tokenizer_->LookupId(input_tokens[i], &input_ids[i]); input_mask[i] = 1; } - // |<-----------kMaxSeqLen---------->| + // |<--------input_tensor_length------->| // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0 // input_masks 1 1 1... 1 1 0 0... 0 // segment_ids 0 0 0... 0 0 0 0... 0 PopulateTensor(input_ids, ids_tensor); PopulateTensor(input_mask, mask_tensor); - PopulateTensor(std::vector<int>(kMaxSeqLen, 0), segment_ids_tensor); + PopulateTensor(std::vector<int>(input_tensor_length, 0), segment_ids_tensor); return absl::OkStatus(); } @@ -189,6 +200,64 @@ absl::Status BertNLClassifier::InitializeFromMetadata() { TrySetLabelFromMetadata( GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex)) .IgnoreError(); + + auto* input_tensor_metadatas = + GetMetadataExtractor()->GetInputTensorMetadata(); + const auto& input_tensors = GetInputTensors(); + const auto& ids_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas, + kIdsTensorName); + const auto& mask_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas, + kMaskTensorName); + const auto& segment_ids_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas, + kSegmentIdsTensorName); + if (ids_tensor.dims->size != 2 || mask_tensor.dims->size != 2 || + segment_ids_tensor.dims->size != 2) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat( + "The three input tensors in Bert models are expected to have dim " + "2, but got ids_tensor (%d), mask_tensor (%d), segment_ids_tensor " + "(%d).", + ids_tensor.dims->size, mask_tensor.dims->size, + segment_ids_tensor.dims->size), + TfLiteSupportStatus::kInvalidInputTensorDimensionsError); + } + if (ids_tensor.dims->data[0] != 1 || mask_tensor.dims->data[0] != 1 || + segment_ids_tensor.dims->data[0] != 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat( + "The three input tensors in Bert models are expected to have same " + "batch size 1, but got ids_tensor (%d), mask_tensor (%d), " + "segment_ids_tensor (%d).", + ids_tensor.dims->data[0], mask_tensor.dims->data[0], + segment_ids_tensor.dims->data[0]), + TfLiteSupportStatus::kInvalidInputTensorSizeError); + } + if (ids_tensor.dims->data[1] != mask_tensor.dims->data[1] || + ids_tensor.dims->data[1] != segment_ids_tensor.dims->data[1]) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("The three input tensors in Bert models are " + "expected to have same length, but got ids_tensor " + "(%d), mask_tensor (%d), segment_ids_tensor (%d).", + ids_tensor.dims->data[1], mask_tensor.dims->data[1], + segment_ids_tensor.dims->data[1]), + TfLiteSupportStatus::kInvalidInputTensorSizeError); + } + if (ids_tensor.dims_signature->data[1] == -1 && + mask_tensor.dims_signature->data[1] == -1 && + segment_ids_tensor.dims_signature->data[1] == -1) { + input_tensors_are_dynamic_ = true; + } else if (ids_tensor.dims_signature->data[1] == -1 || + mask_tensor.dims_signature->data[1] == -1 || + segment_ids_tensor.dims_signature->data[1] == -1) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Input tensors contain a mix of static and dynamic tensors", + TfLiteSupportStatus::kInvalidInputTensorSizeError); + } + return absl::OkStatus(); } diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h index 0c709ee0..7a4b0587 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h @@ -95,6 +95,7 @@ class BertNLClassifier : public NLClassifier { absl::Status InitializeFromMetadata(); std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_; + bool input_tensors_are_dynamic_ = false; }; } // namespace nlclassifier diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java index f0667d77..0a609fd2 100644 --- a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java +++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java @@ -93,4 +93,4 @@ public class BertNLClassifierTest { assertThat(findCategoryWithLabel(positiveResults, "positive").getScore()) .isGreaterThan(findCategoryWithLabel(positiveResults, "negative").getScore()); } -}
\ No newline at end of file +} |