summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorandroid-build-team Robot <android-build-team-robot@google.com>2020-05-22 01:07:25 +0000
committerandroid-build-team Robot <android-build-team-robot@google.com>2020-05-22 01:07:25 +0000
commit6b99f36d5097e92382fc5d72982ff43ac4417dbe (patch)
tree1cc80d6e5e71d1414d921d6ccba03ba010852e14
parent196188dee0b1c40318b8b73d44c69567712736a3 (diff)
parent0824d7c6d6821941bde2d1b82efb7982ff7cc8a4 (diff)
downloadml-6b99f36d5097e92382fc5d72982ff43ac4417dbe.tar.gz
Snap for 6520394 from 0824d7c6d6821941bde2d1b82efb7982ff7cc8a4 to rvc-release
Change-Id: I066342232903e34667b03d501433d9823d452570
-rw-r--r--nn/common/Utils.cpp45
1 files changed, 23 insertions, 22 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index 81e5cf1e1..fedc8cb30 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -1083,6 +1083,20 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
outExpectedTypes);
}
case ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM: {
+ const uint32_t kNumOutputs = 2;
+ const uint32_t kNumOutputsMerged = 1;
+ const uint32_t kNumOutputsWithState = 6;
+ const uint32_t kNumOutputsMergedWithState = 5;
+ if (inputCount != 61 ||
+ (outputCount != kNumOutputs && outputCount != kNumOutputsMerged &&
+ outputCount != kNumOutputsWithState &&
+ outputCount != kNumOutputsMergedWithState)) {
+ LOG(ERROR) << "Invalid number of input operands (" << inputCount
+ << ", expected 61) or output operands (" << outputCount
+ << ", expected 1, 2, 5 or 6) for operation " << getOperationName(opType);
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+
std::vector<OperandType> inExpectedTypes;
auto inputType = operands[inputIndexes[0]].type;
if (inputType != OperandType::TENSOR_FLOAT32 &&
@@ -1109,20 +1123,6 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
inExpectedTypes.push_back(inputType);
}
- const uint32_t kNumOutputs = 2;
- const uint32_t kNumOutputsMerged = 1;
- const uint32_t kNumOutputsWithState = 6;
- const uint32_t kNumOutputsMergedWithState = 5;
-
- if (inputCount != 61 ||
- (outputCount != kNumOutputs && outputCount != kNumOutputsMerged &&
- outputCount != kNumOutputsWithState &&
- outputCount != kNumOutputsMergedWithState)) {
- LOG(ERROR) << "Invalid number of input operands (" << inputCount
- << ", expected 61) or output operands (" << outputCount
- << ", expected 1, 2, 5 or 6) for operation " << getOperationName(opType);
- return ANEURALNETWORKS_BAD_DATA;
- }
HalVersion minSupportedHalVersion = HalVersion::V1_2;
if (outputCount == kNumOutputsWithState || outputCount == kNumOutputsMergedWithState) {
minSupportedHalVersion = HalVersion::V1_3;
@@ -1135,6 +1135,12 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
return status;
}
case ANEURALNETWORKS_LSTM: {
+ if ((inputCount != 23 && inputCount != 27) || outputCount != 4) {
+ LOG(ERROR) << "Invalid number of input operands (" << inputCount
+ << ", expected 23 or 27) or output operands (" << outputCount
+ << ", expected 4) for operation " << getOperationName(opType);
+ return ANEURALNETWORKS_BAD_DATA;
+ }
std::vector<OperandType> inExpectedTypes;
std::vector<OperandType> outExpectedTypes;
auto inputType = operands[inputIndexes[0]].type;
@@ -1160,18 +1166,13 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
}
outExpectedTypes = {inputType, inputType, inputType, inputType};
- if (inputCount == 23 && outputCount == 4) {
+ if (inputCount == 23) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- } else if (inputCount == 27 && outputCount == 4) {
+ } else {
+ NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
for (int i = 0; i < 4; ++i) {
inExpectedTypes.push_back(inputType);
}
- NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
- } else {
- LOG(ERROR) << "Invalid number of input operands (" << inputCount
- << ", expected 23 or 27) or output operands (" << outputCount
- << ", expected 4) for operation " << getOperationName(opType);
- return ANEURALNETWORKS_BAD_DATA;
}
return validateOperationOperandTypes(operands, inputCount, inputIndexes,
inExpectedTypes, outputCount, outputIndexes,