summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTreehugger Robot <treehugger-gerrit@google.com>2021-02-02 03:34:04 +0000
committerGerrit Code Review <noreply-gerritcodereview@google.com>2021-02-02 03:34:04 +0000
commitb3b158e575aaaac56a0ebe025fc16946ca86d70a (patch)
tree9fc858e3c387affe6bd23a12cc549408eeb0ee5d
parenta90ed9c1702e90692930d3270f0cd71884452090 (diff)
parentf4767499dca3879b399105756a247632fe46ff7a (diff)
downloadml-b3b158e575aaaac56a0ebe025fc16946ca86d70a.tar.gz
Merge "Reject extension operations in sample drivers" into android11-gsi
-rw-r--r--nn/driver/sample/SampleDriverFloatFast.cpp2
-rw-r--r--nn/driver/sample/SampleDriverFloatSlow.cpp2
-rw-r--r--nn/driver/sample/SampleDriverFull.cpp4
-rw-r--r--nn/driver/sample/SampleDriverQuant.cpp2
-rw-r--r--nn/runtime/test/fibonacci_extension/FibonacciExtensionTest.cpp58
5 files changed, 46 insertions, 22 deletions
diff --git a/nn/driver/sample/SampleDriverFloatFast.cpp b/nn/driver/sample/SampleDriverFloatFast.cpp
index bb4b815b0..5d2cd1344 100644
--- a/nn/driver/sample/SampleDriverFloatFast.cpp
+++ b/nn/driver/sample/SampleDriverFloatFast.cpp
@@ -67,7 +67,7 @@ std::vector<bool> SampleDriverFloatFast::getSupportedOperationsImpl(
std::vector<bool> supported(count);
for (size_t i = 0; i < count; i++) {
const Operation& operation = model.main.operations[i];
- if (operation.inputs.size() > 0) {
+ if (!isExtensionOperationType(operation.type) && operation.inputs.size() > 0) {
const Operand& firstOperand = model.main.operands[operation.inputs[0]];
supported[i] = firstOperand.type == OperandType::TENSOR_FLOAT32;
}
diff --git a/nn/driver/sample/SampleDriverFloatSlow.cpp b/nn/driver/sample/SampleDriverFloatSlow.cpp
index 12e972cd5..1e6f0cb0d 100644
--- a/nn/driver/sample/SampleDriverFloatSlow.cpp
+++ b/nn/driver/sample/SampleDriverFloatSlow.cpp
@@ -67,7 +67,7 @@ std::vector<bool> SampleDriverFloatSlow::getSupportedOperationsImpl(
std::vector<bool> supported(count);
for (size_t i = 0; i < count; i++) {
const Operation& operation = model.main.operations[i];
- if (operation.inputs.size() > 0) {
+ if (!isExtensionOperationType(operation.type) && operation.inputs.size() > 0) {
const Operand& firstOperand = model.main.operands[operation.inputs[0]];
supported[i] = firstOperand.type == OperandType::TENSOR_FLOAT32;
}
diff --git a/nn/driver/sample/SampleDriverFull.cpp b/nn/driver/sample/SampleDriverFull.cpp
index 563551712..e0f15eaa3 100644
--- a/nn/driver/sample/SampleDriverFull.cpp
+++ b/nn/driver/sample/SampleDriverFull.cpp
@@ -48,6 +48,10 @@ Return<void> SampleDriverFull::getSupportedOperations_1_3(const V1_3::Model& mod
if (validateModel(model)) {
const size_t count = model.main.operations.size();
std::vector<bool> supported(count, true);
+ for (size_t i = 0; i < count; i++) {
+ const Operation& operation = model.main.operations[i];
+ supported[i] = !isExtensionOperationType(operation.type);
+ }
cb(ErrorStatus::NONE, supported);
} else {
std::vector<bool> supported;
diff --git a/nn/driver/sample/SampleDriverQuant.cpp b/nn/driver/sample/SampleDriverQuant.cpp
index 39d02a643..91eb6e268 100644
--- a/nn/driver/sample/SampleDriverQuant.cpp
+++ b/nn/driver/sample/SampleDriverQuant.cpp
@@ -67,7 +67,7 @@ std::vector<bool> SampleDriverQuant::getSupportedOperationsImpl(const V1_3::Mode
std::vector<bool> supported(count);
for (size_t i = 0; i < count; i++) {
const Operation& operation = model.main.operations[i];
- if (operation.inputs.size() > 0) {
+ if (!isExtensionOperationType(operation.type) && operation.inputs.size() > 0) {
const Operand& firstOperand = model.main.operands[operation.inputs[0]];
supported[i] = isQuantized(firstOperand.type);
if (operation.type == OperationType::SELECT) {
diff --git a/nn/runtime/test/fibonacci_extension/FibonacciExtensionTest.cpp b/nn/runtime/test/fibonacci_extension/FibonacciExtensionTest.cpp
index cdafa344f..faeeda3fd 100644
--- a/nn/runtime/test/fibonacci_extension/FibonacciExtensionTest.cpp
+++ b/nn/runtime/test/fibonacci_extension/FibonacciExtensionTest.cpp
@@ -14,6 +14,12 @@
* limitations under the License.
*/
+#include <gtest/gtest.h>
+
+#include <vector>
+
+#include "FibonacciDriver.h"
+#include "FibonacciExtension.h"
#include "HalInterfaces.h"
#include "Manager.h"
#include "NeuralNetworks.h"
@@ -24,13 +30,6 @@
#include "Utils.h"
#include "ValidateHal.h"
-#include <gtest/gtest.h>
-
-#include "FibonacciDriver.h"
-#include "FibonacciExtension.h"
-
-#include <vector>
-
namespace android {
namespace nn {
namespace {
@@ -58,27 +57,26 @@ class FibonacciExtensionTest : public ::testing::Test {
uint32_t numDevices = 0;
ASSERT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
- ANeuralNetworksDevice* fibonacciDevice = nullptr;
- ANeuralNetworksDevice* cpuDevice = nullptr;
for (uint32_t i = 0; i < numDevices; i++) {
ANeuralNetworksDevice* device = nullptr;
EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
+ mAllDevices.push_back(device);
bool supportsFibonacciExtension;
ASSERT_EQ(
ANeuralNetworksDevice_getExtensionSupport(
device, EXAMPLE_FIBONACCI_EXTENSION_NAME, &supportsFibonacciExtension),
ANEURALNETWORKS_NO_ERROR);
if (supportsFibonacciExtension) {
- ASSERT_EQ(fibonacciDevice, nullptr) << "Found multiple Fibonacci drivers";
- fibonacciDevice = device;
+ ASSERT_EQ(mFibonacciDevice, nullptr) << "Found multiple Fibonacci drivers";
+ mFibonacciDevice = device;
} else if (DeviceManager::get()->forTest_isCpuDevice(device)) {
- ASSERT_EQ(cpuDevice, nullptr) << "Found multiple CPU drivers";
- cpuDevice = device;
+ ASSERT_EQ(mCpuDevice, nullptr) << "Found multiple CPU drivers";
+ mCpuDevice = device;
}
}
- ASSERT_NE(fibonacciDevice, nullptr) << "Expecting Fibonacci driver to be available";
- ASSERT_NE(cpuDevice, nullptr) << "Expecting CPU driver to be available";
- mDevices = {fibonacciDevice, cpuDevice};
+ ASSERT_NE(mFibonacciDevice, nullptr) << "Expecting Fibonacci driver to be available";
+ ASSERT_NE(mCpuDevice, nullptr) << "Expecting CPU driver to be available";
+ mDevices = {mFibonacciDevice, mCpuDevice};
}
virtual void TearDown() {
@@ -92,12 +90,13 @@ class FibonacciExtensionTest : public ::testing::Test {
TypeManager::get()->forTest_reset();
}
- void checkSupportedOperations(const std::vector<bool>& expected) {
+ void checkSupportedOperations(const std::vector<bool>& expected,
+ const std::vector<ANeuralNetworksDevice*> devices) {
const uint32_t kMaxNumberOperations = 256;
EXPECT_LE(expected.size(), kMaxNumberOperations);
bool supported[kMaxNumberOperations] = {false};
EXPECT_EQ(ANeuralNetworksModel_getSupportedOperationsForDevices(
- mModel.getHandle(), mDevices.data(), mDevices.size(), supported),
+ mModel.getHandle(), devices.data(), devices.size(), supported),
ANEURALNETWORKS_NO_ERROR);
for (size_t i = 0; i < expected.size(); ++i) {
SCOPED_TRACE(::testing::Message() << "i = " << i);
@@ -105,6 +104,10 @@ class FibonacciExtensionTest : public ::testing::Test {
}
}
+ void checkSupportedOperations(const std::vector<bool>& expected) {
+ checkSupportedOperations(expected, mDevices);
+ }
+
void prepareForExecution() {
ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), mDevices.data(),
mDevices.size(), &mCompilation),
@@ -114,7 +117,10 @@ class FibonacciExtensionTest : public ::testing::Test {
ANEURALNETWORKS_NO_ERROR);
}
- std::vector<ANeuralNetworksDevice*> mDevices;
+ ANeuralNetworksDevice* mFibonacciDevice = nullptr;
+ ANeuralNetworksDevice* mCpuDevice = nullptr;
+ std::vector<ANeuralNetworksDevice*> mDevices; // Fibonacci and CPU devices.
+ std::vector<ANeuralNetworksDevice*> mAllDevices;
ANeuralNetworksExecution* mExecution = nullptr;
ANeuralNetworksCompilation* mCompilation = nullptr;
ExtensionModel mModel;
@@ -334,6 +340,20 @@ TEST_F(FibonacciExtensionTest, InvalidOperation) {
ASSERT_EQ(ANeuralNetworksCompilation_finish(mCompilation), ANEURALNETWORKS_BAD_DATA);
}
+TEST_F(FibonacciExtensionTest, GetSupportedOperations) {
+ ExtensionOperandType inputType(Type::TENSOR_FLOAT32, {1});
+ ExtensionOperandType outputType(Type::TENSOR_FLOAT32, {1});
+ createModel(&mModel, inputType, outputType, /*addNopOperations=*/false);
+
+ for (ANeuralNetworksDevice* device : mAllDevices) {
+ const char* name = nullptr;
+ ASSERT_EQ(ANeuralNetworksDevice_getName(device, &name), ANEURALNETWORKS_NO_ERROR);
+ SCOPED_TRACE(::testing::Message() << "device = " << name);
+ // Only Fibonacci device should support Fibonacci operation.
+ checkSupportedOperations({device == mFibonacciDevice}, {device});
+ }
+}
+
} // namespace
} // namespace nn
} // namespace android