aboutsummaryrefslogtreecommitdiff
path: root/source/fuzz/transformation_replace_opselect_with_conditional_branch.cpp
blob: b2010a94be977479259d754327e478f82e5e1b18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
// Copyright (c) 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/fuzz/transformation_replace_opselect_with_conditional_branch.h"

#include "source/fuzz/fuzzer_util.h"

namespace spvtools {
namespace fuzz {
TransformationReplaceOpSelectWithConditionalBranch::
    TransformationReplaceOpSelectWithConditionalBranch(
        protobufs::TransformationReplaceOpSelectWithConditionalBranch message)
    : message_(std::move(message)) {}

TransformationReplaceOpSelectWithConditionalBranch::
    TransformationReplaceOpSelectWithConditionalBranch(
        uint32_t select_id, uint32_t true_block_id, uint32_t false_block_id) {
  message_.set_select_id(select_id);
  message_.set_true_block_id(true_block_id);
  message_.set_false_block_id(false_block_id);
}

bool TransformationReplaceOpSelectWithConditionalBranch::IsApplicable(
    opt::IRContext* ir_context,
    const TransformationContext& /* unused */) const {
  assert((message_.true_block_id() || message_.false_block_id()) &&
         "At least one of the ids must be non-zero.");

  // Check that the non-zero ids are fresh.
  std::set<uint32_t> used_ids;
  for (uint32_t id : {message_.true_block_id(), message_.false_block_id()}) {
    if (id && !CheckIdIsFreshAndNotUsedByThisTransformation(id, ir_context,
                                                            &used_ids)) {
      return false;
    }
  }

  auto instruction =
      ir_context->get_def_use_mgr()->GetDef(message_.select_id());

  // The instruction must exist and it must be an OpSelect instruction.
  if (!instruction || instruction->opcode() != spv::Op::OpSelect) {
    return false;
  }

  // Check that the condition is a scalar boolean.
  auto condition = ir_context->get_def_use_mgr()->GetDef(
      instruction->GetSingleWordInOperand(0));
  assert(condition && "The condition should always exist in a valid module.");

  auto condition_type =
      ir_context->get_type_mgr()->GetType(condition->type_id());
  if (!condition_type->AsBool()) {
    return false;
  }

  auto block = ir_context->get_instr_block(instruction);
  assert(block && "The block containing the instruction must be found");

  // The instruction must be the first in its block.
  if (instruction->unique_id() != block->begin()->unique_id()) {
    return false;
  }

  // The block must not be a merge block.
  if (ir_context->GetStructuredCFGAnalysis()->IsMergeBlock(block->id())) {
    return false;
  }

  // The block must have exactly one predecessor.
  auto predecessors = ir_context->cfg()->preds(block->id());
  if (predecessors.size() != 1) {
    return false;
  }

  uint32_t pred_id = predecessors[0];
  auto predecessor = ir_context->get_instr_block(pred_id);

  // The predecessor must not be the header of a construct and it must end with
  // OpBranch.
  if (predecessor->GetMergeInst() != nullptr ||
      predecessor->terminator()->opcode() != spv::Op::OpBranch) {
    return false;
  }

  return true;
}

void TransformationReplaceOpSelectWithConditionalBranch::Apply(
    opt::IRContext* ir_context, TransformationContext* /* unused */) const {
  auto instruction =
      ir_context->get_def_use_mgr()->GetDef(message_.select_id());

  auto block = ir_context->get_instr_block(instruction);

  auto predecessor =
      ir_context->get_instr_block(ir_context->cfg()->preds(block->id())[0]);

  // Create a new block for each non-zero id in {|message_.true_branch_id|,
  // |message_.false_branch_id|}. Make each newly-created block branch
  // unconditionally to the instruction block.
  for (uint32_t id : {message_.true_block_id(), message_.false_block_id()}) {
    if (id) {
      fuzzerutil::UpdateModuleIdBound(ir_context, id);

      // Create the new block.
      auto new_block = MakeUnique<opt::BasicBlock>(
          MakeUnique<opt::Instruction>(ir_context, spv::Op::OpLabel, 0, id,
                                       opt::Instruction::OperandList{}));

      // Add an unconditional branch from the new block to the instruction
      // block.
      new_block->AddInstruction(MakeUnique<opt::Instruction>(
          ir_context, spv::Op::OpBranch, 0, 0,
          opt::Instruction::OperandList{{SPV_OPERAND_TYPE_ID, {block->id()}}}));

      // Insert the new block right after the predecessor of the instruction
      // block.
      block->GetParent()->InsertBasicBlockBefore(std::move(new_block), block);
    }
  }

  // Delete the OpBranch instruction from the predecessor.
  ir_context->KillInst(predecessor->terminator());

  // Add an OpSelectionMerge instruction to the predecessor block, where the
  // merge block is the instruction block.
  predecessor->AddInstruction(MakeUnique<opt::Instruction>(
      ir_context, spv::Op::OpSelectionMerge, 0, 0,
      opt::Instruction::OperandList{
          {SPV_OPERAND_TYPE_ID, {block->id()}},
          {SPV_OPERAND_TYPE_SELECTION_CONTROL,
           {uint32_t(spv::SelectionControlMask::MaskNone)}}}));

  // |if_block| will be the true block, if it has been created, the instruction
  // block otherwise.
  uint32_t if_block =
      message_.true_block_id() ? message_.true_block_id() : block->id();

  // |else_block| will be the false block, if it has been created, the
  // instruction block otherwise.
  uint32_t else_block =
      message_.false_block_id() ? message_.false_block_id() : block->id();

  assert(if_block != else_block &&
         "|if_block| and |else_block| should always be different, if the "
         "transformation is applicable.");

  // Add a conditional branching instruction to the predecessor, branching to
  // |if_block| if the condition is true and to |if_false| otherwise.
  predecessor->AddInstruction(MakeUnique<opt::Instruction>(
      ir_context, spv::Op::OpBranchConditional, 0, 0,
      opt::Instruction::OperandList{
          {SPV_OPERAND_TYPE_ID, {instruction->GetSingleWordInOperand(0)}},
          {SPV_OPERAND_TYPE_ID, {if_block}},
          {SPV_OPERAND_TYPE_ID, {else_block}}}));

  // |if_pred| will be the true block, if it has been created, the existing
  // predecessor otherwise.
  uint32_t if_pred =
      message_.true_block_id() ? message_.true_block_id() : predecessor->id();

  // |else_pred| will be the false block, if it has been created, the existing
  // predecessor otherwise.
  uint32_t else_pred =
      message_.false_block_id() ? message_.false_block_id() : predecessor->id();

  // Replace the OpSelect instruction in the merge block with an OpPhi.
  // This:          OpSelect %type %cond %if %else
  // will become:   OpPhi %type %if %if_pred %else %else_pred
  instruction->SetOpcode(spv::Op::OpPhi);
  std::vector<opt::Operand> operands;

  operands.emplace_back(instruction->GetInOperand(1));
  operands.emplace_back(opt::Operand{SPV_OPERAND_TYPE_ID, {if_pred}});

  operands.emplace_back(instruction->GetInOperand(2));
  operands.emplace_back(opt::Operand{SPV_OPERAND_TYPE_ID, {else_pred}});

  instruction->SetInOperands(std::move(operands));

  // Invalidate all analyses, since the structure of the module was changed.
  ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
}

protobufs::Transformation
TransformationReplaceOpSelectWithConditionalBranch::ToMessage() const {
  protobufs::Transformation result;
  *result.mutable_replace_opselect_with_conditional_branch() = message_;
  return result;
}

std::unordered_set<uint32_t>
TransformationReplaceOpSelectWithConditionalBranch::GetFreshIds() const {
  return {message_.true_block_id(), message_.false_block_id()};
}

}  // namespace fuzz
}  // namespace spvtools