summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXusong Wang <xusongw@google.com>2019-05-08 18:40:42 -0700
committerXusong Wang <xusongw@google.com>2019-05-14 16:10:24 -0700
commit5c7e6f98f159a34e4e4d886339a40699060e41eb (patch)
tree327b66e9a68c820fd323dcf37b37aead7ab91de8
parent0538dfbb3dbccab158775f62e251f0d962945455 (diff)
downloadml-5c7e6f98f159a34e4e4d886339a40699060e41eb.tar.gz
Fix the mismatch between compliantWith and validateModel.
* Check rank 0 operand in compliantWith * Check hardware buffer in compliantWith * Disallow hardware buffer for pre-1.2 model in validateModel * Add compliance tests for rank 0 tensor and hardware buffer Bug: 131845106 Test: NeuralNetworksTest_static Test: NeuralNetworksTest_static_fuzzing Test: Above tests with debug.nn.strict-slicing set to 1 Change-Id: I0e2f80f93074d15ea68ac5fd162ca9e70e128835
-rw-r--r--nn/common/Utils.cpp34
-rw-r--r--nn/common/ValidateHal.cpp93
-rw-r--r--nn/common/include/Utils.h2
-rw-r--r--nn/common/include/ValidateHal.h13
-rw-r--r--nn/runtime/test/TestCompliance.cpp88
5 files changed, 178 insertions, 52 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index a1b3b6a7e..731127a26 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -2076,6 +2076,12 @@ static hidl_vec<V1_1::Operation> convertToV1_1(const hidl_vec<V1_0::Operation>&
return result;
}
+bool compliantWithV1_0(const V1_2::Operand& operand) {
+ return validOperandType(static_cast<V1_0::OperandType>(operand.type)) &&
+ (nonExtensionOperandTypeIsScalar(static_cast<int>(operand.type)) ||
+ operand.dimensions.size() != 0);
+}
+
V1_0::Model convertToV1_0(const V1_0::Model& model) {
return model;
}
@@ -2119,7 +2125,33 @@ void logModelToInfo(const V1_2::Model& model) {
static bool compliantWith(HalVersion version, const V1_2::Model& model,
std::set<uint32_t>* noncompliantOperations) {
- auto localValidateOperation = [&model, version](const V1_2::Operation& op) {
+ if (version >= HalVersion::V1_2) return true;
+
+ // A boolean vector indicating whether each pool is compliant with the target HAL version.
+ std::vector<bool> isPoolCompliant(model.pools.size(), false);
+ std::transform(model.pools.begin(), model.pools.end(), isPoolCompliant.begin(),
+ [version](const hidl_memory& pool) { return validatePool(pool, version); });
+
+ // A boolean vector indicating whether each operand is compliant with the target HAL version.
+ std::vector<bool> isOperandCompliant(model.operands.size(), false);
+ std::transform(model.operands.begin(), model.operands.end(), isOperandCompliant.begin(),
+ [&isPoolCompliant](const V1_2::Operand& op) {
+ // There is no V1_1::Operand -- both V1_0::Model and V1_1::Model use
+ // V1_0::Operand.
+ return compliantWithV1_0(op) &&
+ !(op.lifetime == OperandLifeTime::CONSTANT_REFERENCE &&
+ !isPoolCompliant[op.location.poolIndex]);
+ });
+
+ auto allOperandsCompliant = [&isOperandCompliant](const hidl_vec<uint32_t>& indices) {
+ return std::all_of(
+ indices.begin(), indices.end(),
+ [&isOperandCompliant](const uint32_t ind) { return isOperandCompliant[ind]; });
+ };
+
+ auto localValidateOperation = [&model, version,
+ &allOperandsCompliant](const V1_2::Operation& op) {
+ if (!allOperandsCompliant(op.inputs) || !allOperandsCompliant(op.outputs)) return false;
int error = validateOperation(
static_cast<int32_t>(op.type), op.inputs.size(),
op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
diff --git a/nn/common/ValidateHal.cpp b/nn/common/ValidateHal.cpp
index 1015d922d..421730a1d 100644
--- a/nn/common/ValidateHal.cpp
+++ b/nn/common/ValidateHal.cpp
@@ -27,6 +27,21 @@
namespace android {
namespace nn {
+template <class T_Model>
+struct ModelToHalVersion;
+template <>
+struct ModelToHalVersion<V1_0::Model> {
+ static constexpr HalVersion version = HalVersion::V1_0;
+};
+template <>
+struct ModelToHalVersion<V1_1::Model> {
+ static constexpr HalVersion version = HalVersion::V1_1;
+};
+template <>
+struct ModelToHalVersion<V1_2::Model> {
+ static constexpr HalVersion version = HalVersion::V1_2;
+};
+
class MemoryAccessVerifier {
public:
MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
@@ -418,22 +433,26 @@ static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
return true;
}
-static bool validatePools(const hidl_vec<hidl_memory>& pools) {
- for (const hidl_memory& memory : pools) {
- const auto& name = memory.name();
- if (name != "ashmem" && name != "mmap_fd" && name != "hardware_buffer_blob" &&
- name != "hardware_buffer") {
- LOG(ERROR) << "Unsupported memory type " << name;
- return false;
- }
- if (memory.handle() == nullptr) {
- LOG(ERROR) << "Memory of type " << name << " is null";
- return false;
- }
+bool validatePool(const hidl_memory& pool, HalVersion ver) {
+ const auto& name = pool.name();
+ if (name != "ashmem" && name != "mmap_fd" &&
+ ((ver < HalVersion::V1_2) ||
+ (name != "hardware_buffer_blob" && name != "hardware_buffer"))) {
+ LOG(ERROR) << "Unsupported memory type " << name;
+ return false;
+ }
+ if (pool.handle() == nullptr) {
+ LOG(ERROR) << "Memory of type " << name << " is null";
+ return false;
}
return true;
}
+static bool validatePools(const hidl_vec<hidl_memory>& pools, HalVersion ver) {
+ return std::all_of(pools.begin(), pools.end(),
+ [ver](const hidl_memory& pool) { return validatePool(pool, ver); });
+}
+
static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
const hidl_vec<Operand>& operands, OperandLifeTime lifetime) {
const size_t operandCount = operands.size();
@@ -460,10 +479,10 @@ static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
return true;
}
-template <typename VersionedModel>
-static bool validateModelVersioned(const VersionedModel& model, bool allowUnspecifiedRank) {
- NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED,
- "validateModelVersioned");
+template <class T_Model>
+bool validateModel(const T_Model& model) {
+ NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
+ HalVersion version = ModelToHalVersion<T_Model>::version;
if (model.operations.size() == 0 || model.operands.size() == 0) {
LOG(ERROR) << "Invalid empty model.";
return false;
@@ -472,26 +491,18 @@ static bool validateModelVersioned(const VersionedModel& model, bool allowUnspec
// validations we can use operands upcasted to the latest version.
const hidl_vec<Operand> latestVersionOperands = convertToV1_2(model.operands);
return (validateOperands(model.operands, model.operandValues, model.pools,
- allowUnspecifiedRank) &&
+ /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
validateOperations(model.operations, latestVersionOperands) &&
validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
OperandLifeTime::MODEL_INPUT) &&
validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
OperandLifeTime::MODEL_OUTPUT) &&
- validatePools(model.pools));
-}
-
-bool validateModel(const V1_0::Model& model) {
- return validateModelVersioned(model, /*allowUnspecifiedRank=*/false);
+ validatePools(model.pools, version));
}
-bool validateModel(const V1_1::Model& model) {
- return validateModelVersioned(model, /*allowUnspecifiedRank=*/false);
-}
-
-bool validateModel(const V1_2::Model& model) {
- return validateModelVersioned(model, /*allowUnspecifiedRank=*/true);
-}
+template bool validateModel<V1_0::Model>(const V1_0::Model& model);
+template bool validateModel<V1_1::Model>(const V1_1::Model& model);
+template bool validateModel<V1_2::Model>(const V1_2::Model& model);
// Validates the arguments of a request. type is either "input" or "output" and is used
// for printing error messages. The operandIndexes is the appropriate array of input
@@ -572,29 +583,21 @@ static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArg
return true;
}
-template <typename VersionedModel>
-static bool validateRequestVersioned(const Request& request, const VersionedModel& model,
- bool allowDynamicOutputShape) {
+template <class T_Model>
+bool validateRequest(const Request& request, const T_Model& model) {
+ HalVersion version = ModelToHalVersion<T_Model>::version;
return (validateRequestArguments(request.inputs, model.inputIndexes,
convertToV1_2(model.operands), request.pools,
/*allowUnspecified=*/false, "input") &&
validateRequestArguments(request.outputs, model.outputIndexes,
convertToV1_2(model.operands), request.pools,
- /*allowUnspecified=*/allowDynamicOutputShape, "output") &&
- validatePools(request.pools));
+ /*allowUnspecified=*/version >= HalVersion::V1_2, "output") &&
+ validatePools(request.pools, version));
}
-bool validateRequest(const Request& request, const V1_0::Model& model) {
- return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/false);
-}
-
-bool validateRequest(const Request& request, const V1_1::Model& model) {
- return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/false);
-}
-
-bool validateRequest(const Request& request, const V1_2::Model& model) {
- return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/true);
-}
+template bool validateRequest<V1_0::Model>(const Request& request, const V1_0::Model& model);
+template bool validateRequest<V1_1::Model>(const Request& request, const V1_1::Model& model);
+template bool validateRequest<V1_2::Model>(const Request& request, const V1_2::Model& model);
bool validateExecutionPreference(ExecutionPreference preference) {
return preference == ExecutionPreference::LOW_POWER ||
diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h
index 64472f12a..bf6cffda4 100644
--- a/nn/common/include/Utils.h
+++ b/nn/common/include/Utils.h
@@ -315,6 +315,8 @@ bool compliantWithV1_2(const V1_0::Capabilities& capabilities);
bool compliantWithV1_2(const V1_1::Capabilities& capabilities);
bool compliantWithV1_2(const V1_2::Capabilities& capabilities);
+bool compliantWithV1_0(const V1_2::Operand& operand);
+
// If noncompliantOperations != nullptr, then
// precondition: noncompliantOperations->empty()
// postcondition: *noncompliantOperations consists of the indices of the noncompliant
diff --git a/nn/common/include/ValidateHal.h b/nn/common/include/ValidateHal.h
index c953d8a10..4275a24a7 100644
--- a/nn/common/include/ValidateHal.h
+++ b/nn/common/include/ValidateHal.h
@@ -36,17 +36,15 @@ enum class HalVersion : int32_t {
// IMPORTANT: This function cannot validate that OEM operation and operands
// are correctly defined, as these are specific to each implementation.
// Each driver should do their own validation of OEM types.
-bool validateModel(const V1_0::Model& model);
-bool validateModel(const V1_1::Model& model);
-bool validateModel(const V1_2::Model& model);
+template <class T_Model>
+bool validateModel(const T_Model& model);
// Verfies that the request for the given model is valid.
// IMPORTANT: This function cannot validate that OEM operation and operands
// are correctly defined, as these are specific to each implementation.
// Each driver should do their own validation of OEM types.
-bool validateRequest(const Request& request, const V1_0::Model& model);
-bool validateRequest(const Request& request, const V1_1::Model& model);
-bool validateRequest(const Request& request, const V1_2::Model& model);
+template <class T_Model>
+bool validateRequest(const Request& request, const T_Model& model);
// Verfies that the execution preference is valid.
bool validateExecutionPreference(ExecutionPreference preference);
@@ -58,6 +56,9 @@ bool validOperationType(V1_2::OperationType operation);
bool validOperandType(V1_0::OperandType operand);
bool validOperandType(V1_2::OperandType operand);
+// Verfies that the memory pool is valid in the specified HAL version.
+bool validatePool(const hidl_memory& pool, HalVersion ver = HalVersion::LATEST);
+
} // namespace nn
} // namespace android
diff --git a/nn/runtime/test/TestCompliance.cpp b/nn/runtime/test/TestCompliance.cpp
index 93918c803..52764154c 100644
--- a/nn/runtime/test/TestCompliance.cpp
+++ b/nn/runtime/test/TestCompliance.cpp
@@ -27,12 +27,15 @@ namespace compliance_test {
using namespace ::android::nn;
using HidlModel = V1_2::Model;
using WrapperModel = test_wrapper::Model;
+using WrapperOperandType = test_wrapper::OperandType;
+using WrapperType = test_wrapper::Type;
// Creates a HIDL model from a creator of the wrapper model.
static HidlModel createHidlModel(std::function<void(WrapperModel*)> createModel) {
HidlModel hidlModel;
WrapperModel wrapperModel;
createModel(&wrapperModel);
+ EXPECT_EQ(wrapperModel.finish(), test_wrapper::Result::NO_ERROR);
ModelBuilder* modelBuilder = reinterpret_cast<ModelBuilder*>(wrapperModel.getHandle());
modelBuilder->setHidlModel(&hidlModel);
return hidlModel;
@@ -56,4 +59,89 @@ void ComplianceTest::testAvailableSinceV1_0(std::function<void(WrapperModel*)> c
ASSERT_TRUE(compliantWithV1_0(model));
}
+static const WrapperOperandType kTypeTensorFloat(WrapperType::TENSOR_FLOAT32, {1});
+static const WrapperOperandType kTypeTensorFloatRank0(WrapperType::TENSOR_FLOAT32, {});
+static const WrapperOperandType kTypeInt32(WrapperType::INT32, {});
+
+TEST_F(ComplianceTest, Rank0TensorModelInput) {
+ int32_t act_init = 0;
+ // A simple ADD operation: op1 ADD op2 = op3, with op1 and op2 of rank 0.
+ testAvailableSinceV1_2([&act_init](WrapperModel* model) {
+ auto op1 = model->addOperand(&kTypeTensorFloatRank0);
+ auto op2 = model->addOperand(&kTypeTensorFloatRank0);
+ auto act = model->addOperand(&kTypeInt32);
+ auto op3 = model->addOperand(&kTypeTensorFloat);
+ model->setOperandValue(act, &act_init, sizeof(act_init));
+ model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3});
+ model->identifyInputsAndOutputs({op1, op2}, {op3});
+ assert(model->isValid());
+ });
+}
+
+TEST_F(ComplianceTest, Rank0TensorModelOutput) {
+ int32_t act_init = 0;
+ // A simple ADD operation: op1 ADD op2 = op3, with op3 of rank 0.
+ testAvailableSinceV1_2([&act_init](WrapperModel* model) {
+ auto op1 = model->addOperand(&kTypeTensorFloat);
+ auto op2 = model->addOperand(&kTypeTensorFloat);
+ auto act = model->addOperand(&kTypeInt32);
+ auto op3 = model->addOperand(&kTypeTensorFloatRank0);
+ model->setOperandValue(act, &act_init, sizeof(act_init));
+ model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3});
+ model->identifyInputsAndOutputs({op1, op2}, {op3});
+ assert(model->isValid());
+ });
+}
+
+TEST_F(ComplianceTest, Rank0TensorTemporaryVariable) {
+ int32_t act_init = 0;
+ // Two ADD operations: op1 ADD op2 = op3, op3 ADD op4 = op5, with op3 of rank 0.
+ testAvailableSinceV1_2([&act_init](WrapperModel* model) {
+ auto op1 = model->addOperand(&kTypeTensorFloat);
+ auto op2 = model->addOperand(&kTypeTensorFloat);
+ auto op3 = model->addOperand(&kTypeTensorFloatRank0);
+ auto op4 = model->addOperand(&kTypeTensorFloat);
+ auto op5 = model->addOperand(&kTypeTensorFloat);
+ auto act = model->addOperand(&kTypeInt32);
+ model->setOperandValue(act, &act_init, sizeof(act_init));
+ model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3});
+ model->addOperation(ANEURALNETWORKS_ADD, {op3, op4, act}, {op5});
+ model->identifyInputsAndOutputs({op1, op2, op4}, {op5});
+ assert(model->isValid());
+ });
+}
+
+TEST_F(ComplianceTest, HardwareBuffer) {
+ const size_t memorySize = 20;
+ AHardwareBuffer_Desc desc{
+ .width = memorySize,
+ .height = 1,
+ .layers = 1,
+ .format = AHARDWAREBUFFER_FORMAT_BLOB,
+ .usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN,
+ };
+
+ AHardwareBuffer* buffer = nullptr;
+ ASSERT_EQ(AHardwareBuffer_allocate(&desc, &buffer), 0);
+ test_wrapper::Memory memory(buffer);
+ ASSERT_TRUE(memory.isValid());
+
+ int32_t act_init = 0;
+
+ // A simple ADD operation: op1 ADD op2 = op3, with op2 using a const hardware buffer.
+ testAvailableSinceV1_2([&memory, &act_init](WrapperModel* model) {
+ auto op1 = model->addOperand(&kTypeTensorFloat);
+ auto op2 = model->addOperand(&kTypeTensorFloat);
+ auto act = model->addOperand(&kTypeInt32);
+ auto op3 = model->addOperand(&kTypeTensorFloat);
+ model->setOperandValueFromMemory(op2, &memory, 0, sizeof(float));
+ model->setOperandValue(act, &act_init, sizeof(act_init));
+ model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3});
+ model->identifyInputsAndOutputs({op1}, {op3});
+ assert(model->isValid());
+ });
+
+ AHardwareBuffer_release(buffer);
+}
+
} // namespace compliance_test