diff options
author | Steven Perron <stevenperron@google.com> | 2022-08-31 11:06:15 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-31 11:06:15 -0400 |
commit | d51dc53d2caf25024c7721647ed2a23819bd509c (patch) | |
tree | fbb897dfed3d0b83ec6b7e0432fb09ecf783ba1a | |
parent | fca39d5cb4420f391aacaa0d506c65544663754b (diff) | |
download | SPIRV-Tools-d51dc53d2caf25024c7721647ed2a23819bd509c.tar.gz |
Improve algorithm to reorder blocks in a function (#4911)
* Improve algorithm to reorder blocks in a function
In dead branch elimination, blocks can end up in a the wrong order, so
there is code to reorder the blocks in structured order. The problem is
that the algorithm to do that is very poor. It involves many searchs in
the function for the correct position to place the block, as well as
moving many block in the vector.
The solution is to write a specialized function in the function class
that will reorder the blocks in structured order. After computing the
structured order, reordering the block can be done in linear time, with
very little overhead.
-rw-r--r-- | source/opt/dead_branch_elim_pass.cpp | 13 | ||||
-rw-r--r-- | source/opt/function.cpp | 8 | ||||
-rw-r--r-- | source/opt/function.h | 41 | ||||
-rw-r--r-- | test/opt/function_test.cpp | 59 |
4 files changed, 110 insertions, 11 deletions
diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp index cc616ca6..d99b7f78 100644 --- a/source/opt/dead_branch_elim_pass.cpp +++ b/source/opt/dead_branch_elim_pass.cpp @@ -459,17 +459,8 @@ void DeadBranchElimPass::FixBlockOrder() { }; // Reorders blocks according to structured order. - ProcessFunction reorder_structured = [this](Function* function) { - std::list<BasicBlock*> order; - context()->cfg()->ComputeStructuredOrder(function, &*function->begin(), - &order); - std::vector<BasicBlock*> blocks; - for (auto block : order) { - blocks.push_back(block); - } - for (uint32_t i = 1; i < blocks.size(); ++i) { - function->MoveBasicBlockToAfter(blocks[i]->id(), blocks[i - 1]); - } + ProcessFunction reorder_structured = [](Function* function) { + function->ReorderBasicBlocksInStructuredOrder(); return true; }; diff --git a/source/opt/function.cpp b/source/opt/function.cpp index 38c66951..bb51df3f 100644 --- a/source/opt/function.cpp +++ b/source/opt/function.cpp @@ -270,5 +270,13 @@ std::string Function::PrettyPrint(uint32_t options) const { }); return str.str(); } + +void Function::ReorderBasicBlocksInStructuredOrder() { + std::list<BasicBlock*> order; + IRContext* context = this->def_inst_->context(); + context->cfg()->ComputeStructuredOrder(this, blocks_[0].get(), &order); + ReorderBasicBlocks(order.begin(), order.end()); +} + } // namespace opt } // namespace spvtools diff --git a/source/opt/function.h b/source/opt/function.h index 917bf584..146cbe34 100644 --- a/source/opt/function.h +++ b/source/opt/function.h @@ -19,6 +19,7 @@ #include <functional> #include <memory> #include <string> +#include <unordered_set> #include <utility> #include <vector> @@ -180,7 +181,19 @@ class Function { // Returns true is a function declaration and not a function definition. bool IsDeclaration() { return begin() == end(); } + // Reorders the basic blocks in the function to match the structured order. + void ReorderBasicBlocksInStructuredOrder(); + private: + // Reorders the basic blocks in the function to match the order given by the + // range |{begin,end}|. The range must contain every basic block in the + // function, and no extras. + template <class It> + void ReorderBasicBlocks(It begin, It end); + + template <class It> + bool ContainsAllBlocksInTheFunction(It begin, It end); + // The OpFunction instruction that begins the definition of this function. std::unique_ptr<Instruction> def_inst_; // All parameters to this function. @@ -262,6 +275,34 @@ inline void Function::AddNonSemanticInstruction( non_semantic_.emplace_back(std::move(non_semantic)); } +template <class It> +void Function::ReorderBasicBlocks(It begin, It end) { + // Asserts to make sure every node in the function is in new_order. + assert(ContainsAllBlocksInTheFunction(begin, end)); + + // We have a pointer to all the elements in order, so we can release all + // pointers in |block_|, and then create the new unique pointers from |{begin, + // end}|. + std::for_each(blocks_.begin(), blocks_.end(), + [](std::unique_ptr<BasicBlock>& bb) { bb.release(); }); + std::transform(begin, end, blocks_.begin(), [](BasicBlock* bb) { + return std::unique_ptr<BasicBlock>(bb); + }); +} + +template <class It> +bool Function::ContainsAllBlocksInTheFunction(It begin, It end) { + std::unordered_multiset<BasicBlock*> range(begin, end); + if (range.size() != blocks_.size()) { + return false; + } + + for (auto& bb : blocks_) { + if (range.count(bb.get()) == 0) return false; + } + return true; +} + } // namespace opt } // namespace spvtools diff --git a/test/opt/function_test.cpp b/test/opt/function_test.cpp index af25bacc..34a03871 100644 --- a/test/opt/function_test.cpp +++ b/test/opt/function_test.cpp @@ -296,6 +296,65 @@ OpFunctionEnd EXPECT_EQ(1, non_semantic_ids.count(8)); } +TEST(FunctionTest, ReorderBlocksinStructuredOrder) { + // The spir-v has the basic block in a random order. We want to reorder them + // in structured order. + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %100 "PSMain" + OpExecutionMode %PSMain OriginUpperLeft + OpSource HLSL 600 + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %19 = OpTypeFunction %void + %bool = OpTypeBool +%undef_bool = OpUndef %bool +%undef_int = OpUndef %int + %100 = OpFunction %void None %19 + %11 = OpLabel + OpSelectionMerge %10 None + OpSwitch %undef_int %3 0 %2 10 %1 + %2 = OpLabel + OpReturn + %7 = OpLabel + OpBranch %8 + %3 = OpLabel + OpBranch %4 + %10 = OpLabel + OpReturn + %9 = OpLabel + OpBranch %10 + %8 = OpLabel + OpBranch %4 + %4 = OpLabel + OpLoopMerge %9 %8 None + OpBranchConditional %undef_bool %5 %9 + %1 = OpLabel + OpReturn + %6 = OpLabel + OpBranch %7 + %5 = OpLabel + OpSelectionMerge %7 None + OpBranchConditional %undef_bool %6 %7 + OpFunctionEnd +)"; + + std::unique_ptr<IRContext> ctx = + spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_TRUE(ctx); + auto* func = spvtest::GetFunction(ctx->module(), 100); + ASSERT_TRUE(func); + func->ReorderBasicBlocksInStructuredOrder(); + + auto first_block = func->begin(); + auto bb = first_block; + for (++bb; bb != func->end(); ++bb) { + EXPECT_EQ(bb->id(), (bb - first_block)); + } +} + } // namespace } // namespace opt } // namespace spvtools |