diff options
author | Treehugger Robot <treehugger-gerrit@google.com> | 2021-02-02 03:34:04 +0000 |
---|---|---|
committer | Gerrit Code Review <noreply-gerritcodereview@google.com> | 2021-02-02 03:34:04 +0000 |
commit | b3b158e575aaaac56a0ebe025fc16946ca86d70a (patch) | |
tree | 9fc858e3c387affe6bd23a12cc549408eeb0ee5d | |
parent | a90ed9c1702e90692930d3270f0cd71884452090 (diff) | |
parent | f4767499dca3879b399105756a247632fe46ff7a (diff) | |
download | ml-b3b158e575aaaac56a0ebe025fc16946ca86d70a.tar.gz |
Merge "Reject extension operations in sample drivers" into android11-gsi
-rw-r--r-- | nn/driver/sample/SampleDriverFloatFast.cpp | 2 | ||||
-rw-r--r-- | nn/driver/sample/SampleDriverFloatSlow.cpp | 2 | ||||
-rw-r--r-- | nn/driver/sample/SampleDriverFull.cpp | 4 | ||||
-rw-r--r-- | nn/driver/sample/SampleDriverQuant.cpp | 2 | ||||
-rw-r--r-- | nn/runtime/test/fibonacci_extension/FibonacciExtensionTest.cpp | 58 |
5 files changed, 46 insertions, 22 deletions
diff --git a/nn/driver/sample/SampleDriverFloatFast.cpp b/nn/driver/sample/SampleDriverFloatFast.cpp index bb4b815b0..5d2cd1344 100644 --- a/nn/driver/sample/SampleDriverFloatFast.cpp +++ b/nn/driver/sample/SampleDriverFloatFast.cpp @@ -67,7 +67,7 @@ std::vector<bool> SampleDriverFloatFast::getSupportedOperationsImpl( std::vector<bool> supported(count); for (size_t i = 0; i < count; i++) { const Operation& operation = model.main.operations[i]; - if (operation.inputs.size() > 0) { + if (!isExtensionOperationType(operation.type) && operation.inputs.size() > 0) { const Operand& firstOperand = model.main.operands[operation.inputs[0]]; supported[i] = firstOperand.type == OperandType::TENSOR_FLOAT32; } diff --git a/nn/driver/sample/SampleDriverFloatSlow.cpp b/nn/driver/sample/SampleDriverFloatSlow.cpp index 12e972cd5..1e6f0cb0d 100644 --- a/nn/driver/sample/SampleDriverFloatSlow.cpp +++ b/nn/driver/sample/SampleDriverFloatSlow.cpp @@ -67,7 +67,7 @@ std::vector<bool> SampleDriverFloatSlow::getSupportedOperationsImpl( std::vector<bool> supported(count); for (size_t i = 0; i < count; i++) { const Operation& operation = model.main.operations[i]; - if (operation.inputs.size() > 0) { + if (!isExtensionOperationType(operation.type) && operation.inputs.size() > 0) { const Operand& firstOperand = model.main.operands[operation.inputs[0]]; supported[i] = firstOperand.type == OperandType::TENSOR_FLOAT32; } diff --git a/nn/driver/sample/SampleDriverFull.cpp b/nn/driver/sample/SampleDriverFull.cpp index 563551712..e0f15eaa3 100644 --- a/nn/driver/sample/SampleDriverFull.cpp +++ b/nn/driver/sample/SampleDriverFull.cpp @@ -48,6 +48,10 @@ Return<void> SampleDriverFull::getSupportedOperations_1_3(const V1_3::Model& mod if (validateModel(model)) { const size_t count = model.main.operations.size(); std::vector<bool> supported(count, true); + for (size_t i = 0; i < count; i++) { + const Operation& operation = model.main.operations[i]; + supported[i] = !isExtensionOperationType(operation.type); + } cb(ErrorStatus::NONE, supported); } else { std::vector<bool> supported; diff --git a/nn/driver/sample/SampleDriverQuant.cpp b/nn/driver/sample/SampleDriverQuant.cpp index 39d02a643..91eb6e268 100644 --- a/nn/driver/sample/SampleDriverQuant.cpp +++ b/nn/driver/sample/SampleDriverQuant.cpp @@ -67,7 +67,7 @@ std::vector<bool> SampleDriverQuant::getSupportedOperationsImpl(const V1_3::Mode std::vector<bool> supported(count); for (size_t i = 0; i < count; i++) { const Operation& operation = model.main.operations[i]; - if (operation.inputs.size() > 0) { + if (!isExtensionOperationType(operation.type) && operation.inputs.size() > 0) { const Operand& firstOperand = model.main.operands[operation.inputs[0]]; supported[i] = isQuantized(firstOperand.type); if (operation.type == OperationType::SELECT) { diff --git a/nn/runtime/test/fibonacci_extension/FibonacciExtensionTest.cpp b/nn/runtime/test/fibonacci_extension/FibonacciExtensionTest.cpp index cdafa344f..faeeda3fd 100644 --- a/nn/runtime/test/fibonacci_extension/FibonacciExtensionTest.cpp +++ b/nn/runtime/test/fibonacci_extension/FibonacciExtensionTest.cpp @@ -14,6 +14,12 @@ * limitations under the License. */ +#include <gtest/gtest.h> + +#include <vector> + +#include "FibonacciDriver.h" +#include "FibonacciExtension.h" #include "HalInterfaces.h" #include "Manager.h" #include "NeuralNetworks.h" @@ -24,13 +30,6 @@ #include "Utils.h" #include "ValidateHal.h" -#include <gtest/gtest.h> - -#include "FibonacciDriver.h" -#include "FibonacciExtension.h" - -#include <vector> - namespace android { namespace nn { namespace { @@ -58,27 +57,26 @@ class FibonacciExtensionTest : public ::testing::Test { uint32_t numDevices = 0; ASSERT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR); - ANeuralNetworksDevice* fibonacciDevice = nullptr; - ANeuralNetworksDevice* cpuDevice = nullptr; for (uint32_t i = 0; i < numDevices; i++) { ANeuralNetworksDevice* device = nullptr; EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR); + mAllDevices.push_back(device); bool supportsFibonacciExtension; ASSERT_EQ( ANeuralNetworksDevice_getExtensionSupport( device, EXAMPLE_FIBONACCI_EXTENSION_NAME, &supportsFibonacciExtension), ANEURALNETWORKS_NO_ERROR); if (supportsFibonacciExtension) { - ASSERT_EQ(fibonacciDevice, nullptr) << "Found multiple Fibonacci drivers"; - fibonacciDevice = device; + ASSERT_EQ(mFibonacciDevice, nullptr) << "Found multiple Fibonacci drivers"; + mFibonacciDevice = device; } else if (DeviceManager::get()->forTest_isCpuDevice(device)) { - ASSERT_EQ(cpuDevice, nullptr) << "Found multiple CPU drivers"; - cpuDevice = device; + ASSERT_EQ(mCpuDevice, nullptr) << "Found multiple CPU drivers"; + mCpuDevice = device; } } - ASSERT_NE(fibonacciDevice, nullptr) << "Expecting Fibonacci driver to be available"; - ASSERT_NE(cpuDevice, nullptr) << "Expecting CPU driver to be available"; - mDevices = {fibonacciDevice, cpuDevice}; + ASSERT_NE(mFibonacciDevice, nullptr) << "Expecting Fibonacci driver to be available"; + ASSERT_NE(mCpuDevice, nullptr) << "Expecting CPU driver to be available"; + mDevices = {mFibonacciDevice, mCpuDevice}; } virtual void TearDown() { @@ -92,12 +90,13 @@ class FibonacciExtensionTest : public ::testing::Test { TypeManager::get()->forTest_reset(); } - void checkSupportedOperations(const std::vector<bool>& expected) { + void checkSupportedOperations(const std::vector<bool>& expected, + const std::vector<ANeuralNetworksDevice*> devices) { const uint32_t kMaxNumberOperations = 256; EXPECT_LE(expected.size(), kMaxNumberOperations); bool supported[kMaxNumberOperations] = {false}; EXPECT_EQ(ANeuralNetworksModel_getSupportedOperationsForDevices( - mModel.getHandle(), mDevices.data(), mDevices.size(), supported), + mModel.getHandle(), devices.data(), devices.size(), supported), ANEURALNETWORKS_NO_ERROR); for (size_t i = 0; i < expected.size(); ++i) { SCOPED_TRACE(::testing::Message() << "i = " << i); @@ -105,6 +104,10 @@ class FibonacciExtensionTest : public ::testing::Test { } } + void checkSupportedOperations(const std::vector<bool>& expected) { + checkSupportedOperations(expected, mDevices); + } + void prepareForExecution() { ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), mDevices.data(), mDevices.size(), &mCompilation), @@ -114,7 +117,10 @@ class FibonacciExtensionTest : public ::testing::Test { ANEURALNETWORKS_NO_ERROR); } - std::vector<ANeuralNetworksDevice*> mDevices; + ANeuralNetworksDevice* mFibonacciDevice = nullptr; + ANeuralNetworksDevice* mCpuDevice = nullptr; + std::vector<ANeuralNetworksDevice*> mDevices; // Fibonacci and CPU devices. + std::vector<ANeuralNetworksDevice*> mAllDevices; ANeuralNetworksExecution* mExecution = nullptr; ANeuralNetworksCompilation* mCompilation = nullptr; ExtensionModel mModel; @@ -334,6 +340,20 @@ TEST_F(FibonacciExtensionTest, InvalidOperation) { ASSERT_EQ(ANeuralNetworksCompilation_finish(mCompilation), ANEURALNETWORKS_BAD_DATA); } +TEST_F(FibonacciExtensionTest, GetSupportedOperations) { + ExtensionOperandType inputType(Type::TENSOR_FLOAT32, {1}); + ExtensionOperandType outputType(Type::TENSOR_FLOAT32, {1}); + createModel(&mModel, inputType, outputType, /*addNopOperations=*/false); + + for (ANeuralNetworksDevice* device : mAllDevices) { + const char* name = nullptr; + ASSERT_EQ(ANeuralNetworksDevice_getName(device, &name), ANEURALNETWORKS_NO_ERROR); + SCOPED_TRACE(::testing::Message() << "device = " << name); + // Only Fibonacci device should support Fibonacci operation. + checkSupportedOperations({device == mFibonacciDevice}, {device}); + } +} + } // namespace } // namespace nn } // namespace android |