aboutsummaryrefslogtreecommitdiff
path: root/source/validate_datarules.cpp
blob: 3d50328cf866169b9062834693371975c408d676 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
// Copyright (c) 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Ensures Data Rules are followed according to the specifications.

#include "validate.h"

#include <cassert>
#include <sstream>
#include <string>

#include "diagnostic.h"
#include "opcode.h"
#include "operand.h"
#include "val/instruction.h"
#include "val/validation_state.h"

using libspirv::CapabilitySet;
using libspirv::DiagnosticStream;
using libspirv::ValidationState_t;

namespace {

// Validates that the number of components in the vector is valid.
// Vector types can only be parameterized as having 2, 3, or 4 components.
// If the Vector16 capability is added, 8 and 16 components are also allowed.
spv_result_t ValidateVecNumComponents(ValidationState_t& _,
                                      const spv_parsed_instruction_t* inst) {
  // Operand 2 specifies the number of components in the vector.
  const uint32_t num_components = inst->words[inst->operands[2].offset];
  if (num_components == 2 || num_components == 3 || num_components == 4) {
    return SPV_SUCCESS;
  }
  if (num_components == 8 || num_components == 16) {
    if (_.HasCapability(SpvCapabilityVector16)) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA)
           << "Having " << num_components << " components for "
           << spvOpcodeString(static_cast<SpvOp>(inst->opcode))
           << " requires the Vector16 capability";
  }
  return _.diag(SPV_ERROR_INVALID_DATA)
         << "Illegal number of components (" << num_components << ") for "
         << spvOpcodeString(static_cast<SpvOp>(inst->opcode));
}

// Validates that the number of bits specifed for a float type is valid.
// Scalar floating-point types can be parameterized only with 32-bits.
// Float16 capability allows using a 16-bit OpTypeFloat.
// Float16Buffer capability allows creation of a 16-bit OpTypeFloat.
// Float64 capability allows using a 64-bit OpTypeFloat.
spv_result_t ValidateFloatSize(ValidationState_t& _,
                               const spv_parsed_instruction_t* inst) {
  // Operand 1 is the number of bits for this float
  const uint32_t num_bits = inst->words[inst->operands[1].offset];
  if (num_bits == 32) {
    return SPV_SUCCESS;
  }
  if (num_bits == 16) {
    if (_.HasCapability(SpvCapabilityFloat16) ||
        _.HasCapability(SpvCapabilityFloat16Buffer)) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA)
           << "Using a 16-bit floating point "
           << "type requires the Float16 or Float16Buffer capability.";
  }
  if (num_bits == 64) {
    if (_.HasCapability(SpvCapabilityFloat64)) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA)
           << "Using a 64-bit floating point "
           << "type requires the Float64 capability.";
  }
  return _.diag(SPV_ERROR_INVALID_DATA)
         << "Invalid number of bits (" << num_bits << ") used for OpTypeFloat.";
}

// Validates that the number of bits specified for an Int type is valid.
// Scalar integer types can be parameterized only with 32-bits.
// Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
// integers, respectively.
spv_result_t ValidateIntSize(ValidationState_t& _,
                             const spv_parsed_instruction_t* inst) {
  // Operand 1 is the number of bits for this integer.
  const uint32_t num_bits = inst->words[inst->operands[1].offset];
  if (num_bits == 32) {
    return SPV_SUCCESS;
  }
  if (num_bits == 8) {
    if (_.HasCapability(SpvCapabilityInt8)) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA)
           << "Using an 8-bit integer type requires the Int8 capability.";
  }
  if (num_bits == 16) {
    if (_.HasCapability(SpvCapabilityInt16)) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA)
           << "Using a 16-bit integer type requires the Int16 capability.";
  }
  if (num_bits == 64) {
    if (_.HasCapability(SpvCapabilityInt64)) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA)
           << "Using a 64-bit integer type requires the Int64 capability.";
  }
  return _.diag(SPV_ERROR_INVALID_DATA) << "Invalid number of bits ("
                                        << num_bits << ") used for OpTypeInt.";
}

// Validates that the matrix is parameterized with floating-point types.
spv_result_t ValidateMatrixColumnType(ValidationState_t& _,
                                      const spv_parsed_instruction_t* inst) {
  // Find the component type of matrix columns (must be vector).
  // Operand 1 is the <id> of the type specified for matrix columns.
  auto type_id = inst->words[inst->operands[1].offset];
  auto col_type_instr = _.FindDef(type_id);
  if (col_type_instr->opcode() != SpvOpTypeVector) {
    return _.diag(SPV_ERROR_INVALID_ID)
           << "Columns in a matrix must be of type vector.";
  }

  // Trace back once more to find out the type of components in the vector.
  // Operand 1 is the <id> of the type of data in the vector.
  auto comp_type_id =
      col_type_instr->words()[col_type_instr->operands()[1].offset];
  auto comp_type_instruction = _.FindDef(comp_type_id);
  if (comp_type_instruction->opcode() != SpvOpTypeFloat) {
    return _.diag(SPV_ERROR_INVALID_DATA) << "Matrix types can only be "
                                             "parameterized with "
                                             "floating-point types.";
  }
  return SPV_SUCCESS;
}

// Validates that the matrix has 2,3, or 4 columns.
spv_result_t ValidateMatrixNumCols(ValidationState_t& _,
                                   const spv_parsed_instruction_t* inst) {
  // Operand 2 is the number of columns in the matrix.
  const uint32_t num_cols = inst->words[inst->operands[2].offset];
  if (num_cols != 2 && num_cols != 3 && num_cols != 4) {
    return _.diag(SPV_ERROR_INVALID_DATA) << "Matrix types can only be "
                                             "parameterized as having only 2, "
                                             "3, or 4 columns.";
  }
  return SPV_SUCCESS;
}

// Validates that OpSpecConstant specializes to either int or float type.
spv_result_t ValidateSpecConstNumerical(ValidationState_t& _,
                                        const spv_parsed_instruction_t* inst) {
  // Operand 0 is the <id> of the type that we're specializing to.
  auto type_id = inst->words[inst->operands[0].offset];
  auto type_instruction = _.FindDef(type_id);
  auto type_opcode = type_instruction->opcode();
  if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) {
    return _.diag(SPV_ERROR_INVALID_DATA) << "Specialization constant must be "
                                             "an integer or floating-point "
                                             "number.";
  }
  return SPV_SUCCESS;
}

// Validates that OpSpecConstantTrue and OpSpecConstantFalse specialize to bool.
spv_result_t ValidateSpecConstBoolean(ValidationState_t& _,
                                      const spv_parsed_instruction_t* inst) {
  // Find out the type that we're specializing to.
  auto type_instruction = _.FindDef(inst->type_id);
  if (type_instruction->opcode() != SpvOpTypeBool) {
    return _.diag(SPV_ERROR_INVALID_ID) << "Specialization constant must be "
                                           "a boolean type.";
  }
  return SPV_SUCCESS;
}

// Records the <id> of the forward pointer to be used for validation.
spv_result_t ValidateForwardPointer(ValidationState_t& _,
                                    const spv_parsed_instruction_t* inst) {
  // Record the <id> (which is operand 0) to ensure it's used properly.
  // OpTypeStruct can only include undefined pointers that are
  // previously declared as a ForwardPointer
  return (_.RegisterForwardPointer(inst->words[inst->operands[0].offset]));
}

// Validates that any undefined component of the struct is a forward pointer.
// It is valid to declare a forward pointer, and use its <id> as one of the
// components of a struct.
spv_result_t ValidateStruct(ValidationState_t& _,
                            const spv_parsed_instruction_t* inst) {
  // Struct components are operands 1, 2, etc.
  for (unsigned i = 1; i < inst->num_operands; i++) {
    auto type_id = inst->words[inst->operands[i].offset];
    auto type_instruction = _.FindDef(type_id);
    if (type_instruction == nullptr && !_.IsForwardPointer(type_id)) {
      return _.diag(SPV_ERROR_INVALID_ID)
             << "Forward reference operands in an OpTypeStruct must first be "
                "declared using OpTypeForwardPointer.";
    }
  }
  return SPV_SUCCESS;
}

}  // anonymous namespace

namespace libspirv {

// Validates that Data Rules are followed according to the specifications.
// (Data Rules subsection of 2.16.1 Universal Validation Rules)
spv_result_t DataRulesPass(ValidationState_t& _,
                           const spv_parsed_instruction_t* inst) {
  switch (inst->opcode) {
    case SpvOpTypeVector: {
      if (auto error = ValidateVecNumComponents(_, inst)) return error;
      break;
    }
    case SpvOpTypeFloat: {
      if (auto error = ValidateFloatSize(_, inst)) return error;
      break;
    }
    case SpvOpTypeInt: {
      if (auto error = ValidateIntSize(_, inst)) return error;
      break;
    }
    case SpvOpTypeMatrix: {
      if (auto error = ValidateMatrixColumnType(_, inst)) return error;
      if (auto error = ValidateMatrixNumCols(_, inst)) return error;
      break;
    }
    // TODO(ehsan): Add OpSpecConstantComposite validation code.
    // TODO(ehsan): Add OpSpecConstantOp validation code (if any).
    case SpvOpSpecConstant: {
      if (auto error = ValidateSpecConstNumerical(_, inst)) return error;
      break;
    }
    case SpvOpSpecConstantFalse:
    case SpvOpSpecConstantTrue: {
      if (auto error = ValidateSpecConstBoolean(_, inst)) return error;
      break;
    }
    case SpvOpTypeForwardPointer: {
      if (auto error = ValidateForwardPointer(_, inst)) return error;
      break;
    }
    case SpvOpTypeStruct: {
      if (auto error = ValidateStruct(_, inst)) return error;
      break;
    }
    // TODO(ehsan): add more data rules validation here.
    default: { break; }
  }

  return SPV_SUCCESS;
}

}  // namespace libspirv