diff options
Diffstat (limited to 'source/val/validate_function.cpp')
-rw-r--r-- | source/val/validate_function.cpp | 105 |
1 files changed, 51 insertions, 54 deletions
diff --git a/source/val/validate_function.cpp b/source/val/validate_function.cpp index db402aa3..0ccf5a9e 100644 --- a/source/val/validate_function.cpp +++ b/source/val/validate_function.cpp @@ -28,8 +28,7 @@ namespace { // of the decorations that apply to |a|. bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b, ValidationState_t& _) { - if (a->opcode() != spv::Op::OpTypePointer || - b->opcode() != spv::Op::OpTypePointer) { + if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) { return false; } @@ -57,7 +56,7 @@ bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b, spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { const auto function_type_id = inst->GetOperandAs<uint32_t>(3); const auto function_type = _.FindDef(function_type_id); - if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) { + if (!function_type || SpvOpTypeFunction != function_type->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunction Function Type <id> " << _.getIdName(function_type_id) << " is not a function type."; @@ -71,21 +70,21 @@ spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { << _.getIdName(return_id) << "."; } - const std::vector<spv::Op> acceptable = { - spv::Op::OpGroupDecorate, - spv::Op::OpDecorate, - spv::Op::OpEnqueueKernel, - spv::Op::OpEntryPoint, - spv::Op::OpExecutionMode, - spv::Op::OpExecutionModeId, - spv::Op::OpFunctionCall, - spv::Op::OpGetKernelNDrangeSubGroupCount, - spv::Op::OpGetKernelNDrangeMaxSubGroupSize, - spv::Op::OpGetKernelWorkGroupSize, - spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple, - spv::Op::OpGetKernelLocalSizeForSubgroupCount, - spv::Op::OpGetKernelMaxNumSubgroups, - spv::Op::OpName}; + const std::vector<SpvOp> acceptable = { + SpvOpGroupDecorate, + SpvOpDecorate, + SpvOpEnqueueKernel, + SpvOpEntryPoint, + SpvOpExecutionMode, + SpvOpExecutionModeId, + SpvOpFunctionCall, + SpvOpGetKernelNDrangeSubGroupCount, + SpvOpGetKernelNDrangeMaxSubGroupSize, + SpvOpGetKernelWorkGroupSize, + SpvOpGetKernelPreferredWorkGroupSizeMultiple, + SpvOpGetKernelLocalSizeForSubgroupCount, + SpvOpGetKernelMaxNumSubgroups, + SpvOpName}; for (auto& pair : inst->uses()) { const auto* use = pair.first; if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == @@ -113,14 +112,14 @@ spv_result_t ValidateFunctionParameter(ValidationState_t& _, auto func_inst = &_.ordered_instructions()[inst_num]; while (--inst_num) { func_inst = &_.ordered_instructions()[inst_num]; - if (func_inst->opcode() == spv::Op::OpFunction) { + if (func_inst->opcode() == SpvOpFunction) { break; - } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) { + } else if (func_inst->opcode() == SpvOpFunctionParameter) { ++param_index; } } - if (func_inst->opcode() != spv::Op::OpFunction) { + if (func_inst->opcode() != SpvOpFunction) { return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "Function parameter must be preceded by a function."; } @@ -151,25 +150,25 @@ spv_result_t ValidateFunctionParameter(ValidationState_t& _, // Validate that PhysicalStorageBuffer have one of Restrict, Aliased, // RestrictPointer, or AliasedPointer. auto param_nonarray_type_id = param_type->id(); - while (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypeArray) { + while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) { param_nonarray_type_id = _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u); } - if (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypePointer) { + if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) { auto param_nonarray_type = _.FindDef(param_nonarray_type_id); - if (param_nonarray_type->GetOperandAs<spv::StorageClass>(1u) == - spv::StorageClass::PhysicalStorageBuffer) { + if (param_nonarray_type->GetOperandAs<uint32_t>(1u) == + SpvStorageClassPhysicalStorageBuffer) { // check for Aliased or Restrict const auto& decorations = _.id_decorations(inst->id()); bool foundAliased = std::any_of( decorations.begin(), decorations.end(), [](const Decoration& d) { - return spv::Decoration::Aliased == d.dec_type(); + return SpvDecorationAliased == d.dec_type(); }); bool foundRestrict = std::any_of( decorations.begin(), decorations.end(), [](const Decoration& d) { - return spv::Decoration::Restrict == d.dec_type(); + return SpvDecorationRestrict == d.dec_type(); }); if (!foundAliased && !foundRestrict) { @@ -188,20 +187,20 @@ spv_result_t ValidateFunctionParameter(ValidationState_t& _, const auto pointee_type_id = param_nonarray_type->GetOperandAs<uint32_t>(2); const auto pointee_type = _.FindDef(pointee_type_id); - if (spv::Op::OpTypePointer == pointee_type->opcode() && - pointee_type->GetOperandAs<spv::StorageClass>(1u) == - spv::StorageClass::PhysicalStorageBuffer) { + if (SpvOpTypePointer == pointee_type->opcode() && + pointee_type->GetOperandAs<uint32_t>(1u) == + SpvStorageClassPhysicalStorageBuffer) { // check for AliasedPointer/RestrictPointer const auto& decorations = _.id_decorations(inst->id()); bool foundAliased = std::any_of( decorations.begin(), decorations.end(), [](const Decoration& d) { - return spv::Decoration::AliasedPointer == d.dec_type(); + return SpvDecorationAliasedPointer == d.dec_type(); }); bool foundRestrict = std::any_of( decorations.begin(), decorations.end(), [](const Decoration& d) { - return spv::Decoration::RestrictPointer == d.dec_type(); + return SpvDecorationRestrictPointer == d.dec_type(); }); if (!foundAliased && !foundRestrict) { @@ -227,7 +226,7 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _, const Instruction* inst) { const auto function_id = inst->GetOperandAs<uint32_t>(2); const auto function = _.FindDef(function_id); - if (!function || spv::Op::OpFunction != function->opcode()) { + if (!function || SpvOpFunction != function->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunctionCall Function <id> " << _.getIdName(function_id) << " is not a function."; @@ -243,7 +242,7 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _, const auto function_type_id = function->GetOperandAs<uint32_t>(3); const auto function_type = _.FindDef(function_type_id); - if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) { + if (!function_type || function_type->opcode() != SpvOpTypeFunction) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Missing function type definition."; } @@ -286,21 +285,20 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _, } } - if (_.addressing_model() == spv::AddressingModel::Logical) { - if (parameter_type->opcode() == spv::Op::OpTypePointer && + if (_.addressing_model() == SpvAddressingModelLogical) { + if (parameter_type->opcode() == SpvOpTypePointer && !_.options()->relax_logical_pointer) { - spv::StorageClass sc = - parameter_type->GetOperandAs<spv::StorageClass>(1u); + SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u); // Validate which storage classes can be pointer operands. switch (sc) { - case spv::StorageClass::UniformConstant: - case spv::StorageClass::Function: - case spv::StorageClass::Private: - case spv::StorageClass::Workgroup: - case spv::StorageClass::AtomicCounter: + case SpvStorageClassUniformConstant: + case SpvStorageClassFunction: + case SpvStorageClassPrivate: + case SpvStorageClassWorkgroup: + case SpvStorageClassAtomicCounter: // These are always allowed. break; - case spv::StorageClass::StorageBuffer: + case SpvStorageClassStorageBuffer: if (!_.features().variable_pointers) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "StorageBuffer pointer operand " @@ -315,14 +313,13 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _, } // Validate memory object declaration requirements. - if (argument->opcode() != spv::Op::OpVariable && - argument->opcode() != spv::Op::OpFunctionParameter) { + if (argument->opcode() != SpvOpVariable && + argument->opcode() != SpvOpFunctionParameter) { const bool ssbo_vptr = _.features().variable_pointers && - sc == spv::StorageClass::StorageBuffer; - const bool wg_vptr = - _.HasCapability(spv::Capability::VariablePointers) && - sc == spv::StorageClass::Workgroup; - const bool uc_ptr = sc == spv::StorageClass::UniformConstant; + sc == SpvStorageClassStorageBuffer; + const bool wg_vptr = _.HasCapability(SpvCapabilityVariablePointers) && + sc == SpvStorageClassWorkgroup; + const bool uc_ptr = sc == SpvStorageClassUniformConstant; if (!ssbo_vptr && !wg_vptr && !uc_ptr) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Pointer operand " << _.getIdName(argument_id) @@ -339,13 +336,13 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _, spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { switch (inst->opcode()) { - case spv::Op::OpFunction: + case SpvOpFunction: if (auto error = ValidateFunction(_, inst)) return error; break; - case spv::Op::OpFunctionParameter: + case SpvOpFunctionParameter: if (auto error = ValidateFunctionParameter(_, inst)) return error; break; - case spv::Op::OpFunctionCall: + case SpvOpFunctionCall: if (auto error = ValidateFunctionCall(_, inst)) return error; break; default: |