diff options
author | Spencer Fricke <115671160+spencer-lunarg@users.noreply.github.com> | 2022-12-06 23:00:10 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-06 09:00:10 -0500 |
commit | 7b8f00f00a5b18374a294f728ec87565c2fc4ca1 (patch) | |
tree | 233f6fa6752336d993cbef29228664c004708951 | |
parent | 40f5bf59c6acb4754a0bffd3c53a715732883a12 (diff) | |
download | spirv-tools-7b8f00f00a5b18374a294f728ec87565c2fc4ca1.tar.gz |
spirv-opt: Fix OpCompositeInsert with Null Constant (#5008)
* spirv-opt: Unify GetConstId function names
* spirv-opt: Fix OpCompositeInsert with Null Constant
* spirv-opt: Improve GetNullCompositeConstant description
-rw-r--r-- | source/opt/const_folding_rules.cpp | 51 | ||||
-rw-r--r-- | source/opt/constants.cpp | 46 | ||||
-rw-r--r-- | source/opt/constants.h | 15 | ||||
-rw-r--r-- | source/opt/debug_info_manager.cpp | 5 | ||||
-rw-r--r-- | source/opt/eliminate_dead_io_components_pass.cpp | 2 | ||||
-rw-r--r-- | source/opt/interface_var_sroa.cpp | 4 | ||||
-rw-r--r-- | source/opt/replace_desc_array_access_using_var_index.cpp | 2 | ||||
-rw-r--r-- | source/opt/scalar_replacement_pass.cpp | 4 | ||||
-rw-r--r-- | test/opt/fold_spec_const_op_composite_test.cpp | 287 |
9 files changed, 373 insertions, 43 deletions
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index 19b39d63..14f22089 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -136,32 +136,38 @@ ConstantFoldingRule FoldInsertWithConstants() { std::vector<const analysis::Constant*> chain; std::vector<const analysis::Constant*> components; const analysis::Type* type = nullptr; + const uint32_t final_index = (inst->NumInOperands() - 1); - // Work down hierarchy and add all the indexes, not including the final - // index. + // Work down hierarchy of all indexes for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { - if (composite->AsNullConstant()) { - // Return Null for the return type. - analysis::TypeManager* type_mgr = context->get_type_mgr(); - return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {}); - } + type = composite->type(); - if (i != inst->NumInOperands() - 1) { - chain.push_back(composite); + if (composite->AsNullConstant()) { + // Make new composite so it can be inserted in the index with the + // non-null value + const auto new_composite = const_mgr->GetNullCompositeConstant(type); + // Keep track of any indexes along the way to last index + if (i != final_index) { + chain.push_back(new_composite); + } + components = new_composite->AsCompositeConstant()->GetComponents(); + } else { + // Keep track of any indexes along the way to last index + if (i != final_index) { + chain.push_back(composite); + } + components = composite->AsCompositeConstant()->GetComponents(); } const uint32_t index = inst->GetSingleWordInOperand(i); - components = composite->AsCompositeConstant()->GetComponents(); - type = composite->AsCompositeConstant()->type(); composite = components[index]; } // Final index in hierarchy is inserted with new object. - const uint32_t final_index = - inst->GetSingleWordInOperand(inst->NumInOperands() - 1); + const uint32_t final_operand = inst->GetSingleWordInOperand(final_index); std::vector<uint32_t> ids; for (size_t i = 0; i < components.size(); i++) { const analysis::Constant* constant = - (i == final_index) ? object : components[i]; + (i == final_operand) ? object : components[i]; Instruction* member_inst = const_mgr->GetDefiningInstruction(constant); ids.push_back(member_inst->result_id()); } @@ -171,19 +177,16 @@ ConstantFoldingRule FoldInsertWithConstants() { for (size_t i = chain.size(); i > 0; i--) { // Need to insert any previous instruction into the module first. // Can't just insert in types_values_begin() because it will move above - // where the types are declared - for (Module::inst_iterator inst_iter = context->types_values_begin(); - inst_iter != context->types_values_end(); ++inst_iter) { - Instruction* x = &*inst_iter; - if (inst->result_id() == x->result_id()) { - const_mgr->BuildInstructionAndAddToModule(new_constant, &inst_iter); - break; - } - } + // where the types are declared. + // Can't compare with location of inst because not all new added + // instructions are added to types_values_ + auto iter = context->types_values_end(); + Module::inst_iterator* pos = &iter; + const_mgr->BuildInstructionAndAddToModule(new_constant, pos); composite = chain[i - 1]; components = composite->AsCompositeConstant()->GetComponents(); - type = composite->AsCompositeConstant()->type(); + type = composite->type(); ids.clear(); for (size_t k = 0; k < components.size(); k++) { const uint32_t index = diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp index 9930b44b..d70e27bb 100644 --- a/source/opt/constants.cpp +++ b/source/opt/constants.cpp @@ -391,6 +391,43 @@ const Constant* ConstantManager::GetConstant( return cst ? RegisterConstant(std::move(cst)) : nullptr; } +const Constant* ConstantManager::GetNullCompositeConstant(const Type* type) { + std::vector<uint32_t> literal_words_or_id; + + if (type->AsVector()) { + const Type* element_type = type->AsVector()->element_type(); + const uint32_t null_id = GetNullConstId(element_type); + const uint32_t element_count = type->AsVector()->element_count(); + for (uint32_t i = 0; i < element_count; i++) { + literal_words_or_id.push_back(null_id); + } + } else if (type->AsMatrix()) { + const Type* element_type = type->AsMatrix()->element_type(); + const uint32_t null_id = GetNullConstId(element_type); + const uint32_t element_count = type->AsMatrix()->element_count(); + for (uint32_t i = 0; i < element_count; i++) { + literal_words_or_id.push_back(null_id); + } + } else if (type->AsStruct()) { + // TODO (sfricke-lunarg) add proper struct support + return nullptr; + } else if (type->AsArray()) { + const Type* element_type = type->AsArray()->element_type(); + const uint32_t null_id = GetNullConstId(element_type); + assert(type->AsArray()->length_info().words[0] == + analysis::Array::LengthInfo::kConstant && + "unexpected array length"); + const uint32_t element_count = type->AsArray()->length_info().words[0]; + for (uint32_t i = 0; i < element_count; i++) { + literal_words_or_id.push_back(null_id); + } + } else { + return nullptr; + } + + return GetConstant(type, literal_words_or_id); +} + const Constant* ConstantManager::GetNumericVectorConstantWithWords( const Vector* type, const std::vector<uint32_t>& literal_words) { const auto* element_type = type->element_type(); @@ -445,18 +482,23 @@ const Constant* ConstantManager::GetDoubleConst(double val) { return c; } -uint32_t ConstantManager::GetSIntConst(int32_t val) { +uint32_t ConstantManager::GetSIntConstId(int32_t val) { Type* sint_type = context()->get_type_mgr()->GetSIntType(); const Constant* c = GetConstant(sint_type, {static_cast<uint32_t>(val)}); return GetDefiningInstruction(c)->result_id(); } -uint32_t ConstantManager::GetUIntConst(uint32_t val) { +uint32_t ConstantManager::GetUIntConstId(uint32_t val) { Type* uint_type = context()->get_type_mgr()->GetUIntType(); const Constant* c = GetConstant(uint_type, {val}); return GetDefiningInstruction(c)->result_id(); } +uint32_t ConstantManager::GetNullConstId(const Type* type) { + const Constant* c = GetConstant(type, {}); + return GetDefiningInstruction(c)->result_id(); +} + std::vector<const analysis::Constant*> Constant::GetVectorComponents( analysis::ConstantManager* const_mgr) const { std::vector<const analysis::Constant*> components; diff --git a/source/opt/constants.h b/source/opt/constants.h index 588ca3e7..410304ea 100644 --- a/source/opt/constants.h +++ b/source/opt/constants.h @@ -520,6 +520,14 @@ class ConstantManager { literal_words_or_ids.end())); } + // Takes a type and creates a OpConstantComposite + // This allows a + // OpConstantNull %composite_type + // to become a + // OpConstantComposite %composite_type %null %null ... etc + // Assumes type is a Composite already, otherwise returns null + const Constant* GetNullCompositeConstant(const Type* type); + // Gets or creates a unique Constant instance of Vector type |type| with // numeric elements and a vector of constant defining words |literal_words|. // If a Constant instance existed already in the constant pool, it returns a @@ -649,10 +657,13 @@ class ConstantManager { const Constant* GetDoubleConst(double val); // Returns the id of a 32-bit signed integer constant with value |val|. - uint32_t GetSIntConst(int32_t val); + uint32_t GetSIntConstId(int32_t val); // Returns the id of a 32-bit unsigned integer constant with value |val|. - uint32_t GetUIntConst(uint32_t val); + uint32_t GetUIntConstId(uint32_t val); + + // Returns the id of a OpConstantNull with type of |type|. + uint32_t GetNullConstId(const Type* type); private: // Creates a Constant instance with the given type and a vector of constant diff --git a/source/opt/debug_info_manager.cpp b/source/opt/debug_info_manager.cpp index 0ec392f5..1e614c6f 100644 --- a/source/opt/debug_info_manager.cpp +++ b/source/opt/debug_info_manager.cpp @@ -235,7 +235,8 @@ uint32_t DebugInfoManager::CreateDebugInlinedAt(const Instruction* line, !context()->AreAnalysesValid(IRContext::Analysis::kAnalysisConstants)) line_number = AddNewConstInGlobals(context(), line_number); else - line_number = context()->get_constant_mgr()->GetUIntConst(line_number); + line_number = + context()->get_constant_mgr()->GetUIntConstId(line_number); } } @@ -344,7 +345,7 @@ Instruction* DebugInfoManager::GetDebugOperationWithDeref() { {static_cast<uint32_t>(OpenCLDebugInfo100Deref)}}, })); } else { - uint32_t deref_id = context()->get_constant_mgr()->GetUIntConst( + uint32_t deref_id = context()->get_constant_mgr()->GetUIntConstId( NonSemanticShaderDebugInfo100Deref); deref_operation = std::unique_ptr<Instruction>( diff --git a/source/opt/eliminate_dead_io_components_pass.cpp b/source/opt/eliminate_dead_io_components_pass.cpp index df596454..e430c6d5 100644 --- a/source/opt/eliminate_dead_io_components_pass.cpp +++ b/source/opt/eliminate_dead_io_components_pass.cpp @@ -197,7 +197,7 @@ void EliminateDeadIOComponentsPass::ChangeArrayLength(Instruction& arr_var, type_mgr->GetType(arr_var.type_id())->AsPointer(); const analysis::Array* arr_ty = ptr_type->pointee_type()->AsArray(); assert(arr_ty && "expecting array type"); - uint32_t length_id = const_mgr->GetUIntConst(length); + uint32_t length_id = const_mgr->GetUIntConstId(length); analysis::Array new_arr_ty(arr_ty->element_type(), arr_ty->GetConstantLengthInfo(length_id, length)); analysis::Type* reg_new_arr_ty = type_mgr->GetRegisteredType(&new_arr_ty); diff --git a/source/opt/interface_var_sroa.cpp b/source/opt/interface_var_sroa.cpp index 8205c75f..08477cbd 100644 --- a/source/opt/interface_var_sroa.cpp +++ b/source/opt/interface_var_sroa.cpp @@ -489,7 +489,7 @@ Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex( Instruction* insert_before) { uint32_t ptr_type_id = GetPointerType(component_type_id, GetStorageClass(var)); - uint32_t index_id = context()->get_constant_mgr()->GetUIntConst(index); + uint32_t index_id = context()->get_constant_mgr()->GetUIntConstId(index); std::unique_ptr<Instruction> new_access_chain(new Instruction( context(), spv::Op::OpAccessChain, ptr_type_id, TakeNextId(), std::initializer_list<Operand>{ @@ -781,7 +781,7 @@ uint32_t InterfaceVariableScalarReplacement::GetArrayType( uint32_t elem_type_id, uint32_t array_length) { analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id); uint32_t array_length_id = - context()->get_constant_mgr()->GetUIntConst(array_length); + context()->get_constant_mgr()->GetUIntConstId(array_length); analysis::Array array_type( elem_type, analysis::Array::LengthInfo{array_length_id, {0, array_length}}); diff --git a/source/opt/replace_desc_array_access_using_var_index.cpp b/source/opt/replace_desc_array_access_using_var_index.cpp index 93c77d34..59745e12 100644 --- a/source/opt/replace_desc_array_access_using_var_index.cpp +++ b/source/opt/replace_desc_array_access_using_var_index.cpp @@ -331,7 +331,7 @@ BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateNewBlock() const { void ReplaceDescArrayAccessUsingVarIndex::UseConstIndexForAccessChain( Instruction* access_chain, uint32_t const_element_idx) const { uint32_t const_element_idx_id = - context()->get_constant_mgr()->GetUIntConst(const_element_idx); + context()->get_constant_mgr()->GetUIntConstId(const_element_idx); access_chain->SetInOperand(kOpAccessChainInOperandIndexes, {const_element_idx_id}); } diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp index 6045158a..bfebb01c 100644 --- a/source/opt/scalar_replacement_pass.cpp +++ b/source/opt/scalar_replacement_pass.cpp @@ -191,7 +191,7 @@ bool ScalarReplacementPass::ReplaceWholeDebugDeclare( if (added_dbg_value == nullptr) return false; added_dbg_value->AddOperand( {SPV_OPERAND_TYPE_ID, - {context()->get_constant_mgr()->GetSIntConst(idx)}}); + {context()->get_constant_mgr()->GetSIntConstId(idx)}}); added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex, {deref_expr->result_id()}); if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) { @@ -217,7 +217,7 @@ bool ScalarReplacementPass::ReplaceWholeDebugValue( // Append 'Indexes' operand. new_dbg_value->AddOperand( {SPV_OPERAND_TYPE_ID, - {context()->get_constant_mgr()->GetSIntConst(idx)}}); + {context()->get_constant_mgr()->GetSIntConstId(idx)}}); // Insert the new DebugValue to the basic block. auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value)); get_def_use_mgr()->AnalyzeInstDefUse(added_instr); diff --git a/test/opt/fold_spec_const_op_composite_test.cpp b/test/opt/fold_spec_const_op_composite_test.cpp index aae9eb24..f83e86e9 100644 --- a/test/opt/fold_spec_const_op_composite_test.cpp +++ b/test/opt/fold_spec_const_op_composite_test.cpp @@ -340,6 +340,41 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertVector) { SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false); } +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, + CompositeInsertVectorIntoMatrix) { + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %mat2v2float = OpTypeMatrix %v2float 2 + %float_0 = OpConstant %float 0 + %float_1 = OpConstant %float 1 + %float_2 = OpConstant %float 2 + %v2float_01 = OpConstantComposite %v2float %float_0 %float_1 + %v2float_12 = OpConstantComposite %v2float %float_1 %float_2 + +; CHECK: %10 = OpConstantComposite %v2float %float_0 %float_1 +; CHECK: %11 = OpConstantComposite %v2float %float_1 %float_2 +; CHECK: %12 = OpConstantComposite %mat2v2float %11 %11 +%mat2v2float_1212 = OpConstantComposite %mat2v2float %v2float_12 %v2float_12 + +; CHECK: %15 = OpConstantComposite %mat2v2float %10 %11 + %spec_0 = OpSpecConstantOp %mat2v2float CompositeInsert %v2float_01 %mat2v2float_1212 0 + %1 = OpFunction %void None %3 + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false); +} + TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrix) { const std::string test = R"( @@ -374,7 +409,39 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrix) { SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false); } -TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertNull) { +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertFloatNull) { + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %float_1 = OpConstant %float 1 + +; CHECK: %7 = OpConstantNull %float +; CHECK: %8 = OpConstantComposite %v3float %7 %7 %7 +; CHECK: %12 = OpConstantComposite %v3float %7 %7 %float_1 + %null = OpConstantNull %float + %spec_0 = OpConstantComposite %v3float %null %null %null + %spec_1 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_0 2 + +; CHECK: %float_1_0 = OpConstant %float 1 + %spec_2 = OpSpecConstantOp %float CompositeExtract %spec_1 2 + %1 = OpFunction %void None %3 + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false); +} + +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, + CompositeInsertFloatSetNull) { const std::string test = R"( OpCapability Shader @@ -384,16 +451,222 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertNull) { %void = OpTypeVoid %3 = OpTypeFunction %void %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %float_1 = OpConstant %float 1 + +; CHECK: %7 = OpConstantNull %float +; CHECK: %8 = OpConstantComposite %v3float %7 %7 %float_1 +; CHECK: %12 = OpConstantComposite %v3float %7 %7 %7 + %null = OpConstantNull %float + %spec_0 = OpConstantComposite %v3float %null %null %float_1 + %spec_1 = OpSpecConstantOp %v3float CompositeInsert %null %spec_0 2 + +; CHECK: %13 = OpConstantNull %float + %spec_2 = OpSpecConstantOp %float CompositeExtract %spec_1 2 + %1 = OpFunction %void None %3 + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false); +} + +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertVectorNull) { + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %float_1 = OpConstant %float 1 + %null = OpConstantNull %v3float + +; CHECK: %11 = OpConstantNull %float +; CHECK: %12 = OpConstantComposite %v3float %11 %11 %float_1 + %spec_0 = OpSpecConstantOp %v3float CompositeInsert %float_1 %null 2 + + +; CHECK: %float_1_0 = OpConstant %float 1 + %spec_1 = OpSpecConstantOp %float CompositeExtract %spec_0 2 + %1 = OpFunction %void None %3 + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false); +} + +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, + CompositeInsertNullVectorIntoMatrix) { + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %mat2v2float = OpTypeMatrix %v2float 2 + %null = OpConstantNull %mat2v2float + %float_1 = OpConstant %float 1 + %float_2 = OpConstant %float 2 + %v2float_12 = OpConstantComposite %v2float %float_1 %float_2 + +; CHECK: %13 = OpConstantNull %v2float +; CHECK: %14 = OpConstantComposite %mat2v2float %10 %13 + %spec_0 = OpSpecConstantOp %mat2v2float CompositeInsert %v2float_12 %null 0 + %1 = OpFunction %void None %3 + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false); +} + +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, + CompositeInsertVectorKeepNull) { + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %float_0 = OpConstant %float 0 + %null_float = OpConstantNull %float + %null_vec = OpConstantNull %v3float + +; CHECK: %15 = OpConstantComposite %v3float %7 %7 %float_0 + %spec_0 = OpSpecConstantOp %v3float CompositeInsert %float_0 %null_vec 2 + +; CHECK: %float_0_0 = OpConstant %float 0 + %spec_1 = OpSpecConstantOp %float CompositeExtract %spec_0 2 + +; CHECK: %17 = OpConstantComposite %v3float %7 %7 %7 + %spec_2 = OpSpecConstantOp %v3float CompositeInsert %null_float %null_vec 2 + +; CHECK: %18 = OpConstantNull %float + %spec_3 = OpSpecConstantOp %float CompositeExtract %spec_2 2 + %1 = OpFunction %void None %3 + %label = OpLabel + %add = OpFAdd %float %spec_3 %spec_3 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false); +} + +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, + CompositeInsertVectorChainNull) { + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %float_1 = OpConstant %float 1 + %null = OpConstantNull %v3float + +; CHECK: %15 = OpConstantNull %float +; CHECK: %16 = OpConstantComposite %v3float %15 %15 %float_1 +; CHECK: %17 = OpConstantComposite %v3float %15 %float_1 %float_1 +; CHECK: %18 = OpConstantComposite %v3float %float_1 %float_1 %float_1 + %spec_0 = OpSpecConstantOp %v3float CompositeInsert %float_1 %null 2 + %spec_1 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_0 1 + %spec_2 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_1 0 + +; CHECK: %float_1_0 = OpConstant %float 1 +; CHECK: %float_1_1 = OpConstant %float 1 +; CHECK: %float_1_2 = OpConstant %float 1 + %spec_3 = OpSpecConstantOp %float CompositeExtract %spec_2 0 + %spec_4 = OpSpecConstantOp %float CompositeExtract %spec_2 1 + %spec_5 = OpSpecConstantOp %float CompositeExtract %spec_2 2 + %1 = OpFunction %void None %3 + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false); +} + +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, + CompositeInsertVectorChainReset) { + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %float_1 = OpConstant %float 1 + %null = OpConstantNull %float +; CHECK: %8 = OpConstantComposite %v3float %7 %7 %float_1 + %spec_0 = OpConstantComposite %v3float %null %null %float_1 + + ; set to null +; CHECK: %13 = OpConstantComposite %v3float %7 %7 %7 + %spec_1 = OpSpecConstantOp %v3float CompositeInsert %null %spec_0 2 + + ; set to back to original value +; CHECK: %14 = OpConstantComposite %v3float %7 %7 %float_1 + %spec_2 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_1 2 + +; CHECK: %float_1_0 = OpConstant %float 1 + %spec_3 = OpSpecConstantOp %float CompositeExtract %spec_2 2 + %1 = OpFunction %void None %3 + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false); +} + +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrixNull) { + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + %void = OpTypeVoid + %func = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 0 %v2float = OpTypeVector %float 2 %mat2v2float = OpTypeMatrix %v2float 2 %null = OpConstantNull %mat2v2float %float_1 = OpConstant %float 1 - %v2float_1 = OpConstantComposite %v2float %float_1 %float_1 - %mat2v2_1 = OpConstantComposite %mat2v2float %v2float_1 %v2float_1 - ; CHECK: %13 = OpConstantNull %mat2v2float - %14 = OpSpecConstantOp %mat2v2float CompositeInsert %mat2v2_1 %null 0 0 - %1 = OpFunction %void None %3 - %16 = OpLabel + ; CHECK: %13 = OpConstantNull %v2float + ; CHECK: %14 = OpConstantNull %float + ; CHECK: %15 = OpConstantComposite %v2float %float_1 %14 + ; CHECK: %16 = OpConstantComposite %mat2v2float %13 %15 + %spec = OpSpecConstantOp %mat2v2float CompositeInsert %float_1 %null 1 0 +; extra type def to make sure new type def are not just thrown at end + %v2int = OpTypeVector %int 2 + %main = OpFunction %void None %func + %label = OpLabel OpReturn OpFunctionEnd )"; |