aboutsummaryrefslogtreecommitdiff
path: root/source/val/validate_function.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/val/validate_function.cpp')
-rw-r--r--source/val/validate_function.cpp105
1 files changed, 51 insertions, 54 deletions
diff --git a/source/val/validate_function.cpp b/source/val/validate_function.cpp
index db402aa3..0ccf5a9e 100644
--- a/source/val/validate_function.cpp
+++ b/source/val/validate_function.cpp
@@ -28,8 +28,7 @@ namespace {
// of the decorations that apply to |a|.
bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
ValidationState_t& _) {
- if (a->opcode() != spv::Op::OpTypePointer ||
- b->opcode() != spv::Op::OpTypePointer) {
+ if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) {
return false;
}
@@ -57,7 +56,7 @@ bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
const auto function_type = _.FindDef(function_type_id);
- if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) {
+ if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpFunction Function Type <id> " << _.getIdName(function_type_id)
<< " is not a function type.";
@@ -71,21 +70,21 @@ spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
<< _.getIdName(return_id) << ".";
}
- const std::vector<spv::Op> acceptable = {
- spv::Op::OpGroupDecorate,
- spv::Op::OpDecorate,
- spv::Op::OpEnqueueKernel,
- spv::Op::OpEntryPoint,
- spv::Op::OpExecutionMode,
- spv::Op::OpExecutionModeId,
- spv::Op::OpFunctionCall,
- spv::Op::OpGetKernelNDrangeSubGroupCount,
- spv::Op::OpGetKernelNDrangeMaxSubGroupSize,
- spv::Op::OpGetKernelWorkGroupSize,
- spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
- spv::Op::OpGetKernelLocalSizeForSubgroupCount,
- spv::Op::OpGetKernelMaxNumSubgroups,
- spv::Op::OpName};
+ const std::vector<SpvOp> acceptable = {
+ SpvOpGroupDecorate,
+ SpvOpDecorate,
+ SpvOpEnqueueKernel,
+ SpvOpEntryPoint,
+ SpvOpExecutionMode,
+ SpvOpExecutionModeId,
+ SpvOpFunctionCall,
+ SpvOpGetKernelNDrangeSubGroupCount,
+ SpvOpGetKernelNDrangeMaxSubGroupSize,
+ SpvOpGetKernelWorkGroupSize,
+ SpvOpGetKernelPreferredWorkGroupSizeMultiple,
+ SpvOpGetKernelLocalSizeForSubgroupCount,
+ SpvOpGetKernelMaxNumSubgroups,
+ SpvOpName};
for (auto& pair : inst->uses()) {
const auto* use = pair.first;
if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
@@ -113,14 +112,14 @@ spv_result_t ValidateFunctionParameter(ValidationState_t& _,
auto func_inst = &_.ordered_instructions()[inst_num];
while (--inst_num) {
func_inst = &_.ordered_instructions()[inst_num];
- if (func_inst->opcode() == spv::Op::OpFunction) {
+ if (func_inst->opcode() == SpvOpFunction) {
break;
- } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) {
+ } else if (func_inst->opcode() == SpvOpFunctionParameter) {
++param_index;
}
}
- if (func_inst->opcode() != spv::Op::OpFunction) {
+ if (func_inst->opcode() != SpvOpFunction) {
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
<< "Function parameter must be preceded by a function.";
}
@@ -151,25 +150,25 @@ spv_result_t ValidateFunctionParameter(ValidationState_t& _,
// Validate that PhysicalStorageBuffer have one of Restrict, Aliased,
// RestrictPointer, or AliasedPointer.
auto param_nonarray_type_id = param_type->id();
- while (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypeArray) {
+ while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
param_nonarray_type_id =
_.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
}
- if (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypePointer) {
+ if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
- if (param_nonarray_type->GetOperandAs<spv::StorageClass>(1u) ==
- spv::StorageClass::PhysicalStorageBuffer) {
+ if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
+ SpvStorageClassPhysicalStorageBuffer) {
// check for Aliased or Restrict
const auto& decorations = _.id_decorations(inst->id());
bool foundAliased = std::any_of(
decorations.begin(), decorations.end(), [](const Decoration& d) {
- return spv::Decoration::Aliased == d.dec_type();
+ return SpvDecorationAliased == d.dec_type();
});
bool foundRestrict = std::any_of(
decorations.begin(), decorations.end(), [](const Decoration& d) {
- return spv::Decoration::Restrict == d.dec_type();
+ return SpvDecorationRestrict == d.dec_type();
});
if (!foundAliased && !foundRestrict) {
@@ -188,20 +187,20 @@ spv_result_t ValidateFunctionParameter(ValidationState_t& _,
const auto pointee_type_id =
param_nonarray_type->GetOperandAs<uint32_t>(2);
const auto pointee_type = _.FindDef(pointee_type_id);
- if (spv::Op::OpTypePointer == pointee_type->opcode() &&
- pointee_type->GetOperandAs<spv::StorageClass>(1u) ==
- spv::StorageClass::PhysicalStorageBuffer) {
+ if (SpvOpTypePointer == pointee_type->opcode() &&
+ pointee_type->GetOperandAs<uint32_t>(1u) ==
+ SpvStorageClassPhysicalStorageBuffer) {
// check for AliasedPointer/RestrictPointer
const auto& decorations = _.id_decorations(inst->id());
bool foundAliased = std::any_of(
decorations.begin(), decorations.end(), [](const Decoration& d) {
- return spv::Decoration::AliasedPointer == d.dec_type();
+ return SpvDecorationAliasedPointer == d.dec_type();
});
bool foundRestrict = std::any_of(
decorations.begin(), decorations.end(), [](const Decoration& d) {
- return spv::Decoration::RestrictPointer == d.dec_type();
+ return SpvDecorationRestrictPointer == d.dec_type();
});
if (!foundAliased && !foundRestrict) {
@@ -227,7 +226,7 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _,
const Instruction* inst) {
const auto function_id = inst->GetOperandAs<uint32_t>(2);
const auto function = _.FindDef(function_id);
- if (!function || spv::Op::OpFunction != function->opcode()) {
+ if (!function || SpvOpFunction != function->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpFunctionCall Function <id> " << _.getIdName(function_id)
<< " is not a function.";
@@ -243,7 +242,7 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _,
const auto function_type_id = function->GetOperandAs<uint32_t>(3);
const auto function_type = _.FindDef(function_type_id);
- if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) {
+ if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Missing function type definition.";
}
@@ -286,21 +285,20 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _,
}
}
- if (_.addressing_model() == spv::AddressingModel::Logical) {
- if (parameter_type->opcode() == spv::Op::OpTypePointer &&
+ if (_.addressing_model() == SpvAddressingModelLogical) {
+ if (parameter_type->opcode() == SpvOpTypePointer &&
!_.options()->relax_logical_pointer) {
- spv::StorageClass sc =
- parameter_type->GetOperandAs<spv::StorageClass>(1u);
+ SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u);
// Validate which storage classes can be pointer operands.
switch (sc) {
- case spv::StorageClass::UniformConstant:
- case spv::StorageClass::Function:
- case spv::StorageClass::Private:
- case spv::StorageClass::Workgroup:
- case spv::StorageClass::AtomicCounter:
+ case SpvStorageClassUniformConstant:
+ case SpvStorageClassFunction:
+ case SpvStorageClassPrivate:
+ case SpvStorageClassWorkgroup:
+ case SpvStorageClassAtomicCounter:
// These are always allowed.
break;
- case spv::StorageClass::StorageBuffer:
+ case SpvStorageClassStorageBuffer:
if (!_.features().variable_pointers) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "StorageBuffer pointer operand "
@@ -315,14 +313,13 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _,
}
// Validate memory object declaration requirements.
- if (argument->opcode() != spv::Op::OpVariable &&
- argument->opcode() != spv::Op::OpFunctionParameter) {
+ if (argument->opcode() != SpvOpVariable &&
+ argument->opcode() != SpvOpFunctionParameter) {
const bool ssbo_vptr = _.features().variable_pointers &&
- sc == spv::StorageClass::StorageBuffer;
- const bool wg_vptr =
- _.HasCapability(spv::Capability::VariablePointers) &&
- sc == spv::StorageClass::Workgroup;
- const bool uc_ptr = sc == spv::StorageClass::UniformConstant;
+ sc == SpvStorageClassStorageBuffer;
+ const bool wg_vptr = _.HasCapability(SpvCapabilityVariablePointers) &&
+ sc == SpvStorageClassWorkgroup;
+ const bool uc_ptr = sc == SpvStorageClassUniformConstant;
if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Pointer operand " << _.getIdName(argument_id)
@@ -339,13 +336,13 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _,
spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
switch (inst->opcode()) {
- case spv::Op::OpFunction:
+ case SpvOpFunction:
if (auto error = ValidateFunction(_, inst)) return error;
break;
- case spv::Op::OpFunctionParameter:
+ case SpvOpFunctionParameter:
if (auto error = ValidateFunctionParameter(_, inst)) return error;
break;
- case spv::Op::OpFunctionCall:
+ case SpvOpFunctionCall:
if (auto error = ValidateFunctionCall(_, inst)) return error;
break;
default: