diff options
author | Xusong Wang <xusongw@google.com> | 2020-05-15 10:48:33 -0700 |
---|---|---|
committer | Xusong Wang <xusongw@google.com> | 2020-05-21 10:29:45 -0700 |
commit | f0af901e251b46938ceca80658b5cefc67fc7b6d (patch) | |
tree | 41dbfa4a97391c1e8be62f0d53b5cbf830b863be | |
parent | 28e0c9d9d39528d0d1c9ed18b3b9d861c7ec06fb (diff) | |
download | ml-f0af901e251b46938ceca80658b5cefc67fc7b6d.tar.gz |
Fix FULLY_CONNECTED issue with unknown num_units.
Fixes: 156748888
Test: NNT_static
Test: 1.3 VTS with ag/11509996
Change-Id: Ib3b191ccefdfd0d03f8c69772976ef0f2421a9d7
-rw-r--r-- | nn/common/operations/FullyConnected.cpp | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp index 71808c0b7..9bdd0bab2 100644 --- a/nn/common/operations/FullyConnected.cpp +++ b/nn/common/operations/FullyConnected.cpp @@ -200,11 +200,14 @@ bool validateShapes(const Shape& input, const Shape& weights, const Shape& bias, uint32_t input_n_elements = getNumberOfElements(input); uint32_t num_units = getSizeOfDimension(weights, 0); uint32_t input_size = getSizeOfDimension(weights, 1); + uint32_t bias_len = getSizeOfDimension(bias, 0); uint32_t batch_size = input_size == 0 ? 0 : input_n_elements / input_size; if (batch_size != 0) { NN_RET_CHECK_EQ(input_size * batch_size, input_n_elements); } - NN_RET_CHECK_EQ(getSizeOfDimension(bias, 0), num_units); + if (num_units != 0 && bias_len != 0) { + NN_RET_CHECK_EQ(bias_len, num_units); + } if (output != nullptr) { // Only batch_size can be 0. NN_RET_CHECK_GT(num_units, 0); |