aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSteven Perron <stevenperron@google.com>2022-08-31 11:06:15 -0400
committerGitHub <noreply@github.com>2022-08-31 11:06:15 -0400
commitd51dc53d2caf25024c7721647ed2a23819bd509c (patch)
treefbb897dfed3d0b83ec6b7e0432fb09ecf783ba1a
parentfca39d5cb4420f391aacaa0d506c65544663754b (diff)
downloadSPIRV-Tools-d51dc53d2caf25024c7721647ed2a23819bd509c.tar.gz
Improve algorithm to reorder blocks in a function (#4911)
* Improve algorithm to reorder blocks in a function In dead branch elimination, blocks can end up in a the wrong order, so there is code to reorder the blocks in structured order. The problem is that the algorithm to do that is very poor. It involves many searchs in the function for the correct position to place the block, as well as moving many block in the vector. The solution is to write a specialized function in the function class that will reorder the blocks in structured order. After computing the structured order, reordering the block can be done in linear time, with very little overhead.
-rw-r--r--source/opt/dead_branch_elim_pass.cpp13
-rw-r--r--source/opt/function.cpp8
-rw-r--r--source/opt/function.h41
-rw-r--r--test/opt/function_test.cpp59
4 files changed, 110 insertions, 11 deletions
diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp
index cc616ca6..d99b7f78 100644
--- a/source/opt/dead_branch_elim_pass.cpp
+++ b/source/opt/dead_branch_elim_pass.cpp
@@ -459,17 +459,8 @@ void DeadBranchElimPass::FixBlockOrder() {
};
// Reorders blocks according to structured order.
- ProcessFunction reorder_structured = [this](Function* function) {
- std::list<BasicBlock*> order;
- context()->cfg()->ComputeStructuredOrder(function, &*function->begin(),
- &order);
- std::vector<BasicBlock*> blocks;
- for (auto block : order) {
- blocks.push_back(block);
- }
- for (uint32_t i = 1; i < blocks.size(); ++i) {
- function->MoveBasicBlockToAfter(blocks[i]->id(), blocks[i - 1]);
- }
+ ProcessFunction reorder_structured = [](Function* function) {
+ function->ReorderBasicBlocksInStructuredOrder();
return true;
};
diff --git a/source/opt/function.cpp b/source/opt/function.cpp
index 38c66951..bb51df3f 100644
--- a/source/opt/function.cpp
+++ b/source/opt/function.cpp
@@ -270,5 +270,13 @@ std::string Function::PrettyPrint(uint32_t options) const {
});
return str.str();
}
+
+void Function::ReorderBasicBlocksInStructuredOrder() {
+ std::list<BasicBlock*> order;
+ IRContext* context = this->def_inst_->context();
+ context->cfg()->ComputeStructuredOrder(this, blocks_[0].get(), &order);
+ ReorderBasicBlocks(order.begin(), order.end());
+}
+
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/function.h b/source/opt/function.h
index 917bf584..146cbe34 100644
--- a/source/opt/function.h
+++ b/source/opt/function.h
@@ -19,6 +19,7 @@
#include <functional>
#include <memory>
#include <string>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -180,7 +181,19 @@ class Function {
// Returns true is a function declaration and not a function definition.
bool IsDeclaration() { return begin() == end(); }
+ // Reorders the basic blocks in the function to match the structured order.
+ void ReorderBasicBlocksInStructuredOrder();
+
private:
+ // Reorders the basic blocks in the function to match the order given by the
+ // range |{begin,end}|. The range must contain every basic block in the
+ // function, and no extras.
+ template <class It>
+ void ReorderBasicBlocks(It begin, It end);
+
+ template <class It>
+ bool ContainsAllBlocksInTheFunction(It begin, It end);
+
// The OpFunction instruction that begins the definition of this function.
std::unique_ptr<Instruction> def_inst_;
// All parameters to this function.
@@ -262,6 +275,34 @@ inline void Function::AddNonSemanticInstruction(
non_semantic_.emplace_back(std::move(non_semantic));
}
+template <class It>
+void Function::ReorderBasicBlocks(It begin, It end) {
+ // Asserts to make sure every node in the function is in new_order.
+ assert(ContainsAllBlocksInTheFunction(begin, end));
+
+ // We have a pointer to all the elements in order, so we can release all
+ // pointers in |block_|, and then create the new unique pointers from |{begin,
+ // end}|.
+ std::for_each(blocks_.begin(), blocks_.end(),
+ [](std::unique_ptr<BasicBlock>& bb) { bb.release(); });
+ std::transform(begin, end, blocks_.begin(), [](BasicBlock* bb) {
+ return std::unique_ptr<BasicBlock>(bb);
+ });
+}
+
+template <class It>
+bool Function::ContainsAllBlocksInTheFunction(It begin, It end) {
+ std::unordered_multiset<BasicBlock*> range(begin, end);
+ if (range.size() != blocks_.size()) {
+ return false;
+ }
+
+ for (auto& bb : blocks_) {
+ if (range.count(bb.get()) == 0) return false;
+ }
+ return true;
+}
+
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/function_test.cpp b/test/opt/function_test.cpp
index af25bacc..34a03871 100644
--- a/test/opt/function_test.cpp
+++ b/test/opt/function_test.cpp
@@ -296,6 +296,65 @@ OpFunctionEnd
EXPECT_EQ(1, non_semantic_ids.count(8));
}
+TEST(FunctionTest, ReorderBlocksinStructuredOrder) {
+ // The spir-v has the basic block in a random order. We want to reorder them
+ // in structured order.
+ const std::string text = R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %100 "PSMain"
+ OpExecutionMode %PSMain OriginUpperLeft
+ OpSource HLSL 600
+ %int = OpTypeInt 32 1
+ %void = OpTypeVoid
+ %19 = OpTypeFunction %void
+ %bool = OpTypeBool
+%undef_bool = OpUndef %bool
+%undef_int = OpUndef %int
+ %100 = OpFunction %void None %19
+ %11 = OpLabel
+ OpSelectionMerge %10 None
+ OpSwitch %undef_int %3 0 %2 10 %1
+ %2 = OpLabel
+ OpReturn
+ %7 = OpLabel
+ OpBranch %8
+ %3 = OpLabel
+ OpBranch %4
+ %10 = OpLabel
+ OpReturn
+ %9 = OpLabel
+ OpBranch %10
+ %8 = OpLabel
+ OpBranch %4
+ %4 = OpLabel
+ OpLoopMerge %9 %8 None
+ OpBranchConditional %undef_bool %5 %9
+ %1 = OpLabel
+ OpReturn
+ %6 = OpLabel
+ OpBranch %7
+ %5 = OpLabel
+ OpSelectionMerge %7 None
+ OpBranchConditional %undef_bool %6 %7
+ OpFunctionEnd
+)";
+
+ std::unique_ptr<IRContext> ctx =
+ spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ASSERT_TRUE(ctx);
+ auto* func = spvtest::GetFunction(ctx->module(), 100);
+ ASSERT_TRUE(func);
+ func->ReorderBasicBlocksInStructuredOrder();
+
+ auto first_block = func->begin();
+ auto bb = first_block;
+ for (++bb; bb != func->end(); ++bb) {
+ EXPECT_EQ(bb->id(), (bb - first_block));
+ }
+}
+
} // namespace
} // namespace opt
} // namespace spvtools