aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRodrigo Locatti <rlocatti@nvidia.com>2024-04-10 11:40:10 -0300
committerGitHub <noreply@github.com>2024-04-10 10:40:10 -0400
commit6761288d39e2af51d73a5d8edb328dafc2054b1c (patch)
tree604aef28a8216effd08812d81249b735eea97cf3
parent3983d15a1d34fb95656818af0fc89c6260cbf316 (diff)
downloadspirv-tools-6761288d39e2af51d73a5d8edb328dafc2054b1c.tar.gz
Validator: Support SPV_NV_raw_access_chains (#5568)
-rw-r--r--source/opcode.cpp3
-rw-r--r--source/val/validate_annotation.cpp3
-rw-r--r--source/val/validate_decorations.cpp8
-rw-r--r--source/val/validate_memory.cpp123
-rw-r--r--test/val/CMakeLists.txt1
-rw-r--r--test/val/val_extension_spv_nv_raw_access_chains.cpp510
6 files changed, 644 insertions, 4 deletions
diff --git a/source/opcode.cpp b/source/opcode.cpp
index 38d1a1be..787dbb34 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -295,6 +295,7 @@ bool spvOpcodeReturnsLogicalVariablePointer(const spv::Op opcode) {
case spv::Op::OpPtrAccessChain:
case spv::Op::OpLoad:
case spv::Op::OpConstantNull:
+ case spv::Op::OpRawAccessChainNV:
return true;
default:
return false;
@@ -309,6 +310,7 @@ int32_t spvOpcodeReturnsLogicalPointer(const spv::Op opcode) {
case spv::Op::OpFunctionParameter:
case spv::Op::OpImageTexelPointer:
case spv::Op::OpCopyObject:
+ case spv::Op::OpRawAccessChainNV:
return true;
default:
return false;
@@ -754,6 +756,7 @@ bool spvOpcodeIsAccessChain(spv::Op opcode) {
case spv::Op::OpInBoundsAccessChain:
case spv::Op::OpPtrAccessChain:
case spv::Op::OpInBoundsPtrAccessChain:
+ case spv::Op::OpRawAccessChainNV:
return true;
default:
return false;
diff --git a/source/val/validate_annotation.cpp b/source/val/validate_annotation.cpp
index 106004d0..dac35857 100644
--- a/source/val/validate_annotation.cpp
+++ b/source/val/validate_annotation.cpp
@@ -161,7 +161,8 @@ spv_result_t ValidateDecorationTarget(ValidationState_t& _, spv::Decoration dec,
case spv::Decoration::RestrictPointer:
case spv::Decoration::AliasedPointer:
if (target->opcode() != spv::Op::OpVariable &&
- target->opcode() != spv::Op::OpFunctionParameter) {
+ target->opcode() != spv::Op::OpFunctionParameter &&
+ target->opcode() != spv::Op::OpRawAccessChainNV) {
return fail(0) << "must be a memory object declaration";
}
if (_.GetIdOpcode(target->type_id()) != spv::Op::OpTypePointer) {
diff --git a/source/val/validate_decorations.cpp b/source/val/validate_decorations.cpp
index caa4a6f1..bb1fea55 100644
--- a/source/val/validate_decorations.cpp
+++ b/source/val/validate_decorations.cpp
@@ -1556,7 +1556,8 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
const auto opcode = inst.opcode();
const auto type_id = inst.type_id();
if (opcode != spv::Op::OpVariable &&
- opcode != spv::Op::OpFunctionParameter) {
+ opcode != spv::Op::OpFunctionParameter &&
+ opcode != spv::Op::OpRawAccessChainNV) {
return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
<< "Target of NonWritable decoration must be a memory object "
"declaration (a variable or a function parameter)";
@@ -1569,10 +1570,11 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
vstate.features().nonwritable_var_in_function_or_private) {
// New permitted feature in SPIR-V 1.4.
} else if (
- // It may point to a UBO, SSBO, or storage image.
+ // It may point to a UBO, SSBO, storage image, or raw access chain.
vstate.IsPointerToUniformBlock(type_id) ||
vstate.IsPointerToStorageBuffer(type_id) ||
- vstate.IsPointerToStorageImage(type_id)) {
+ vstate.IsPointerToStorageImage(type_id) ||
+ opcode == spv::Op::OpRawAccessChainNV) {
} else {
return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
<< "Target of NonWritable decoration is invalid: must point to a "
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index c9ecf51a..2d6715f4 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -1427,6 +1427,126 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
return SPV_SUCCESS;
}
+spv_result_t ValidateRawAccessChain(ValidationState_t& _,
+ const Instruction* inst) {
+ std::string instr_name = "Op" + std::string(spvOpcodeString(inst->opcode()));
+
+ // The result type must be OpTypePointer.
+ const auto result_type = _.FindDef(inst->type_id());
+ if (spv::Op::OpTypePointer != result_type->opcode()) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "The Result Type of " << instr_name << " <id> "
+ << _.getIdName(inst->id()) << " must be OpTypePointer. Found Op"
+ << spvOpcodeString(result_type->opcode()) << '.';
+ }
+
+ // The pointed storage class must be valid.
+ const auto storage_class = result_type->GetOperandAs<spv::StorageClass>(1);
+ if (storage_class != spv::StorageClass::StorageBuffer &&
+ storage_class != spv::StorageClass::PhysicalStorageBuffer &&
+ storage_class != spv::StorageClass::Uniform) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "The Result Type of " << instr_name << " <id> "
+ << _.getIdName(inst->id())
+ << " must point to a storage class of "
+ "StorageBuffer, PhysicalStorageBuffer, or Uniform.";
+ }
+
+ // The pointed type must not be one in the list below.
+ const auto result_type_pointee =
+ _.FindDef(result_type->GetOperandAs<uint32_t>(2));
+ if (result_type_pointee->opcode() == spv::Op::OpTypeArray ||
+ result_type_pointee->opcode() == spv::Op::OpTypeMatrix ||
+ result_type_pointee->opcode() == spv::Op::OpTypeStruct) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "The Result Type of " << instr_name << " <id> "
+ << _.getIdName(inst->id())
+ << " must not point to "
+ "OpTypeArray, OpTypeMatrix, or OpTypeStruct.";
+ }
+
+ // Validate Stride is a OpConstant.
+ const auto stride = _.FindDef(inst->GetOperandAs<uint32_t>(3));
+ if (stride->opcode() != spv::Op::OpConstant) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "The Stride of " << instr_name << " <id> "
+ << _.getIdName(inst->id()) << " must be OpConstant. Found Op"
+ << spvOpcodeString(stride->opcode()) << '.';
+ }
+ // Stride type must be OpTypeInt
+ const auto stride_type = _.FindDef(stride->type_id());
+ if (stride_type->opcode() != spv::Op::OpTypeInt) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "The type of Stride of " << instr_name << " <id> "
+ << _.getIdName(inst->id()) << " must be OpTypeInt. Found Op"
+ << spvOpcodeString(stride_type->opcode()) << '.';
+ }
+
+ // Index and Offset type must be OpTypeInt with a width of 32
+ const auto ValidateType = [&](const char* name,
+ int operandIndex) -> spv_result_t {
+ const auto value = _.FindDef(inst->GetOperandAs<uint32_t>(operandIndex));
+ const auto value_type = _.FindDef(value->type_id());
+ if (value_type->opcode() != spv::Op::OpTypeInt) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "The type of " << name << " of " << instr_name << " <id> "
+ << _.getIdName(inst->id()) << " must be OpTypeInt. Found Op"
+ << spvOpcodeString(value_type->opcode()) << '.';
+ }
+ const auto width = value_type->GetOperandAs<uint32_t>(1);
+ if (width != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "The integer width of " << name << " of " << instr_name
+ << " <id> " << _.getIdName(inst->id()) << " must be 32. Found "
+ << width << '.';
+ }
+ return SPV_SUCCESS;
+ };
+ spv_result_t result;
+ result = ValidateType("Index", 4);
+ if (result != SPV_SUCCESS) {
+ return result;
+ }
+ result = ValidateType("Offset", 5);
+ if (result != SPV_SUCCESS) {
+ return result;
+ }
+
+ uint32_t access_operands = 0;
+ if (inst->operands().size() >= 7) {
+ access_operands = inst->GetOperandAs<uint32_t>(6);
+ }
+ if (access_operands &
+ uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
+ uint64_t stride_value = 0;
+ if (_.EvalConstantValUint64(stride->id(), &stride_value) &&
+ stride_value == 0) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Stride must not be zero when per-element robustness is used.";
+ }
+ }
+ if (access_operands &
+ uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerComponentNV) ||
+ access_operands &
+ uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
+ if (storage_class == spv::StorageClass::PhysicalStorageBuffer) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Storage class cannot be PhysicalStorageBuffer when "
+ "raw access chain robustness is used.";
+ }
+ }
+ if (access_operands &
+ uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerComponentNV) &&
+ access_operands &
+ uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Per-component robustness and per-element robustness are "
+ "mutually exclusive.";
+ }
+
+ return SPV_SUCCESS;
+}
+
spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
const Instruction* inst) {
if (_.addressing_model() == spv::AddressingModel::Logical) {
@@ -1866,6 +1986,9 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpInBoundsPtrAccessChain:
if (auto error = ValidateAccessChain(_, inst)) return error;
break;
+ case spv::Op::OpRawAccessChainNV:
+ if (auto error = ValidateRawAccessChain(_, inst)) return error;
+ break;
case spv::Op::OpArrayLength:
if (auto error = ValidateArrayLength(_, inst)) return error;
break;
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index 62d93bdd..9d6f6ea6 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -46,6 +46,7 @@ add_spvtools_unittest(TARGET val_abcde
val_extension_spv_khr_bit_instructions_test.cpp
val_extension_spv_khr_terminate_invocation_test.cpp
val_extension_spv_khr_subgroup_rotate_test.cpp
+ val_extension_spv_nv_raw_access_chains.cpp
val_ext_inst_test.cpp
val_ext_inst_debug_test.cpp
${VAL_TEST_COMMON_SRCS}
diff --git a/test/val/val_extension_spv_nv_raw_access_chains.cpp b/test/val/val_extension_spv_nv_raw_access_chains.cpp
new file mode 100644
index 00000000..f06d7cd4
--- /dev/null
+++ b/test/val/val_extension_spv_nv_raw_access_chains.cpp
@@ -0,0 +1,510 @@
+// Copyright (c) 2024 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "source/spirv_target_env.h"
+#include "test/unit_spirv.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::HasSubstr;
+
+using ValidateSpvNVRawAccessChains = spvtest::ValidateBase<bool>;
+
+TEST_F(ValidateSpvNVRawAccessChains, Valid) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+ %intPtr = OpTypePointer StorageBuffer %int
+
+ %int_0 = OpConstant %int 0
+ %int_16 = OpConstant %int 16
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV
+ %unused = OpLoad %int %rawChain
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, NoCapability) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+ %intPtr = OpTypePointer StorageBuffer %int
+
+ %int_0 = OpConstant %int 0
+ %int_16 = OpConstant %int 16
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV
+ %unused = OpLoad %int %rawChain
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("requires one of these capabilities: RawAccessChainsNV"));
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, NoExtension) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+ %intPtr = OpTypePointer StorageBuffer %int
+
+ %int_0 = OpConstant %int 0
+ %int_16 = OpConstant %int 16
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV
+ %unused = OpLoad %int %rawChain
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_MISSING_EXTENSION, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("requires one of these extensions: SPV_NV_raw_access_chains"));
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, ReturnTypeNotPointer) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+
+ %int_0 = OpConstant %int 0
+ %int_16 = OpConstant %int 16
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %int %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("must be OpTypePointer. Found OpTypeInt"));
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, Workgroup) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+
+ %int = OpTypeInt 32 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer Workgroup %intStruct
+ %ssbo = OpVariable %intStructPtr Workgroup
+ %intPtr = OpTypePointer Workgroup %int
+
+ %int_0 = OpConstant %int 0
+ %int_16 = OpConstant %int 16
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV
+ %unused = OpLoad %int %rawChain
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("must point to a storage class of"));
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, ReturnTypeArray) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+ %int_1 = OpConstant %int 1
+ %intArray = OpTypeArray %int %int_1
+ %intArrayPtr = OpTypePointer StorageBuffer %intArray
+
+ %int_0 = OpConstant %int 0
+ %int_16 = OpConstant %int 16
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %intArrayPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("must not point to"));
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, VariableStride) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+ %intPtr = OpTypePointer StorageBuffer %int
+
+ %int_0 = OpConstant %int 0
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %stride = OpIAdd %int %int_0 %int_0
+ %rawChain = OpRawAccessChainNV %intPtr %ssbo %stride %int_0 %int_0 RobustnessPerComponentNV
+ %unused = OpLoad %int %rawChain
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("must be OpConstant"));
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, RobustnessPerElementZeroStride) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+ %intPtr = OpTypePointer StorageBuffer %int
+
+ %int_0 = OpConstant %int 0
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_0 %int_0 %int_0 RobustnessPerElementNV
+ %unused = OpLoad %int %rawChain
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Stride must not be zero when per-element robustness is used"));
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, BothRobustness) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+ %intPtr = OpTypePointer StorageBuffer %int
+
+ %int_0 = OpConstant %int 0
+ %int_16 = OpConstant %int 16
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerElementNV|RobustnessPerComponentNV
+ %unused = OpLoad %int %rawChain
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Per-component robustness and per-element robustness "
+ "are mutually exclusive"));
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, StrideFloat) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %float = OpTypeFloat 32
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+ %intPtr = OpTypePointer StorageBuffer %int
+
+ %int_0 = OpConstant %int 0
+ %float_16 = OpConstant %float 16
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %intPtr %ssbo %float_16 %int_0 %int_0 RobustnessPerComponentNV
+ %unused = OpLoad %int %rawChain
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("must be OpTypeInt"));
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, IndexType) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpCapability Int64
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %long = OpTypeInt 64 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+ %intPtr = OpTypePointer StorageBuffer %int
+
+ %int_0 = OpConstant %int 0
+ %int_16 = OpConstant %int 16
+ %long_0 = OpConstant %long 0
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %long_0 %int_0 RobustnessPerComponentNV
+ %unused = OpLoad %int %rawChain
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("The integer width of Index"));
+}
+
+TEST_F(ValidateSpvNVRawAccessChains, OffsetType) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpCapability RawAccessChainsNV
+ OpCapability Int64
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_NV_raw_access_chains"
+ OpMemoryModel Logical GLSL450
+
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %intStruct Block
+ OpMemberDecorate %intStruct 0 Offset 0
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+
+ %int = OpTypeInt 32 1
+ %long = OpTypeInt 64 1
+ %void = OpTypeVoid
+ %mainFunctionType = OpTypeFunction %void
+ %intStruct = OpTypeStruct %int
+ %intStructPtr = OpTypePointer StorageBuffer %intStruct
+ %ssbo = OpVariable %intStructPtr StorageBuffer
+ %intPtr = OpTypePointer StorageBuffer %int
+
+ %int_0 = OpConstant %int 0
+ %int_16 = OpConstant %int 16
+ %long_0 = OpConstant %long 0
+
+ %main = OpFunction %void None %mainFunctionType
+ %label = OpLabel
+ %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %long_0 RobustnessPerComponentNV
+ %unused = OpLoad %int %rawChain
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("The integer width of Offset"));
+}
+
+} // namespace
+} // namespace val
+} // namespace spvtools