summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXusong Wang <xusongw@google.com>2020-05-26 17:43:13 +0000
committerAndroid (Google) Code Review <android-gerrit@google.com>2020-05-26 17:43:13 +0000
commit66e5923200afc965bf19b880737e9180e9f5c909 (patch)
tree744bfbff7a4b6062d8011b0721f4577e4c5a4a73
parent0824d7c6d6821941bde2d1b82efb7982ff7cc8a4 (diff)
parentf0af901e251b46938ceca80658b5cefc67fc7b6d (diff)
downloadml-66e5923200afc965bf19b880737e9180e9f5c909.tar.gz
Merge changes Ib3b191cc,I9afea607 into rvc-dev
* changes: Fix FULLY_CONNECTED issue with unknown num_units. Fix CAST issue with outputs of unknown rank.
-rw-r--r--nn/common/operations/Cast.cpp8
-rw-r--r--nn/common/operations/FullyConnected.cpp5
2 files changed, 7 insertions, 6 deletions
diff --git a/nn/common/operations/Cast.cpp b/nn/common/operations/Cast.cpp
index f8ca4022e..77e35afb0 100644
--- a/nn/common/operations/Cast.cpp
+++ b/nn/common/operations/Cast.cpp
@@ -17,12 +17,13 @@
#define LOG_TAG "Operations"
#include "Cast.h"
+
+#include <algorithm>
+
#include "HalInterfaces.h"
#include "Operations.h"
#include "Tracing.h"
-#include <algorithm>
-
namespace android {
namespace nn {
namespace cast {
@@ -67,9 +68,6 @@ bool copyToTensor(const FromT* inputData, int numElements, uint8_t* outputData,
} // namespace
bool prepare(const Shape& input, Shape* output) {
- if (input.dimensions.size() != output->dimensions.size()) {
- return false;
- }
output->dimensions = input.dimensions;
return true;
}
diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp
index 71808c0b7..9bdd0bab2 100644
--- a/nn/common/operations/FullyConnected.cpp
+++ b/nn/common/operations/FullyConnected.cpp
@@ -200,11 +200,14 @@ bool validateShapes(const Shape& input, const Shape& weights, const Shape& bias,
uint32_t input_n_elements = getNumberOfElements(input);
uint32_t num_units = getSizeOfDimension(weights, 0);
uint32_t input_size = getSizeOfDimension(weights, 1);
+ uint32_t bias_len = getSizeOfDimension(bias, 0);
uint32_t batch_size = input_size == 0 ? 0 : input_n_elements / input_size;
if (batch_size != 0) {
NN_RET_CHECK_EQ(input_size * batch_size, input_n_elements);
}
- NN_RET_CHECK_EQ(getSizeOfDimension(bias, 0), num_units);
+ if (num_units != 0 && bias_len != 0) {
+ NN_RET_CHECK_EQ(bias_len, num_units);
+ }
if (output != nullptr) {
// Only batch_size can be 0.
NN_RET_CHECK_GT(num_units, 0);