aboutsummaryrefslogtreecommitdiff
path: root/source/opt/desc_sroa.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/opt/desc_sroa.cpp')
-rw-r--r--source/opt/desc_sroa.cpp224
1 files changed, 91 insertions, 133 deletions
diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp
index 5e950069..bcbdde94 100644
--- a/source/opt/desc_sroa.cpp
+++ b/source/opt/desc_sroa.cpp
@@ -14,10 +14,19 @@
#include "source/opt/desc_sroa.h"
+#include "source/opt/desc_sroa_util.h"
#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
+namespace {
+
+bool IsDecorationBinding(Instruction* inst) {
+ if (inst->opcode() != SpvOpDecorate) return false;
+ return inst->GetSingleWordInOperand(1u) == SpvDecorationBinding;
+}
+
+} // namespace
Pass::Status DescriptorScalarReplacement::Process() {
bool modified = false;
@@ -25,7 +34,7 @@ Pass::Status DescriptorScalarReplacement::Process() {
std::vector<Instruction*> vars_to_kill;
for (Instruction& var : context()->types_values()) {
- if (IsCandidate(&var)) {
+ if (descsroautil::IsDescriptorArray(context(), &var)) {
modified = true;
if (!ReplaceCandidate(&var)) {
return Status::Failure;
@@ -41,72 +50,6 @@ Pass::Status DescriptorScalarReplacement::Process() {
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}
-bool DescriptorScalarReplacement::IsCandidate(Instruction* var) {
- if (var->opcode() != SpvOpVariable) {
- return false;
- }
-
- uint32_t ptr_type_id = var->type_id();
- Instruction* ptr_type_inst =
- context()->get_def_use_mgr()->GetDef(ptr_type_id);
- if (ptr_type_inst->opcode() != SpvOpTypePointer) {
- return false;
- }
-
- uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1);
- Instruction* var_type_inst =
- context()->get_def_use_mgr()->GetDef(var_type_id);
- if (var_type_inst->opcode() != SpvOpTypeArray &&
- var_type_inst->opcode() != SpvOpTypeStruct) {
- return false;
- }
-
- // All structures with descriptor assignments must be replaced by variables,
- // one for each of their members - with the exceptions of buffers.
- if (IsTypeOfStructuredBuffer(var_type_inst)) {
- return false;
- }
-
- bool has_desc_set_decoration = false;
- context()->get_decoration_mgr()->ForEachDecoration(
- var->result_id(), SpvDecorationDescriptorSet,
- [&has_desc_set_decoration](const Instruction&) {
- has_desc_set_decoration = true;
- });
- if (!has_desc_set_decoration) {
- return false;
- }
-
- bool has_binding_decoration = false;
- context()->get_decoration_mgr()->ForEachDecoration(
- var->result_id(), SpvDecorationBinding,
- [&has_binding_decoration](const Instruction&) {
- has_binding_decoration = true;
- });
- if (!has_binding_decoration) {
- return false;
- }
-
- return true;
-}
-
-bool DescriptorScalarReplacement::IsTypeOfStructuredBuffer(
- const Instruction* type) const {
- if (type->opcode() != SpvOpTypeStruct) {
- return false;
- }
-
- // All buffers have offset decorations for members of their structure types.
- // This is how we distinguish it from a structure of descriptors.
- bool has_offset_decoration = false;
- context()->get_decoration_mgr()->ForEachDecoration(
- type->result_id(), SpvDecorationOffset,
- [&has_offset_decoration](const Instruction&) {
- has_offset_decoration = true;
- });
- return has_offset_decoration;
-}
-
bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
std::vector<Instruction*> access_chain_work_list;
std::vector<Instruction*> load_work_list;
@@ -162,16 +105,15 @@ bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
return false;
}
- uint32_t idx_id = use->GetSingleWordInOperand(1);
- const analysis::Constant* idx_const =
- context()->get_constant_mgr()->FindDeclaredConstant(idx_id);
- if (idx_const == nullptr) {
+ const analysis::Constant* const_index =
+ descsroautil::GetAccessChainIndexAsConst(context(), use);
+ if (const_index == nullptr) {
context()->EmitErrorMessage("Variable cannot be replaced: invalid index",
use);
return false;
}
- uint32_t idx = idx_const->GetU32();
+ uint32_t idx = const_index->GetU32();
uint32_t replacement_var = GetReplacementVariable(var, idx);
if (use->NumInOperands() == 2) {
@@ -208,39 +150,12 @@ uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var,
uint32_t idx) {
auto replacement_vars = replacement_variables_.find(var);
if (replacement_vars == replacement_variables_.end()) {
- uint32_t ptr_type_id = var->type_id();
- Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
- assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
- "Variable should be a pointer to an array or structure.");
- uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
- Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(pointee_type_id);
- const bool is_array = pointee_type_inst->opcode() == SpvOpTypeArray;
- const bool is_struct = pointee_type_inst->opcode() == SpvOpTypeStruct;
- assert((is_array || is_struct) &&
- "Variable should be a pointer to an array or structure.");
-
- // For arrays, each array element should be replaced with a new replacement
- // variable
- if (is_array) {
- uint32_t array_len_id = pointee_type_inst->GetSingleWordInOperand(1);
- const analysis::Constant* array_len_const =
- context()->get_constant_mgr()->FindDeclaredConstant(array_len_id);
- assert(array_len_const != nullptr && "Array length must be a constant.");
- uint32_t array_len = array_len_const->GetU32();
-
- replacement_vars = replacement_variables_
- .insert({var, std::vector<uint32_t>(array_len, 0)})
- .first;
- }
- // For structures, each member should be replaced with a new replacement
- // variable
- if (is_struct) {
- const uint32_t num_members = pointee_type_inst->NumInOperands();
- replacement_vars =
- replacement_variables_
- .insert({var, std::vector<uint32_t>(num_members, 0)})
- .first;
- }
+ uint32_t number_of_elements =
+ descsroautil::GetNumberOfElementsForArrayOrStruct(context(), var);
+ replacement_vars =
+ replacement_variables_
+ .insert({var, std::vector<uint32_t>(number_of_elements, 0)})
+ .first;
}
if (replacement_vars->second[idx] == 0) {
@@ -250,6 +165,74 @@ uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var,
return replacement_vars->second[idx];
}
+void DescriptorScalarReplacement::CopyDecorationsForNewVariable(
+ Instruction* old_var, uint32_t index, uint32_t new_var_id,
+ uint32_t new_var_ptr_type_id, const bool is_old_var_array,
+ const bool is_old_var_struct, Instruction* old_var_type) {
+ // Handle OpDecorate instructions.
+ for (auto old_decoration :
+ get_decoration_mgr()->GetDecorationsFor(old_var->result_id(), true)) {
+ uint32_t new_binding = 0;
+ if (IsDecorationBinding(old_decoration)) {
+ new_binding = GetNewBindingForElement(
+ old_decoration->GetSingleWordInOperand(2), index, new_var_ptr_type_id,
+ is_old_var_array, is_old_var_struct, old_var_type);
+ }
+ CreateNewDecorationForNewVariable(old_decoration, new_var_id, new_binding);
+ }
+
+ // Handle OpMemberDecorate instructions.
+ for (auto old_decoration : get_decoration_mgr()->GetDecorationsFor(
+ old_var_type->result_id(), true)) {
+ assert(old_decoration->opcode() == SpvOpMemberDecorate);
+ if (old_decoration->GetSingleWordInOperand(1u) != index) continue;
+ CreateNewDecorationForMemberDecorate(old_decoration, new_var_id);
+ }
+}
+
+uint32_t DescriptorScalarReplacement::GetNewBindingForElement(
+ uint32_t old_binding, uint32_t index, uint32_t new_var_ptr_type_id,
+ const bool is_old_var_array, const bool is_old_var_struct,
+ Instruction* old_var_type) {
+ if (is_old_var_array) {
+ return old_binding + index * GetNumBindingsUsedByType(new_var_ptr_type_id);
+ }
+ if (is_old_var_struct) {
+ // The binding offset that should be added is the sum of binding
+ // numbers used by previous members of the current struct.
+ uint32_t new_binding = old_binding;
+ for (uint32_t i = 0; i < index; ++i) {
+ new_binding +=
+ GetNumBindingsUsedByType(old_var_type->GetSingleWordInOperand(i));
+ }
+ return new_binding;
+ }
+ return old_binding;
+}
+
+void DescriptorScalarReplacement::CreateNewDecorationForNewVariable(
+ Instruction* old_decoration, uint32_t new_var_id, uint32_t new_binding) {
+ assert(old_decoration->opcode() == SpvOpDecorate);
+ std::unique_ptr<Instruction> new_decoration(old_decoration->Clone(context()));
+ new_decoration->SetInOperand(0, {new_var_id});
+
+ if (IsDecorationBinding(new_decoration.get())) {
+ new_decoration->SetInOperand(2, {new_binding});
+ }
+ context()->AddAnnotationInst(std::move(new_decoration));
+}
+
+void DescriptorScalarReplacement::CreateNewDecorationForMemberDecorate(
+ Instruction* old_member_decoration, uint32_t new_var_id) {
+ std::vector<Operand> operands(
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {new_var_id}}});
+ auto new_decorate_operand_begin = old_member_decoration->begin() + 2u;
+ auto new_decorate_operand_end = old_member_decoration->end();
+ operands.insert(operands.end(), new_decorate_operand_begin,
+ new_decorate_operand_end);
+ get_decoration_mgr()->AddDecoration(SpvOpDecorate, std::move(operands));
+}
+
uint32_t DescriptorScalarReplacement::CreateReplacementVariable(
Instruction* var, uint32_t idx) {
// The storage class for the new variable is the same as the original.
@@ -285,33 +268,8 @@ uint32_t DescriptorScalarReplacement::CreateReplacementVariable(
{static_cast<uint32_t>(storage_class)}}}));
context()->AddGlobalValue(std::move(variable));
- // Copy all of the decorations to the new variable. The only difference is
- // the Binding decoration needs to be adjusted.
- for (auto old_decoration :
- get_decoration_mgr()->GetDecorationsFor(var->result_id(), true)) {
- assert(old_decoration->opcode() == SpvOpDecorate);
- std::unique_ptr<Instruction> new_decoration(
- old_decoration->Clone(context()));
- new_decoration->SetInOperand(0, {id});
-
- uint32_t decoration = new_decoration->GetSingleWordInOperand(1u);
- if (decoration == SpvDecorationBinding) {
- uint32_t new_binding = new_decoration->GetSingleWordInOperand(2);
- if (is_array) {
- new_binding += idx * GetNumBindingsUsedByType(ptr_element_type_id);
- }
- if (is_struct) {
- // The binding offset that should be added is the sum of binding numbers
- // used by previous members of the current struct.
- for (uint32_t i = 0; i < idx; ++i) {
- new_binding += GetNumBindingsUsedByType(
- pointee_type_inst->GetSingleWordInOperand(i));
- }
- }
- new_decoration->SetInOperand(2, {new_binding});
- }
- context()->AddAnnotationInst(std::move(new_decoration));
- }
+ CopyDecorationsForNewVariable(var, idx, id, ptr_element_type_id, is_array,
+ is_struct, pointee_type_inst);
// Create a new OpName for the replacement variable.
std::vector<std::unique_ptr<Instruction>> names_to_add;
@@ -377,7 +335,7 @@ uint32_t DescriptorScalarReplacement::GetNumBindingsUsedByType(
// The number of bindings consumed by a structure is the sum of the bindings
// used by its members.
if (type_inst->opcode() == SpvOpTypeStruct &&
- !IsTypeOfStructuredBuffer(type_inst)) {
+ !descsroautil::IsTypeOfStructuredBuffer(context(), type_inst)) {
uint32_t sum = 0;
for (uint32_t i = 0; i < type_inst->NumInOperands(); i++)
sum += GetNumBindingsUsedByType(type_inst->GetSingleWordInOperand(i));