aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSpencer Fricke <115671160+spencer-lunarg@users.noreply.github.com>2022-12-06 23:00:10 +0900
committerGitHub <noreply@github.com>2022-12-06 09:00:10 -0500
commit7b8f00f00a5b18374a294f728ec87565c2fc4ca1 (patch)
tree233f6fa6752336d993cbef29228664c004708951
parent40f5bf59c6acb4754a0bffd3c53a715732883a12 (diff)
downloadspirv-tools-7b8f00f00a5b18374a294f728ec87565c2fc4ca1.tar.gz
spirv-opt: Fix OpCompositeInsert with Null Constant (#5008)
* spirv-opt: Unify GetConstId function names * spirv-opt: Fix OpCompositeInsert with Null Constant * spirv-opt: Improve GetNullCompositeConstant description
-rw-r--r--source/opt/const_folding_rules.cpp51
-rw-r--r--source/opt/constants.cpp46
-rw-r--r--source/opt/constants.h15
-rw-r--r--source/opt/debug_info_manager.cpp5
-rw-r--r--source/opt/eliminate_dead_io_components_pass.cpp2
-rw-r--r--source/opt/interface_var_sroa.cpp4
-rw-r--r--source/opt/replace_desc_array_access_using_var_index.cpp2
-rw-r--r--source/opt/scalar_replacement_pass.cpp4
-rw-r--r--test/opt/fold_spec_const_op_composite_test.cpp287
9 files changed, 373 insertions, 43 deletions
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 19b39d63..14f22089 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -136,32 +136,38 @@ ConstantFoldingRule FoldInsertWithConstants() {
std::vector<const analysis::Constant*> chain;
std::vector<const analysis::Constant*> components;
const analysis::Type* type = nullptr;
+ const uint32_t final_index = (inst->NumInOperands() - 1);
- // Work down hierarchy and add all the indexes, not including the final
- // index.
+ // Work down hierarchy of all indexes
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
- if (composite->AsNullConstant()) {
- // Return Null for the return type.
- analysis::TypeManager* type_mgr = context->get_type_mgr();
- return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
- }
+ type = composite->type();
- if (i != inst->NumInOperands() - 1) {
- chain.push_back(composite);
+ if (composite->AsNullConstant()) {
+ // Make new composite so it can be inserted in the index with the
+ // non-null value
+ const auto new_composite = const_mgr->GetNullCompositeConstant(type);
+ // Keep track of any indexes along the way to last index
+ if (i != final_index) {
+ chain.push_back(new_composite);
+ }
+ components = new_composite->AsCompositeConstant()->GetComponents();
+ } else {
+ // Keep track of any indexes along the way to last index
+ if (i != final_index) {
+ chain.push_back(composite);
+ }
+ components = composite->AsCompositeConstant()->GetComponents();
}
const uint32_t index = inst->GetSingleWordInOperand(i);
- components = composite->AsCompositeConstant()->GetComponents();
- type = composite->AsCompositeConstant()->type();
composite = components[index];
}
// Final index in hierarchy is inserted with new object.
- const uint32_t final_index =
- inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
+ const uint32_t final_operand = inst->GetSingleWordInOperand(final_index);
std::vector<uint32_t> ids;
for (size_t i = 0; i < components.size(); i++) {
const analysis::Constant* constant =
- (i == final_index) ? object : components[i];
+ (i == final_operand) ? object : components[i];
Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
ids.push_back(member_inst->result_id());
}
@@ -171,19 +177,16 @@ ConstantFoldingRule FoldInsertWithConstants() {
for (size_t i = chain.size(); i > 0; i--) {
// Need to insert any previous instruction into the module first.
// Can't just insert in types_values_begin() because it will move above
- // where the types are declared
- for (Module::inst_iterator inst_iter = context->types_values_begin();
- inst_iter != context->types_values_end(); ++inst_iter) {
- Instruction* x = &*inst_iter;
- if (inst->result_id() == x->result_id()) {
- const_mgr->BuildInstructionAndAddToModule(new_constant, &inst_iter);
- break;
- }
- }
+ // where the types are declared.
+ // Can't compare with location of inst because not all new added
+ // instructions are added to types_values_
+ auto iter = context->types_values_end();
+ Module::inst_iterator* pos = &iter;
+ const_mgr->BuildInstructionAndAddToModule(new_constant, pos);
composite = chain[i - 1];
components = composite->AsCompositeConstant()->GetComponents();
- type = composite->AsCompositeConstant()->type();
+ type = composite->type();
ids.clear();
for (size_t k = 0; k < components.size(); k++) {
const uint32_t index =
diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp
index 9930b44b..d70e27bb 100644
--- a/source/opt/constants.cpp
+++ b/source/opt/constants.cpp
@@ -391,6 +391,43 @@ const Constant* ConstantManager::GetConstant(
return cst ? RegisterConstant(std::move(cst)) : nullptr;
}
+const Constant* ConstantManager::GetNullCompositeConstant(const Type* type) {
+ std::vector<uint32_t> literal_words_or_id;
+
+ if (type->AsVector()) {
+ const Type* element_type = type->AsVector()->element_type();
+ const uint32_t null_id = GetNullConstId(element_type);
+ const uint32_t element_count = type->AsVector()->element_count();
+ for (uint32_t i = 0; i < element_count; i++) {
+ literal_words_or_id.push_back(null_id);
+ }
+ } else if (type->AsMatrix()) {
+ const Type* element_type = type->AsMatrix()->element_type();
+ const uint32_t null_id = GetNullConstId(element_type);
+ const uint32_t element_count = type->AsMatrix()->element_count();
+ for (uint32_t i = 0; i < element_count; i++) {
+ literal_words_or_id.push_back(null_id);
+ }
+ } else if (type->AsStruct()) {
+ // TODO (sfricke-lunarg) add proper struct support
+ return nullptr;
+ } else if (type->AsArray()) {
+ const Type* element_type = type->AsArray()->element_type();
+ const uint32_t null_id = GetNullConstId(element_type);
+ assert(type->AsArray()->length_info().words[0] ==
+ analysis::Array::LengthInfo::kConstant &&
+ "unexpected array length");
+ const uint32_t element_count = type->AsArray()->length_info().words[0];
+ for (uint32_t i = 0; i < element_count; i++) {
+ literal_words_or_id.push_back(null_id);
+ }
+ } else {
+ return nullptr;
+ }
+
+ return GetConstant(type, literal_words_or_id);
+}
+
const Constant* ConstantManager::GetNumericVectorConstantWithWords(
const Vector* type, const std::vector<uint32_t>& literal_words) {
const auto* element_type = type->element_type();
@@ -445,18 +482,23 @@ const Constant* ConstantManager::GetDoubleConst(double val) {
return c;
}
-uint32_t ConstantManager::GetSIntConst(int32_t val) {
+uint32_t ConstantManager::GetSIntConstId(int32_t val) {
Type* sint_type = context()->get_type_mgr()->GetSIntType();
const Constant* c = GetConstant(sint_type, {static_cast<uint32_t>(val)});
return GetDefiningInstruction(c)->result_id();
}
-uint32_t ConstantManager::GetUIntConst(uint32_t val) {
+uint32_t ConstantManager::GetUIntConstId(uint32_t val) {
Type* uint_type = context()->get_type_mgr()->GetUIntType();
const Constant* c = GetConstant(uint_type, {val});
return GetDefiningInstruction(c)->result_id();
}
+uint32_t ConstantManager::GetNullConstId(const Type* type) {
+ const Constant* c = GetConstant(type, {});
+ return GetDefiningInstruction(c)->result_id();
+}
+
std::vector<const analysis::Constant*> Constant::GetVectorComponents(
analysis::ConstantManager* const_mgr) const {
std::vector<const analysis::Constant*> components;
diff --git a/source/opt/constants.h b/source/opt/constants.h
index 588ca3e7..410304ea 100644
--- a/source/opt/constants.h
+++ b/source/opt/constants.h
@@ -520,6 +520,14 @@ class ConstantManager {
literal_words_or_ids.end()));
}
+ // Takes a type and creates a OpConstantComposite
+ // This allows a
+ // OpConstantNull %composite_type
+ // to become a
+ // OpConstantComposite %composite_type %null %null ... etc
+ // Assumes type is a Composite already, otherwise returns null
+ const Constant* GetNullCompositeConstant(const Type* type);
+
// Gets or creates a unique Constant instance of Vector type |type| with
// numeric elements and a vector of constant defining words |literal_words|.
// If a Constant instance existed already in the constant pool, it returns a
@@ -649,10 +657,13 @@ class ConstantManager {
const Constant* GetDoubleConst(double val);
// Returns the id of a 32-bit signed integer constant with value |val|.
- uint32_t GetSIntConst(int32_t val);
+ uint32_t GetSIntConstId(int32_t val);
// Returns the id of a 32-bit unsigned integer constant with value |val|.
- uint32_t GetUIntConst(uint32_t val);
+ uint32_t GetUIntConstId(uint32_t val);
+
+ // Returns the id of a OpConstantNull with type of |type|.
+ uint32_t GetNullConstId(const Type* type);
private:
// Creates a Constant instance with the given type and a vector of constant
diff --git a/source/opt/debug_info_manager.cpp b/source/opt/debug_info_manager.cpp
index 0ec392f5..1e614c6f 100644
--- a/source/opt/debug_info_manager.cpp
+++ b/source/opt/debug_info_manager.cpp
@@ -235,7 +235,8 @@ uint32_t DebugInfoManager::CreateDebugInlinedAt(const Instruction* line,
!context()->AreAnalysesValid(IRContext::Analysis::kAnalysisConstants))
line_number = AddNewConstInGlobals(context(), line_number);
else
- line_number = context()->get_constant_mgr()->GetUIntConst(line_number);
+ line_number =
+ context()->get_constant_mgr()->GetUIntConstId(line_number);
}
}
@@ -344,7 +345,7 @@ Instruction* DebugInfoManager::GetDebugOperationWithDeref() {
{static_cast<uint32_t>(OpenCLDebugInfo100Deref)}},
}));
} else {
- uint32_t deref_id = context()->get_constant_mgr()->GetUIntConst(
+ uint32_t deref_id = context()->get_constant_mgr()->GetUIntConstId(
NonSemanticShaderDebugInfo100Deref);
deref_operation = std::unique_ptr<Instruction>(
diff --git a/source/opt/eliminate_dead_io_components_pass.cpp b/source/opt/eliminate_dead_io_components_pass.cpp
index df596454..e430c6d5 100644
--- a/source/opt/eliminate_dead_io_components_pass.cpp
+++ b/source/opt/eliminate_dead_io_components_pass.cpp
@@ -197,7 +197,7 @@ void EliminateDeadIOComponentsPass::ChangeArrayLength(Instruction& arr_var,
type_mgr->GetType(arr_var.type_id())->AsPointer();
const analysis::Array* arr_ty = ptr_type->pointee_type()->AsArray();
assert(arr_ty && "expecting array type");
- uint32_t length_id = const_mgr->GetUIntConst(length);
+ uint32_t length_id = const_mgr->GetUIntConstId(length);
analysis::Array new_arr_ty(arr_ty->element_type(),
arr_ty->GetConstantLengthInfo(length_id, length));
analysis::Type* reg_new_arr_ty = type_mgr->GetRegisteredType(&new_arr_ty);
diff --git a/source/opt/interface_var_sroa.cpp b/source/opt/interface_var_sroa.cpp
index 8205c75f..08477cbd 100644
--- a/source/opt/interface_var_sroa.cpp
+++ b/source/opt/interface_var_sroa.cpp
@@ -489,7 +489,7 @@ Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex(
Instruction* insert_before) {
uint32_t ptr_type_id =
GetPointerType(component_type_id, GetStorageClass(var));
- uint32_t index_id = context()->get_constant_mgr()->GetUIntConst(index);
+ uint32_t index_id = context()->get_constant_mgr()->GetUIntConstId(index);
std::unique_ptr<Instruction> new_access_chain(new Instruction(
context(), spv::Op::OpAccessChain, ptr_type_id, TakeNextId(),
std::initializer_list<Operand>{
@@ -781,7 +781,7 @@ uint32_t InterfaceVariableScalarReplacement::GetArrayType(
uint32_t elem_type_id, uint32_t array_length) {
analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id);
uint32_t array_length_id =
- context()->get_constant_mgr()->GetUIntConst(array_length);
+ context()->get_constant_mgr()->GetUIntConstId(array_length);
analysis::Array array_type(
elem_type,
analysis::Array::LengthInfo{array_length_id, {0, array_length}});
diff --git a/source/opt/replace_desc_array_access_using_var_index.cpp b/source/opt/replace_desc_array_access_using_var_index.cpp
index 93c77d34..59745e12 100644
--- a/source/opt/replace_desc_array_access_using_var_index.cpp
+++ b/source/opt/replace_desc_array_access_using_var_index.cpp
@@ -331,7 +331,7 @@ BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateNewBlock() const {
void ReplaceDescArrayAccessUsingVarIndex::UseConstIndexForAccessChain(
Instruction* access_chain, uint32_t const_element_idx) const {
uint32_t const_element_idx_id =
- context()->get_constant_mgr()->GetUIntConst(const_element_idx);
+ context()->get_constant_mgr()->GetUIntConstId(const_element_idx);
access_chain->SetInOperand(kOpAccessChainInOperandIndexes,
{const_element_idx_id});
}
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp
index 6045158a..bfebb01c 100644
--- a/source/opt/scalar_replacement_pass.cpp
+++ b/source/opt/scalar_replacement_pass.cpp
@@ -191,7 +191,7 @@ bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
if (added_dbg_value == nullptr) return false;
added_dbg_value->AddOperand(
{SPV_OPERAND_TYPE_ID,
- {context()->get_constant_mgr()->GetSIntConst(idx)}});
+ {context()->get_constant_mgr()->GetSIntConstId(idx)}});
added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
{deref_expr->result_id()});
if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
@@ -217,7 +217,7 @@ bool ScalarReplacementPass::ReplaceWholeDebugValue(
// Append 'Indexes' operand.
new_dbg_value->AddOperand(
{SPV_OPERAND_TYPE_ID,
- {context()->get_constant_mgr()->GetSIntConst(idx)}});
+ {context()->get_constant_mgr()->GetSIntConstId(idx)}});
// Insert the new DebugValue to the basic block.
auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
diff --git a/test/opt/fold_spec_const_op_composite_test.cpp b/test/opt/fold_spec_const_op_composite_test.cpp
index aae9eb24..f83e86e9 100644
--- a/test/opt/fold_spec_const_op_composite_test.cpp
+++ b/test/opt/fold_spec_const_op_composite_test.cpp
@@ -340,6 +340,41 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertVector) {
SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
}
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+ CompositeInsertVectorIntoMatrix) {
+ const std::string test =
+ R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ OpExecutionMode %1 LocalSize 1 1 1
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v2float = OpTypeVector %float 2
+ %mat2v2float = OpTypeMatrix %v2float 2
+ %float_0 = OpConstant %float 0
+ %float_1 = OpConstant %float 1
+ %float_2 = OpConstant %float 2
+ %v2float_01 = OpConstantComposite %v2float %float_0 %float_1
+ %v2float_12 = OpConstantComposite %v2float %float_1 %float_2
+
+; CHECK: %10 = OpConstantComposite %v2float %float_0 %float_1
+; CHECK: %11 = OpConstantComposite %v2float %float_1 %float_2
+; CHECK: %12 = OpConstantComposite %mat2v2float %11 %11
+%mat2v2float_1212 = OpConstantComposite %mat2v2float %v2float_12 %v2float_12
+
+; CHECK: %15 = OpConstantComposite %mat2v2float %10 %11
+ %spec_0 = OpSpecConstantOp %mat2v2float CompositeInsert %v2float_01 %mat2v2float_1212 0
+ %1 = OpFunction %void None %3
+ %label = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrix) {
const std::string test =
R"(
@@ -374,7 +409,39 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrix) {
SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
}
-TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertNull) {
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertFloatNull) {
+ const std::string test =
+ R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ OpExecutionMode %1 LocalSize 1 1 1
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v3float = OpTypeVector %float 3
+ %float_1 = OpConstant %float 1
+
+; CHECK: %7 = OpConstantNull %float
+; CHECK: %8 = OpConstantComposite %v3float %7 %7 %7
+; CHECK: %12 = OpConstantComposite %v3float %7 %7 %float_1
+ %null = OpConstantNull %float
+ %spec_0 = OpConstantComposite %v3float %null %null %null
+ %spec_1 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_0 2
+
+; CHECK: %float_1_0 = OpConstant %float 1
+ %spec_2 = OpSpecConstantOp %float CompositeExtract %spec_1 2
+ %1 = OpFunction %void None %3
+ %label = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+ CompositeInsertFloatSetNull) {
const std::string test =
R"(
OpCapability Shader
@@ -384,16 +451,222 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertNull) {
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
+ %v3float = OpTypeVector %float 3
+ %float_1 = OpConstant %float 1
+
+; CHECK: %7 = OpConstantNull %float
+; CHECK: %8 = OpConstantComposite %v3float %7 %7 %float_1
+; CHECK: %12 = OpConstantComposite %v3float %7 %7 %7
+ %null = OpConstantNull %float
+ %spec_0 = OpConstantComposite %v3float %null %null %float_1
+ %spec_1 = OpSpecConstantOp %v3float CompositeInsert %null %spec_0 2
+
+; CHECK: %13 = OpConstantNull %float
+ %spec_2 = OpSpecConstantOp %float CompositeExtract %spec_1 2
+ %1 = OpFunction %void None %3
+ %label = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertVectorNull) {
+ const std::string test =
+ R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ OpExecutionMode %1 LocalSize 1 1 1
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v3float = OpTypeVector %float 3
+ %float_1 = OpConstant %float 1
+ %null = OpConstantNull %v3float
+
+; CHECK: %11 = OpConstantNull %float
+; CHECK: %12 = OpConstantComposite %v3float %11 %11 %float_1
+ %spec_0 = OpSpecConstantOp %v3float CompositeInsert %float_1 %null 2
+
+
+; CHECK: %float_1_0 = OpConstant %float 1
+ %spec_1 = OpSpecConstantOp %float CompositeExtract %spec_0 2
+ %1 = OpFunction %void None %3
+ %label = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+ CompositeInsertNullVectorIntoMatrix) {
+ const std::string test =
+ R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ OpExecutionMode %1 LocalSize 1 1 1
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v2float = OpTypeVector %float 2
+ %mat2v2float = OpTypeMatrix %v2float 2
+ %null = OpConstantNull %mat2v2float
+ %float_1 = OpConstant %float 1
+ %float_2 = OpConstant %float 2
+ %v2float_12 = OpConstantComposite %v2float %float_1 %float_2
+
+; CHECK: %13 = OpConstantNull %v2float
+; CHECK: %14 = OpConstantComposite %mat2v2float %10 %13
+ %spec_0 = OpSpecConstantOp %mat2v2float CompositeInsert %v2float_12 %null 0
+ %1 = OpFunction %void None %3
+ %label = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+ CompositeInsertVectorKeepNull) {
+ const std::string test =
+ R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ OpExecutionMode %1 LocalSize 1 1 1
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v3float = OpTypeVector %float 3
+ %float_0 = OpConstant %float 0
+ %null_float = OpConstantNull %float
+ %null_vec = OpConstantNull %v3float
+
+; CHECK: %15 = OpConstantComposite %v3float %7 %7 %float_0
+ %spec_0 = OpSpecConstantOp %v3float CompositeInsert %float_0 %null_vec 2
+
+; CHECK: %float_0_0 = OpConstant %float 0
+ %spec_1 = OpSpecConstantOp %float CompositeExtract %spec_0 2
+
+; CHECK: %17 = OpConstantComposite %v3float %7 %7 %7
+ %spec_2 = OpSpecConstantOp %v3float CompositeInsert %null_float %null_vec 2
+
+; CHECK: %18 = OpConstantNull %float
+ %spec_3 = OpSpecConstantOp %float CompositeExtract %spec_2 2
+ %1 = OpFunction %void None %3
+ %label = OpLabel
+ %add = OpFAdd %float %spec_3 %spec_3
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+ CompositeInsertVectorChainNull) {
+ const std::string test =
+ R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ OpExecutionMode %1 LocalSize 1 1 1
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v3float = OpTypeVector %float 3
+ %float_1 = OpConstant %float 1
+ %null = OpConstantNull %v3float
+
+; CHECK: %15 = OpConstantNull %float
+; CHECK: %16 = OpConstantComposite %v3float %15 %15 %float_1
+; CHECK: %17 = OpConstantComposite %v3float %15 %float_1 %float_1
+; CHECK: %18 = OpConstantComposite %v3float %float_1 %float_1 %float_1
+ %spec_0 = OpSpecConstantOp %v3float CompositeInsert %float_1 %null 2
+ %spec_1 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_0 1
+ %spec_2 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_1 0
+
+; CHECK: %float_1_0 = OpConstant %float 1
+; CHECK: %float_1_1 = OpConstant %float 1
+; CHECK: %float_1_2 = OpConstant %float 1
+ %spec_3 = OpSpecConstantOp %float CompositeExtract %spec_2 0
+ %spec_4 = OpSpecConstantOp %float CompositeExtract %spec_2 1
+ %spec_5 = OpSpecConstantOp %float CompositeExtract %spec_2 2
+ %1 = OpFunction %void None %3
+ %label = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+ CompositeInsertVectorChainReset) {
+ const std::string test =
+ R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ OpExecutionMode %1 LocalSize 1 1 1
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v3float = OpTypeVector %float 3
+ %float_1 = OpConstant %float 1
+ %null = OpConstantNull %float
+; CHECK: %8 = OpConstantComposite %v3float %7 %7 %float_1
+ %spec_0 = OpConstantComposite %v3float %null %null %float_1
+
+ ; set to null
+; CHECK: %13 = OpConstantComposite %v3float %7 %7 %7
+ %spec_1 = OpSpecConstantOp %v3float CompositeInsert %null %spec_0 2
+
+ ; set to back to original value
+; CHECK: %14 = OpConstantComposite %v3float %7 %7 %float_1
+ %spec_2 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_1 2
+
+; CHECK: %float_1_0 = OpConstant %float 1
+ %spec_3 = OpSpecConstantOp %float CompositeExtract %spec_2 2
+ %1 = OpFunction %void None %3
+ %label = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrixNull) {
+ const std::string test =
+ R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %func = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %int = OpTypeInt 32 0
%v2float = OpTypeVector %float 2
%mat2v2float = OpTypeMatrix %v2float 2
%null = OpConstantNull %mat2v2float
%float_1 = OpConstant %float 1
- %v2float_1 = OpConstantComposite %v2float %float_1 %float_1
- %mat2v2_1 = OpConstantComposite %mat2v2float %v2float_1 %v2float_1
- ; CHECK: %13 = OpConstantNull %mat2v2float
- %14 = OpSpecConstantOp %mat2v2float CompositeInsert %mat2v2_1 %null 0 0
- %1 = OpFunction %void None %3
- %16 = OpLabel
+ ; CHECK: %13 = OpConstantNull %v2float
+ ; CHECK: %14 = OpConstantNull %float
+ ; CHECK: %15 = OpConstantComposite %v2float %float_1 %14
+ ; CHECK: %16 = OpConstantComposite %mat2v2float %13 %15
+ %spec = OpSpecConstantOp %mat2v2float CompositeInsert %float_1 %null 1 0
+; extra type def to make sure new type def are not just thrown at end
+ %v2int = OpTypeVector %int 2
+ %main = OpFunction %void None %func
+ %label = OpLabel
OpReturn
OpFunctionEnd
)";