diff options
author | Xusong Wang <xusongw@google.com> | 2019-05-08 18:40:42 -0700 |
---|---|---|
committer | Xusong Wang <xusongw@google.com> | 2019-05-14 16:10:24 -0700 |
commit | 5c7e6f98f159a34e4e4d886339a40699060e41eb (patch) | |
tree | 327b66e9a68c820fd323dcf37b37aead7ab91de8 | |
parent | 0538dfbb3dbccab158775f62e251f0d962945455 (diff) | |
download | ml-5c7e6f98f159a34e4e4d886339a40699060e41eb.tar.gz |
Fix the mismatch between compliantWith and validateModel.
* Check rank 0 operand in compliantWith
* Check hardware buffer in compliantWith
* Disallow hardware buffer for pre-1.2 model in validateModel
* Add compliance tests for rank 0 tensor and hardware buffer
Bug: 131845106
Test: NeuralNetworksTest_static
Test: NeuralNetworksTest_static_fuzzing
Test: Above tests with debug.nn.strict-slicing set to 1
Change-Id: I0e2f80f93074d15ea68ac5fd162ca9e70e128835
-rw-r--r-- | nn/common/Utils.cpp | 34 | ||||
-rw-r--r-- | nn/common/ValidateHal.cpp | 93 | ||||
-rw-r--r-- | nn/common/include/Utils.h | 2 | ||||
-rw-r--r-- | nn/common/include/ValidateHal.h | 13 | ||||
-rw-r--r-- | nn/runtime/test/TestCompliance.cpp | 88 |
5 files changed, 178 insertions, 52 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp index a1b3b6a7e..731127a26 100644 --- a/nn/common/Utils.cpp +++ b/nn/common/Utils.cpp @@ -2076,6 +2076,12 @@ static hidl_vec<V1_1::Operation> convertToV1_1(const hidl_vec<V1_0::Operation>& return result; } +bool compliantWithV1_0(const V1_2::Operand& operand) { + return validOperandType(static_cast<V1_0::OperandType>(operand.type)) && + (nonExtensionOperandTypeIsScalar(static_cast<int>(operand.type)) || + operand.dimensions.size() != 0); +} + V1_0::Model convertToV1_0(const V1_0::Model& model) { return model; } @@ -2119,7 +2125,33 @@ void logModelToInfo(const V1_2::Model& model) { static bool compliantWith(HalVersion version, const V1_2::Model& model, std::set<uint32_t>* noncompliantOperations) { - auto localValidateOperation = [&model, version](const V1_2::Operation& op) { + if (version >= HalVersion::V1_2) return true; + + // A boolean vector indicating whether each pool is compliant with the target HAL version. + std::vector<bool> isPoolCompliant(model.pools.size(), false); + std::transform(model.pools.begin(), model.pools.end(), isPoolCompliant.begin(), + [version](const hidl_memory& pool) { return validatePool(pool, version); }); + + // A boolean vector indicating whether each operand is compliant with the target HAL version. + std::vector<bool> isOperandCompliant(model.operands.size(), false); + std::transform(model.operands.begin(), model.operands.end(), isOperandCompliant.begin(), + [&isPoolCompliant](const V1_2::Operand& op) { + // There is no V1_1::Operand -- both V1_0::Model and V1_1::Model use + // V1_0::Operand. + return compliantWithV1_0(op) && + !(op.lifetime == OperandLifeTime::CONSTANT_REFERENCE && + !isPoolCompliant[op.location.poolIndex]); + }); + + auto allOperandsCompliant = [&isOperandCompliant](const hidl_vec<uint32_t>& indices) { + return std::all_of( + indices.begin(), indices.end(), + [&isOperandCompliant](const uint32_t ind) { return isOperandCompliant[ind]; }); + }; + + auto localValidateOperation = [&model, version, + &allOperandsCompliant](const V1_2::Operation& op) { + if (!allOperandsCompliant(op.inputs) || !allOperandsCompliant(op.outputs)) return false; int error = validateOperation( static_cast<int32_t>(op.type), op.inputs.size(), op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(), diff --git a/nn/common/ValidateHal.cpp b/nn/common/ValidateHal.cpp index 1015d922d..421730a1d 100644 --- a/nn/common/ValidateHal.cpp +++ b/nn/common/ValidateHal.cpp @@ -27,6 +27,21 @@ namespace android { namespace nn { +template <class T_Model> +struct ModelToHalVersion; +template <> +struct ModelToHalVersion<V1_0::Model> { + static constexpr HalVersion version = HalVersion::V1_0; +}; +template <> +struct ModelToHalVersion<V1_1::Model> { + static constexpr HalVersion version = HalVersion::V1_1; +}; +template <> +struct ModelToHalVersion<V1_2::Model> { + static constexpr HalVersion version = HalVersion::V1_2; +}; + class MemoryAccessVerifier { public: MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools) @@ -418,22 +433,26 @@ static bool validateOperations(const hidl_vec<VersionedOperation>& operations, return true; } -static bool validatePools(const hidl_vec<hidl_memory>& pools) { - for (const hidl_memory& memory : pools) { - const auto& name = memory.name(); - if (name != "ashmem" && name != "mmap_fd" && name != "hardware_buffer_blob" && - name != "hardware_buffer") { - LOG(ERROR) << "Unsupported memory type " << name; - return false; - } - if (memory.handle() == nullptr) { - LOG(ERROR) << "Memory of type " << name << " is null"; - return false; - } +bool validatePool(const hidl_memory& pool, HalVersion ver) { + const auto& name = pool.name(); + if (name != "ashmem" && name != "mmap_fd" && + ((ver < HalVersion::V1_2) || + (name != "hardware_buffer_blob" && name != "hardware_buffer"))) { + LOG(ERROR) << "Unsupported memory type " << name; + return false; + } + if (pool.handle() == nullptr) { + LOG(ERROR) << "Memory of type " << name << " is null"; + return false; } return true; } +static bool validatePools(const hidl_vec<hidl_memory>& pools, HalVersion ver) { + return std::all_of(pools.begin(), pools.end(), + [ver](const hidl_memory& pool) { return validatePool(pool, ver); }); +} + static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes, const hidl_vec<Operand>& operands, OperandLifeTime lifetime) { const size_t operandCount = operands.size(); @@ -460,10 +479,10 @@ static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes, return true; } -template <typename VersionedModel> -static bool validateModelVersioned(const VersionedModel& model, bool allowUnspecifiedRank) { - NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, - "validateModelVersioned"); +template <class T_Model> +bool validateModel(const T_Model& model) { + NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel"); + HalVersion version = ModelToHalVersion<T_Model>::version; if (model.operations.size() == 0 || model.operands.size() == 0) { LOG(ERROR) << "Invalid empty model."; return false; @@ -472,26 +491,18 @@ static bool validateModelVersioned(const VersionedModel& model, bool allowUnspec // validations we can use operands upcasted to the latest version. const hidl_vec<Operand> latestVersionOperands = convertToV1_2(model.operands); return (validateOperands(model.operands, model.operandValues, model.pools, - allowUnspecifiedRank) && + /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) && validateOperations(model.operations, latestVersionOperands) && validateModelInputOutputs(model.inputIndexes, latestVersionOperands, OperandLifeTime::MODEL_INPUT) && validateModelInputOutputs(model.outputIndexes, latestVersionOperands, OperandLifeTime::MODEL_OUTPUT) && - validatePools(model.pools)); -} - -bool validateModel(const V1_0::Model& model) { - return validateModelVersioned(model, /*allowUnspecifiedRank=*/false); + validatePools(model.pools, version)); } -bool validateModel(const V1_1::Model& model) { - return validateModelVersioned(model, /*allowUnspecifiedRank=*/false); -} - -bool validateModel(const V1_2::Model& model) { - return validateModelVersioned(model, /*allowUnspecifiedRank=*/true); -} +template bool validateModel<V1_0::Model>(const V1_0::Model& model); +template bool validateModel<V1_1::Model>(const V1_1::Model& model); +template bool validateModel<V1_2::Model>(const V1_2::Model& model); // Validates the arguments of a request. type is either "input" or "output" and is used // for printing error messages. The operandIndexes is the appropriate array of input @@ -572,29 +583,21 @@ static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArg return true; } -template <typename VersionedModel> -static bool validateRequestVersioned(const Request& request, const VersionedModel& model, - bool allowDynamicOutputShape) { +template <class T_Model> +bool validateRequest(const Request& request, const T_Model& model) { + HalVersion version = ModelToHalVersion<T_Model>::version; return (validateRequestArguments(request.inputs, model.inputIndexes, convertToV1_2(model.operands), request.pools, /*allowUnspecified=*/false, "input") && validateRequestArguments(request.outputs, model.outputIndexes, convertToV1_2(model.operands), request.pools, - /*allowUnspecified=*/allowDynamicOutputShape, "output") && - validatePools(request.pools)); + /*allowUnspecified=*/version >= HalVersion::V1_2, "output") && + validatePools(request.pools, version)); } -bool validateRequest(const Request& request, const V1_0::Model& model) { - return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/false); -} - -bool validateRequest(const Request& request, const V1_1::Model& model) { - return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/false); -} - -bool validateRequest(const Request& request, const V1_2::Model& model) { - return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/true); -} +template bool validateRequest<V1_0::Model>(const Request& request, const V1_0::Model& model); +template bool validateRequest<V1_1::Model>(const Request& request, const V1_1::Model& model); +template bool validateRequest<V1_2::Model>(const Request& request, const V1_2::Model& model); bool validateExecutionPreference(ExecutionPreference preference) { return preference == ExecutionPreference::LOW_POWER || diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h index 64472f12a..bf6cffda4 100644 --- a/nn/common/include/Utils.h +++ b/nn/common/include/Utils.h @@ -315,6 +315,8 @@ bool compliantWithV1_2(const V1_0::Capabilities& capabilities); bool compliantWithV1_2(const V1_1::Capabilities& capabilities); bool compliantWithV1_2(const V1_2::Capabilities& capabilities); +bool compliantWithV1_0(const V1_2::Operand& operand); + // If noncompliantOperations != nullptr, then // precondition: noncompliantOperations->empty() // postcondition: *noncompliantOperations consists of the indices of the noncompliant diff --git a/nn/common/include/ValidateHal.h b/nn/common/include/ValidateHal.h index c953d8a10..4275a24a7 100644 --- a/nn/common/include/ValidateHal.h +++ b/nn/common/include/ValidateHal.h @@ -36,17 +36,15 @@ enum class HalVersion : int32_t { // IMPORTANT: This function cannot validate that OEM operation and operands // are correctly defined, as these are specific to each implementation. // Each driver should do their own validation of OEM types. -bool validateModel(const V1_0::Model& model); -bool validateModel(const V1_1::Model& model); -bool validateModel(const V1_2::Model& model); +template <class T_Model> +bool validateModel(const T_Model& model); // Verfies that the request for the given model is valid. // IMPORTANT: This function cannot validate that OEM operation and operands // are correctly defined, as these are specific to each implementation. // Each driver should do their own validation of OEM types. -bool validateRequest(const Request& request, const V1_0::Model& model); -bool validateRequest(const Request& request, const V1_1::Model& model); -bool validateRequest(const Request& request, const V1_2::Model& model); +template <class T_Model> +bool validateRequest(const Request& request, const T_Model& model); // Verfies that the execution preference is valid. bool validateExecutionPreference(ExecutionPreference preference); @@ -58,6 +56,9 @@ bool validOperationType(V1_2::OperationType operation); bool validOperandType(V1_0::OperandType operand); bool validOperandType(V1_2::OperandType operand); +// Verfies that the memory pool is valid in the specified HAL version. +bool validatePool(const hidl_memory& pool, HalVersion ver = HalVersion::LATEST); + } // namespace nn } // namespace android diff --git a/nn/runtime/test/TestCompliance.cpp b/nn/runtime/test/TestCompliance.cpp index 93918c803..52764154c 100644 --- a/nn/runtime/test/TestCompliance.cpp +++ b/nn/runtime/test/TestCompliance.cpp @@ -27,12 +27,15 @@ namespace compliance_test { using namespace ::android::nn; using HidlModel = V1_2::Model; using WrapperModel = test_wrapper::Model; +using WrapperOperandType = test_wrapper::OperandType; +using WrapperType = test_wrapper::Type; // Creates a HIDL model from a creator of the wrapper model. static HidlModel createHidlModel(std::function<void(WrapperModel*)> createModel) { HidlModel hidlModel; WrapperModel wrapperModel; createModel(&wrapperModel); + EXPECT_EQ(wrapperModel.finish(), test_wrapper::Result::NO_ERROR); ModelBuilder* modelBuilder = reinterpret_cast<ModelBuilder*>(wrapperModel.getHandle()); modelBuilder->setHidlModel(&hidlModel); return hidlModel; @@ -56,4 +59,89 @@ void ComplianceTest::testAvailableSinceV1_0(std::function<void(WrapperModel*)> c ASSERT_TRUE(compliantWithV1_0(model)); } +static const WrapperOperandType kTypeTensorFloat(WrapperType::TENSOR_FLOAT32, {1}); +static const WrapperOperandType kTypeTensorFloatRank0(WrapperType::TENSOR_FLOAT32, {}); +static const WrapperOperandType kTypeInt32(WrapperType::INT32, {}); + +TEST_F(ComplianceTest, Rank0TensorModelInput) { + int32_t act_init = 0; + // A simple ADD operation: op1 ADD op2 = op3, with op1 and op2 of rank 0. + testAvailableSinceV1_2([&act_init](WrapperModel* model) { + auto op1 = model->addOperand(&kTypeTensorFloatRank0); + auto op2 = model->addOperand(&kTypeTensorFloatRank0); + auto act = model->addOperand(&kTypeInt32); + auto op3 = model->addOperand(&kTypeTensorFloat); + model->setOperandValue(act, &act_init, sizeof(act_init)); + model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3}); + model->identifyInputsAndOutputs({op1, op2}, {op3}); + assert(model->isValid()); + }); +} + +TEST_F(ComplianceTest, Rank0TensorModelOutput) { + int32_t act_init = 0; + // A simple ADD operation: op1 ADD op2 = op3, with op3 of rank 0. + testAvailableSinceV1_2([&act_init](WrapperModel* model) { + auto op1 = model->addOperand(&kTypeTensorFloat); + auto op2 = model->addOperand(&kTypeTensorFloat); + auto act = model->addOperand(&kTypeInt32); + auto op3 = model->addOperand(&kTypeTensorFloatRank0); + model->setOperandValue(act, &act_init, sizeof(act_init)); + model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3}); + model->identifyInputsAndOutputs({op1, op2}, {op3}); + assert(model->isValid()); + }); +} + +TEST_F(ComplianceTest, Rank0TensorTemporaryVariable) { + int32_t act_init = 0; + // Two ADD operations: op1 ADD op2 = op3, op3 ADD op4 = op5, with op3 of rank 0. + testAvailableSinceV1_2([&act_init](WrapperModel* model) { + auto op1 = model->addOperand(&kTypeTensorFloat); + auto op2 = model->addOperand(&kTypeTensorFloat); + auto op3 = model->addOperand(&kTypeTensorFloatRank0); + auto op4 = model->addOperand(&kTypeTensorFloat); + auto op5 = model->addOperand(&kTypeTensorFloat); + auto act = model->addOperand(&kTypeInt32); + model->setOperandValue(act, &act_init, sizeof(act_init)); + model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3}); + model->addOperation(ANEURALNETWORKS_ADD, {op3, op4, act}, {op5}); + model->identifyInputsAndOutputs({op1, op2, op4}, {op5}); + assert(model->isValid()); + }); +} + +TEST_F(ComplianceTest, HardwareBuffer) { + const size_t memorySize = 20; + AHardwareBuffer_Desc desc{ + .width = memorySize, + .height = 1, + .layers = 1, + .format = AHARDWAREBUFFER_FORMAT_BLOB, + .usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN, + }; + + AHardwareBuffer* buffer = nullptr; + ASSERT_EQ(AHardwareBuffer_allocate(&desc, &buffer), 0); + test_wrapper::Memory memory(buffer); + ASSERT_TRUE(memory.isValid()); + + int32_t act_init = 0; + + // A simple ADD operation: op1 ADD op2 = op3, with op2 using a const hardware buffer. + testAvailableSinceV1_2([&memory, &act_init](WrapperModel* model) { + auto op1 = model->addOperand(&kTypeTensorFloat); + auto op2 = model->addOperand(&kTypeTensorFloat); + auto act = model->addOperand(&kTypeInt32); + auto op3 = model->addOperand(&kTypeTensorFloat); + model->setOperandValueFromMemory(op2, &memory, 0, sizeof(float)); + model->setOperandValue(act, &act_init, sizeof(act_init)); + model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3}); + model->identifyInputsAndOutputs({op1}, {op3}); + assert(model->isValid()); + }); + + AHardwareBuffer_release(buffer); +} + } // namespace compliance_test |