@@ -61,17 +61,22 @@ struct TosaLevel {
61
61
int32_t MAX_KERNEL = 0 ;
62
62
int32_t MAX_STRIDE = 0 ;
63
63
int32_t MAX_SCALE = 0 ;
64
-
65
- // @todo: MAX_LOG2_SIZE value and checks
64
+ int32_t MAX_LOG2_SIZE = 0 ;
65
+ int32_t MAX_NESTING = 0 ;
66
+ int32_t MAX_TENSOR_LIST_SIZE = 0 ;
66
67
67
68
bool operator ==(const TosaLevel &rhs) {
68
69
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
69
- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE ;
70
+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
71
+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
72
+ MAX_NESTING == rhs.MAX_NESTING &&
73
+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE ;
70
74
}
71
75
};
72
76
73
- static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 };
74
- static constexpr TosaLevel TOSA_LEVEL_NONE = {0 , 0 , 0 , 0 };
77
+ static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 , 31 , 6 , 64 };
78
+ static constexpr TosaLevel TOSA_LEVEL_NONE = {32 , 2147483647 , 2147483647 , 2048 ,
79
+ 63 , 256 , 256 };
75
80
76
81
// ===----------------------------------------------------------------------===//
77
82
// TOSA Validation Pass.
@@ -111,133 +116,212 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
111
116
constCheckers.emplace_back (checkConstantOperandPad);
112
117
}
113
118
114
- bool levelCheckKernel (Operation *op, int32_t v,
115
- const std::string &checkDesc) {
119
+ bool levelCheckKernel (Operation *op, int32_t v, const StringRef checkDesc) {
116
120
if (v > tosaLevel.MAX_KERNEL ) {
117
121
op->emitOpError () << " failed level check: " << checkDesc;
118
122
return false ;
119
123
}
120
124
return true ;
121
125
}
122
126
123
- bool levelCheckStride (Operation *op, int32_t v,
124
- const std::string &checkDesc) {
127
+ bool levelCheckStride (Operation *op, int32_t v, const StringRef checkDesc) {
125
128
if (v > tosaLevel.MAX_STRIDE ) {
126
129
op->emitOpError () << " failed level check: " << checkDesc;
127
130
return false ;
128
131
}
129
132
return true ;
130
133
}
131
134
132
- bool levelCheckScale (Operation *op, int32_t v, const std::string & checkDesc) {
135
+ bool levelCheckScale (Operation *op, int32_t v, const StringRef checkDesc) {
133
136
if (v > tosaLevel.MAX_SCALE ) {
134
137
op->emitOpError () << " failed level check: " << checkDesc;
135
138
return false ;
136
139
}
137
140
return true ;
138
141
}
139
142
140
- bool levelCheckRank (Operation *op, const Value &v,
141
- const std::string &checkDesc) {
143
+ bool levelCheckListSize (Operation *op, int32_t v, const StringRef checkDesc) {
144
+ if (v > tosaLevel.MAX_TENSOR_LIST_SIZE ) {
145
+ op->emitOpError () << " failed level check for MAX_TENSOR_LIST_SIZE: "
146
+ << checkDesc;
147
+ return false ;
148
+ }
149
+ return true ;
150
+ }
151
+
152
+ bool levelCheckRankAndSizes (Operation *op, const Value &v,
153
+ const StringRef operandOrResult,
154
+ int32_t highest_rank) {
142
155
if (ShapedType type = dyn_cast<ShapedType>(v.getType ())) {
143
156
if (!type.hasRank ()) {
144
157
op->emitOpError () << " failed level check: unranked tensor" ;
145
158
return false ;
146
159
}
147
- if (type.getRank () > tosaLevel.MAX_RANK ) {
148
- op->emitOpError () << " failed level check: " << checkDesc;
160
+ if (type.getRank () > highest_rank) {
161
+ op->emitOpError () << " failed level check: " << operandOrResult
162
+ << " rank(shape) <= MAX_RANK" ;
163
+ return false ;
164
+ }
165
+
166
+ auto shape = type.getShape ();
167
+ for (auto dim : shape) {
168
+ if (mlir::ShapedType::isDynamic (dim)) {
169
+ op->emitOpError () << " failed level check: " << operandOrResult
170
+ << " shape dimension cannot be dynamic" ;
171
+ return false ;
172
+ }
173
+ }
174
+
175
+ int64_t element_bits = type.getElementTypeBitWidth ();
176
+ int64_t element_bytes = std::max (INT64_C (1 ), element_bits / 8 );
177
+ int64_t size = element_bytes * type.getNumElements ();
178
+
179
+ // According to 1.11. Tensor Definitions of Tosa spec, the value of
180
+ // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
181
+ // defined in 1.7. Levels.
182
+ // For each tensor, the number of tensor elements multiplied by the
183
+ // element size in bytes must be representable as a tensor_size_t.
184
+ const int64_t max_size = (INT64_C (1 ) << tosaLevel.MAX_LOG2_SIZE ) - 1 ;
185
+ if (size > max_size) {
186
+ op->emitOpError ()
187
+ << " failed level check: " << operandOrResult
188
+ << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)" ;
149
189
return false ;
150
190
}
151
191
}
152
192
return true ;
153
193
}
154
194
155
195
template <typename T>
156
- bool levelCheckRanksFor (Operation *op ) {
157
- if (dyn_cast<T>(op)) {
158
- // level check ranks of all operands and results
159
- for (auto v : op->getOperands ()) {
160
- if (!levelCheckRank (op, v, " operand rank(shape) <= MAX_RANK" ))
161
- return false ;
162
- }
163
- for ( auto v : op-> getResults ()) {
164
- if (! levelCheckRank (op, v, " result rank(shape) <= MAX_RANK " ))
165
- return false ;
166
- }
196
+ bool levelCheckRanksAndSizesFor (T tosaOp ) {
197
+ // level check ranks of all operands and results
198
+ auto op = tosaOp. getOperation ();
199
+ for (auto v : op->getOperands ()) {
200
+ if (!levelCheckRankAndSizes (op, v, " operand" , tosaLevel. MAX_RANK ))
201
+ return false ;
202
+ }
203
+
204
+ for ( auto v : op-> getResults ()) {
205
+ if (! levelCheckRankAndSizes (op, v, " result " , tosaLevel. MAX_RANK ))
206
+ return false ;
167
207
}
168
208
return true ;
169
209
}
170
210
171
- bool levelCheckRanks (Operation *op) {
172
- #define CHECK_RANKS_FOR (tosaOp ) \
173
- if (!levelCheckRanksFor<tosaOp##Op>(op)) \
174
- return false ;
211
+ template <>
212
+ bool levelCheckRanksAndSizesFor (tosa::ArgMaxOp tosaOp) {
213
+ auto op = tosaOp.getOperation ();
214
+ if (!levelCheckRankAndSizes (op, tosaOp.getInput (), " operand" ,
215
+ tosaLevel.MAX_RANK ))
216
+ return false ;
217
+
218
+ // rank(output) = rank(input) - 1
219
+ if (!levelCheckRankAndSizes (op, tosaOp.getOutput (), " result" ,
220
+ tosaLevel.MAX_RANK - 1 ))
221
+ return false ;
222
+
223
+ return true ;
224
+ }
225
+
226
+ template <>
227
+ bool levelCheckRanksAndSizesFor (tosa::IfOp tosaOp) {
228
+ auto op = tosaOp.getOperation ();
229
+
230
+ // Only the condition input has rank limitation.
231
+ if (!levelCheckRankAndSizes (op, tosaOp.getCond (), " operand" ,
232
+ tosaLevel.MAX_RANK ))
233
+ return false ;
234
+
235
+ return true ;
236
+ }
237
+
238
+ bool levelCheckRanksAndSizes (Operation *op) {
239
+ #define CHECK_RANKS_AND_SIZES_FOR (tosaOp ) \
240
+ if (isa<tosa::tosaOp##Op>(op)) \
241
+ if (!levelCheckRanksAndSizesFor (cast<tosa::tosaOp##Op>(op))) \
242
+ return false ;
243
+
244
+ #define CHECK_RANKS_AND_SIZES_SKIP (tosaOp ) \
245
+ if (isa<tosa::tosaOp##Op>(op)) \
246
+ return true ;
175
247
176
248
// tensor operators:
177
- CHECK_RANKS_FOR (ArgMax);
249
+ CHECK_RANKS_AND_SIZES_FOR (ArgMax);
178
250
// all activation functions:
179
- CHECK_RANKS_FOR (Clamp);
180
- CHECK_RANKS_FOR (Sigmoid);
181
- CHECK_RANKS_FOR (Tanh);
251
+ CHECK_RANKS_AND_SIZES_FOR (Clamp);
252
+ CHECK_RANKS_AND_SIZES_FOR (Erf);
253
+ CHECK_RANKS_AND_SIZES_FOR (Sigmoid);
254
+ CHECK_RANKS_AND_SIZES_FOR (Tanh);
182
255
// all elementwise binary operators:
183
- CHECK_RANKS_FOR (Add);
184
- CHECK_RANKS_FOR (ArithmeticRightShift);
185
- CHECK_RANKS_FOR (BitwiseAnd);
186
- CHECK_RANKS_FOR (BitwiseOr);
187
- CHECK_RANKS_FOR (BitwiseXor);
188
- CHECK_RANKS_FOR (IntDiv);
189
- CHECK_RANKS_FOR (LogicalAnd);
190
- CHECK_RANKS_FOR (LogicalLeftShift);
191
- CHECK_RANKS_FOR (LogicalRightShift);
192
- CHECK_RANKS_FOR (LogicalOr);
193
- CHECK_RANKS_FOR (LogicalXor);
194
- CHECK_RANKS_FOR (Maximum);
195
- CHECK_RANKS_FOR (Minimum);
196
- CHECK_RANKS_FOR (Mul);
197
- CHECK_RANKS_FOR (Pow);
198
- CHECK_RANKS_FOR (Sub);
199
- CHECK_RANKS_FOR (Table);
256
+ CHECK_RANKS_AND_SIZES_FOR (Add);
257
+ CHECK_RANKS_AND_SIZES_FOR (ArithmeticRightShift);
258
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseAnd);
259
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseOr);
260
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseXor);
261
+ CHECK_RANKS_AND_SIZES_FOR (IntDiv);
262
+ CHECK_RANKS_AND_SIZES_FOR (LogicalAnd);
263
+ CHECK_RANKS_AND_SIZES_FOR (LogicalLeftShift);
264
+ CHECK_RANKS_AND_SIZES_FOR (LogicalRightShift);
265
+ CHECK_RANKS_AND_SIZES_FOR (LogicalOr);
266
+ CHECK_RANKS_AND_SIZES_FOR (LogicalXor);
267
+ CHECK_RANKS_AND_SIZES_FOR (Maximum);
268
+ CHECK_RANKS_AND_SIZES_FOR (Minimum);
269
+ CHECK_RANKS_AND_SIZES_FOR (Mul);
270
+ CHECK_RANKS_AND_SIZES_FOR (Pow);
271
+ CHECK_RANKS_AND_SIZES_FOR (Sub);
272
+ CHECK_RANKS_AND_SIZES_FOR (Table);
200
273
// all elementwise unary operators:
201
- CHECK_RANKS_FOR (Abs);
202
- CHECK_RANKS_FOR (BitwiseNot);
203
- CHECK_RANKS_FOR (Ceil);
204
- CHECK_RANKS_FOR (Clz);
205
- CHECK_RANKS_FOR (Exp);
206
- CHECK_RANKS_FOR (Floor);
207
- CHECK_RANKS_FOR (Log);
208
- CHECK_RANKS_FOR (LogicalNot);
209
- CHECK_RANKS_FOR (Negate);
210
- CHECK_RANKS_FOR (Reciprocal);
211
- CHECK_RANKS_FOR (Rsqrt);
274
+ CHECK_RANKS_AND_SIZES_FOR (Abs);
275
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseNot);
276
+ CHECK_RANKS_AND_SIZES_FOR (Ceil);
277
+ CHECK_RANKS_AND_SIZES_FOR (Clz);
278
+ CHECK_RANKS_AND_SIZES_FOR (Cos);
279
+ CHECK_RANKS_AND_SIZES_FOR (Exp);
280
+ CHECK_RANKS_AND_SIZES_FOR (Floor);
281
+ CHECK_RANKS_AND_SIZES_FOR (Log);
282
+ CHECK_RANKS_AND_SIZES_FOR (LogicalNot);
283
+ CHECK_RANKS_AND_SIZES_FOR (Negate);
284
+ CHECK_RANKS_AND_SIZES_FOR (Reciprocal);
285
+ CHECK_RANKS_AND_SIZES_FOR (Rsqrt);
286
+ CHECK_RANKS_AND_SIZES_FOR (Sin);
212
287
// all elementwise ternary operators:
213
- CHECK_RANKS_FOR (Select);
288
+ CHECK_RANKS_AND_SIZES_FOR (Select);
214
289
// all comparison operators:
215
- CHECK_RANKS_FOR (Equal);
216
- CHECK_RANKS_FOR (Greater);
217
- CHECK_RANKS_FOR (GreaterEqual);
290
+ CHECK_RANKS_AND_SIZES_FOR (Equal);
291
+ CHECK_RANKS_AND_SIZES_FOR (Greater);
292
+ CHECK_RANKS_AND_SIZES_FOR (GreaterEqual);
218
293
// all reduction operators:
219
- CHECK_RANKS_FOR (ReduceAll);
220
- CHECK_RANKS_FOR (ReduceAny);
221
- CHECK_RANKS_FOR (ReduceMax);
222
- CHECK_RANKS_FOR (ReduceMin);
223
- CHECK_RANKS_FOR (ReduceProduct);
224
- CHECK_RANKS_FOR (ReduceSum);
294
+ CHECK_RANKS_AND_SIZES_FOR (ReduceAll);
295
+ CHECK_RANKS_AND_SIZES_FOR (ReduceAny);
296
+ CHECK_RANKS_AND_SIZES_FOR (ReduceMax);
297
+ CHECK_RANKS_AND_SIZES_FOR (ReduceMin);
298
+ CHECK_RANKS_AND_SIZES_FOR (ReduceProduct);
299
+ CHECK_RANKS_AND_SIZES_FOR (ReduceSum);
225
300
// all data layout operators:
226
- CHECK_RANKS_FOR (Concat);
227
- CHECK_RANKS_FOR (Pad);
228
- CHECK_RANKS_FOR (Reshape);
229
- CHECK_RANKS_FOR (Reverse);
230
- CHECK_RANKS_FOR (Slice);
231
- CHECK_RANKS_FOR (Tile);
232
- CHECK_RANKS_FOR (Transpose);
301
+ CHECK_RANKS_AND_SIZES_FOR (Concat);
302
+ CHECK_RANKS_AND_SIZES_FOR (Pad);
303
+ CHECK_RANKS_AND_SIZES_FOR (Reshape);
304
+ CHECK_RANKS_AND_SIZES_FOR (Reverse);
305
+ CHECK_RANKS_AND_SIZES_FOR (Slice);
306
+ CHECK_RANKS_AND_SIZES_FOR (Tile);
307
+ CHECK_RANKS_AND_SIZES_FOR (Transpose);
233
308
// all type conversion operators:
234
- CHECK_RANKS_FOR (Cast);
235
- CHECK_RANKS_FOR (Rescale);
309
+ CHECK_RANKS_AND_SIZES_FOR (Cast);
310
+ CHECK_RANKS_AND_SIZES_FOR (Rescale);
311
+ // control flow operators:
312
+ CHECK_RANKS_AND_SIZES_FOR (If);
236
313
// all data nodes operators:
237
- CHECK_RANKS_FOR (Const);
238
- CHECK_RANKS_FOR (Identity);
314
+ CHECK_RANKS_AND_SIZES_FOR (Const);
315
+ CHECK_RANKS_AND_SIZES_FOR (Identity);
239
316
240
- #undef CHECK_RANKS_FOR
317
+ // The following operators do not have level rank and size constraint.
318
+ CHECK_RANKS_AND_SIZES_SKIP (Resize);
319
+ CHECK_RANKS_AND_SIZES_SKIP (Yield);
320
+ CHECK_RANKS_AND_SIZES_SKIP (Custom);
321
+ CHECK_RANKS_AND_SIZES_SKIP (While);
322
+
323
+ #undef CHECK_RANKS_AND_SIZES_FOR
324
+ #undef CHECK_RANKS_AND_SIZES_SKIP
241
325
return true ;
242
326
}
243
327
@@ -386,6 +470,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
386
470
return true ;
387
471
}
388
472
473
+ bool levelCheckListSize (Operation *op) {
474
+ if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
475
+ return levelCheckListSize (op, concat.getInput1 ().size (), " input1" );
476
+ }
477
+ if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
478
+ if (!levelCheckListSize (op, custom.getInputList ().size (), " input_list" ) ||
479
+ !levelCheckListSize (op, custom.getOutputList ().size (),
480
+ " output_list" )) {
481
+ return false ;
482
+ }
483
+ }
484
+ if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
485
+ if (!levelCheckListSize (op, condIf.getInputs ().size (), " inputs" ) ||
486
+ !levelCheckListSize (op, condIf.getOutput ().size (), " outputs" )) {
487
+ return false ;
488
+ }
489
+ }
490
+ if (auto w = dyn_cast<tosa::WhileOp>(op)) {
491
+ if (!levelCheckListSize (op, w.getInputs ().size (), " inputs" ) ||
492
+ !levelCheckListSize (op, w.getOutput ().size (), " outputs" )) {
493
+ return false ;
494
+ }
495
+ }
496
+ return true ;
497
+ }
498
+
389
499
// configure profile and level values from pass options profileName and
390
500
// levelName
391
501
void configLevelAndProfile () {
@@ -439,7 +549,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
439
549
return success ();
440
550
}
441
551
442
- if (!levelCheckRanks (op)) {
552
+ if (!levelCheckRanksAndSizes (op)) {
443
553
return failure ();
444
554
}
445
555
@@ -455,6 +565,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
455
565
return failure ();
456
566
}
457
567
568
+ // level check MAX_TENSOR_LIST_SIZE
569
+ if (!levelCheckListSize (op)) {
570
+ return failure ();
571
+ }
572
+
458
573
return success ();
459
574
}
460
575
0 commit comments