@@ -61,17 +61,22 @@ struct TosaLevel {
61
61
int32_t MAX_KERNEL = 0 ;
62
62
int32_t MAX_STRIDE = 0 ;
63
63
int32_t MAX_SCALE = 0 ;
64
-
65
- // @todo: MAX_LOG2_SIZE value and checks
64
+ int32_t MAX_LOG2_SIZE = 0 ;
65
+ int32_t MAX_NESTING = 0 ;
66
+ int32_t MAX_TENSOR_LIST_SIZE = 0 ;
66
67
67
68
bool operator ==(const TosaLevel &rhs) {
68
69
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
69
- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE ;
70
+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
71
+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
72
+ MAX_NESTING == rhs.MAX_NESTING &&
73
+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE ;
70
74
}
71
75
};
72
76
73
- static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 };
74
- static constexpr TosaLevel TOSA_LEVEL_NONE = {0 , 0 , 0 , 0 };
77
+ static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 , 31 , 6 , 64 };
78
+ static constexpr TosaLevel TOSA_LEVEL_NONE = {32 , 2147483647 , 2147483647 , 2048 ,
79
+ 63 , 256 , 256 };
75
80
76
81
// ===----------------------------------------------------------------------===//
77
82
// TOSA Validation Pass.
@@ -137,107 +142,188 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
137
142
return true ;
138
143
}
139
144
140
- bool levelCheckRank (Operation *op, const Value &v,
141
- const std::string &checkDesc) {
145
+ bool levelCheckListSize (Operation *op, int32_t v,
146
+ const std::string &checkDesc) {
147
+ if (v > tosaLevel.MAX_TENSOR_LIST_SIZE ) {
148
+ op->emitOpError () << " failed level check for MAX_TENSOR_LIST_SIZE: "
149
+ << checkDesc;
150
+ return false ;
151
+ }
152
+ return true ;
153
+ }
154
+
155
+ bool levelCheckRankAndSizes (Operation *op, const Value &v,
156
+ const std::string &operandOrResult,
157
+ int32_t highest_rank) {
142
158
if (ShapedType type = dyn_cast<ShapedType>(v.getType ())) {
143
159
if (!type.hasRank ()) {
144
160
op->emitOpError () << " failed level check: unranked tensor" ;
145
161
return false ;
146
162
}
147
- if (type.getRank () > tosaLevel.MAX_RANK ) {
148
- op->emitOpError () << " failed level check: " << checkDesc;
163
+ if (type.getRank () > highest_rank) {
164
+ op->emitOpError () << " failed level check: " << operandOrResult
165
+ << " rank(shape) <= MAX_RANK" ;
166
+ return false ;
167
+ }
168
+
169
+ auto shape = type.getShape ();
170
+ for (auto dim : shape) {
171
+ if (mlir::ShapedType::isDynamic (dim)) {
172
+ op->emitOpError () << " failed level check: " << operandOrResult
173
+ << " shape dimension cannot be dynamic" ;
174
+ return false ;
175
+ }
176
+ }
177
+
178
+ int64_t element_bits = type.getElementTypeBitWidth ();
179
+ int64_t element_bytes = std::max (INT64_C (1 ), element_bits / 8 );
180
+ int64_t size = element_bytes * type.getNumElements ();
181
+
182
+ // According to 1.11. Tensor Definitions of Tosa spec, the value of
183
+ // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
184
+ // defined in 1.7. Levels.
185
+ // For each tensor, the number of tensor elements multiplied by the
186
+ // element size in bytes must be representable as a tensor_size_t.
187
+ const int64_t max_size = (INT64_C (1 ) << tosaLevel.MAX_LOG2_SIZE ) - 1 ;
188
+ if (size > max_size) {
189
+ op->emitOpError ()
190
+ << " failed level check: " << operandOrResult
191
+ << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)" ;
149
192
return false ;
150
193
}
151
194
}
152
195
return true ;
153
196
}
154
197
155
198
template <typename T>
156
- bool levelCheckRanksFor (Operation *op ) {
157
- if (dyn_cast<T>(op)) {
158
- // level check ranks of all operands and results
159
- for (auto v : op->getOperands ()) {
160
- if (!levelCheckRank (op, v, " operand rank(shape) <= MAX_RANK" ))
161
- return false ;
162
- }
163
- for ( auto v : op-> getResults ()) {
164
- if (! levelCheckRank (op, v, " result rank(shape) <= MAX_RANK " ))
165
- return false ;
166
- }
199
+ bool levelCheckRanksAndSizesFor (T tosaOp ) {
200
+ // level check ranks of all operands and results
201
+ auto op = tosaOp. getOperation ();
202
+ for (auto v : op->getOperands ()) {
203
+ if (!levelCheckRankAndSizes (op, v, " operand" , tosaLevel. MAX_RANK ))
204
+ return false ;
205
+ }
206
+
207
+ for ( auto v : op-> getResults ()) {
208
+ if (! levelCheckRankAndSizes (op, v, " result " , tosaLevel. MAX_RANK ))
209
+ return false ;
167
210
}
168
211
return true ;
169
212
}
170
213
171
- bool levelCheckRanks (Operation *op) {
172
- #define CHECK_RANKS_FOR (tosaOp ) \
173
- if (!levelCheckRanksFor<tosaOp##Op>(op)) \
174
- return false ;
214
+ template <>
215
+ bool levelCheckRanksAndSizesFor (tosa::ArgMaxOp tosaOp) {
216
+ auto op = tosaOp.getOperation ();
217
+ if (!levelCheckRankAndSizes (op, tosaOp.getInput (), " operand" ,
218
+ tosaLevel.MAX_RANK ))
219
+ return false ;
220
+
221
+ // rank(output) = rank(input) - 1
222
+ if (!levelCheckRankAndSizes (op, tosaOp.getOutput (), " result" ,
223
+ tosaLevel.MAX_RANK - 1 ))
224
+ return false ;
225
+
226
+ return true ;
227
+ }
228
+
229
+ template <>
230
+ bool levelCheckRanksAndSizesFor (tosa::IfOp tosaOp) {
231
+ auto op = tosaOp.getOperation ();
232
+
233
+ // Only the condition input has rank limitation.
234
+ if (!levelCheckRankAndSizes (op, tosaOp.getCond (), " operand" ,
235
+ tosaLevel.MAX_RANK ))
236
+ return false ;
237
+
238
+ return true ;
239
+ }
240
+
241
+ bool levelCheckRanksAndSizes (Operation *op) {
242
+ #define CHECK_RANKS_AND_SIZES_FOR (tosaOp ) \
243
+ if (isa<tosa::tosaOp##Op>(op)) \
244
+ if (!levelCheckRanksAndSizesFor (cast<tosa::tosaOp##Op>(op))) \
245
+ return false ;
246
+
247
+ #define CHECK_RANKS_AND_SIZES_SKIP (tosaOp ) \
248
+ if (isa<tosa::tosaOp##Op>(op)) \
249
+ return true ;
175
250
176
251
// tensor operators:
177
- CHECK_RANKS_FOR (ArgMax);
252
+ CHECK_RANKS_AND_SIZES_FOR (ArgMax);
178
253
// all activation functions:
179
- CHECK_RANKS_FOR (Clamp);
180
- CHECK_RANKS_FOR (Sigmoid);
181
- CHECK_RANKS_FOR (Tanh);
254
+ CHECK_RANKS_AND_SIZES_FOR (Clamp);
255
+ CHECK_RANKS_AND_SIZES_FOR (Erf);
256
+ CHECK_RANKS_AND_SIZES_FOR (Sigmoid);
257
+ CHECK_RANKS_AND_SIZES_FOR (Tanh);
182
258
// all elementwise binary operators:
183
- CHECK_RANKS_FOR (Add);
184
- CHECK_RANKS_FOR (ArithmeticRightShift);
185
- CHECK_RANKS_FOR (BitwiseAnd);
186
- CHECK_RANKS_FOR (BitwiseOr);
187
- CHECK_RANKS_FOR (BitwiseXor);
188
- CHECK_RANKS_FOR (IntDiv);
189
- CHECK_RANKS_FOR (LogicalAnd);
190
- CHECK_RANKS_FOR (LogicalLeftShift);
191
- CHECK_RANKS_FOR (LogicalRightShift);
192
- CHECK_RANKS_FOR (LogicalOr);
193
- CHECK_RANKS_FOR (LogicalXor);
194
- CHECK_RANKS_FOR (Maximum);
195
- CHECK_RANKS_FOR (Minimum);
196
- CHECK_RANKS_FOR (Mul);
197
- CHECK_RANKS_FOR (Pow);
198
- CHECK_RANKS_FOR (Sub);
199
- CHECK_RANKS_FOR (Table);
259
+ CHECK_RANKS_AND_SIZES_FOR (Add);
260
+ CHECK_RANKS_AND_SIZES_FOR (ArithmeticRightShift);
261
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseAnd);
262
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseOr);
263
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseXor);
264
+ CHECK_RANKS_AND_SIZES_FOR (IntDiv);
265
+ CHECK_RANKS_AND_SIZES_FOR (LogicalAnd);
266
+ CHECK_RANKS_AND_SIZES_FOR (LogicalLeftShift);
267
+ CHECK_RANKS_AND_SIZES_FOR (LogicalRightShift);
268
+ CHECK_RANKS_AND_SIZES_FOR (LogicalOr);
269
+ CHECK_RANKS_AND_SIZES_FOR (LogicalXor);
270
+ CHECK_RANKS_AND_SIZES_FOR (Maximum);
271
+ CHECK_RANKS_AND_SIZES_FOR (Minimum);
272
+ CHECK_RANKS_AND_SIZES_FOR (Mul);
273
+ CHECK_RANKS_AND_SIZES_FOR (Pow);
274
+ CHECK_RANKS_AND_SIZES_FOR (Sub);
275
+ CHECK_RANKS_AND_SIZES_FOR (Table);
200
276
// all elementwise unary operators:
201
- CHECK_RANKS_FOR (Abs);
202
- CHECK_RANKS_FOR (BitwiseNot);
203
- CHECK_RANKS_FOR (Ceil);
204
- CHECK_RANKS_FOR (Clz);
205
- CHECK_RANKS_FOR (Exp);
206
- CHECK_RANKS_FOR (Floor);
207
- CHECK_RANKS_FOR (Log);
208
- CHECK_RANKS_FOR (LogicalNot);
209
- CHECK_RANKS_FOR (Negate);
210
- CHECK_RANKS_FOR (Reciprocal);
211
- CHECK_RANKS_FOR (Rsqrt);
277
+ CHECK_RANKS_AND_SIZES_FOR (Abs);
278
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseNot);
279
+ CHECK_RANKS_AND_SIZES_FOR (Ceil);
280
+ CHECK_RANKS_AND_SIZES_FOR (Clz);
281
+ CHECK_RANKS_AND_SIZES_FOR (Cos);
282
+ CHECK_RANKS_AND_SIZES_FOR (Exp);
283
+ CHECK_RANKS_AND_SIZES_FOR (Floor);
284
+ CHECK_RANKS_AND_SIZES_FOR (Log);
285
+ CHECK_RANKS_AND_SIZES_FOR (LogicalNot);
286
+ CHECK_RANKS_AND_SIZES_FOR (Negate);
287
+ CHECK_RANKS_AND_SIZES_FOR (Reciprocal);
288
+ CHECK_RANKS_AND_SIZES_FOR (Rsqrt);
289
+ CHECK_RANKS_AND_SIZES_FOR (Sin);
212
290
// all elementwise ternary operators:
213
- CHECK_RANKS_FOR (Select);
291
+ CHECK_RANKS_AND_SIZES_FOR (Select);
214
292
// all comparison operators:
215
- CHECK_RANKS_FOR (Equal);
216
- CHECK_RANKS_FOR (Greater);
217
- CHECK_RANKS_FOR (GreaterEqual);
293
+ CHECK_RANKS_AND_SIZES_FOR (Equal);
294
+ CHECK_RANKS_AND_SIZES_FOR (Greater);
295
+ CHECK_RANKS_AND_SIZES_FOR (GreaterEqual);
218
296
// all reduction operators:
219
- CHECK_RANKS_FOR (ReduceAll);
220
- CHECK_RANKS_FOR (ReduceAny);
221
- CHECK_RANKS_FOR (ReduceMax);
222
- CHECK_RANKS_FOR (ReduceMin);
223
- CHECK_RANKS_FOR (ReduceProduct);
224
- CHECK_RANKS_FOR (ReduceSum);
297
+ CHECK_RANKS_AND_SIZES_FOR (ReduceAll);
298
+ CHECK_RANKS_AND_SIZES_FOR (ReduceAny);
299
+ CHECK_RANKS_AND_SIZES_FOR (ReduceMax);
300
+ CHECK_RANKS_AND_SIZES_FOR (ReduceMin);
301
+ CHECK_RANKS_AND_SIZES_FOR (ReduceProduct);
302
+ CHECK_RANKS_AND_SIZES_FOR (ReduceSum);
225
303
// all data layout operators:
226
- CHECK_RANKS_FOR (Concat);
227
- CHECK_RANKS_FOR (Pad);
228
- CHECK_RANKS_FOR (Reshape);
229
- CHECK_RANKS_FOR (Reverse);
230
- CHECK_RANKS_FOR (Slice);
231
- CHECK_RANKS_FOR (Tile);
232
- CHECK_RANKS_FOR (Transpose);
304
+ CHECK_RANKS_AND_SIZES_FOR (Concat);
305
+ CHECK_RANKS_AND_SIZES_FOR (Pad);
306
+ CHECK_RANKS_AND_SIZES_FOR (Reshape);
307
+ CHECK_RANKS_AND_SIZES_FOR (Reverse);
308
+ CHECK_RANKS_AND_SIZES_FOR (Slice);
309
+ CHECK_RANKS_AND_SIZES_FOR (Tile);
310
+ CHECK_RANKS_AND_SIZES_FOR (Transpose);
233
311
// all type conversion operators:
234
- CHECK_RANKS_FOR (Cast);
235
- CHECK_RANKS_FOR (Rescale);
312
+ CHECK_RANKS_AND_SIZES_FOR (Cast);
313
+ CHECK_RANKS_AND_SIZES_FOR (Rescale);
314
+ // control flow operators:
315
+ CHECK_RANKS_AND_SIZES_FOR (If);
236
316
// all data nodes operators:
237
- CHECK_RANKS_FOR (Const);
238
- CHECK_RANKS_FOR (Identity);
317
+ CHECK_RANKS_AND_SIZES_FOR (Const);
318
+ CHECK_RANKS_AND_SIZES_FOR (Identity);
319
+
320
+ // The following operators do not have level rank and size constraint.
321
+ CHECK_RANKS_AND_SIZES_SKIP (Yield);
322
+ CHECK_RANKS_AND_SIZES_SKIP (Custom);
323
+ CHECK_RANKS_AND_SIZES_SKIP (While);
239
324
240
- #undef CHECK_RANKS_FOR
325
+ #undef CHECK_RANKS_AND_SIZES_FOR
326
+ #undef CHECK_RANKS_AND_SIZES_SKIP
241
327
return true ;
242
328
}
243
329
@@ -386,6 +472,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
386
472
return true ;
387
473
}
388
474
475
+ bool levelCheckListSize (Operation *op) {
476
+ if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
477
+ return levelCheckListSize (op, concat.getInput1 ().size (), " input1" );
478
+ }
479
+ if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
480
+ if (!levelCheckListSize (op, custom.getInputList ().size (), " input_list" ) ||
481
+ !levelCheckListSize (op, custom.getOutputList ().size (),
482
+ " output_list" )) {
483
+ return false ;
484
+ }
485
+ }
486
+ if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
487
+ if (!levelCheckListSize (op, condIf.getInputs ().size (), " inputs" ) ||
488
+ !levelCheckListSize (op, condIf.getOutput ().size (), " outputs" )) {
489
+ return false ;
490
+ }
491
+ }
492
+ if (auto w = dyn_cast<tosa::WhileOp>(op)) {
493
+ if (!levelCheckListSize (op, w.getInputs ().size (), " inputs" ) ||
494
+ !levelCheckListSize (op, w.getOutput ().size (), " outputs" )) {
495
+ return false ;
496
+ }
497
+ }
498
+ return true ;
499
+ }
500
+
389
501
// configure profile and level values from pass options profileName and
390
502
// levelName
391
503
void configLevelAndProfile () {
@@ -439,7 +551,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
439
551
return success ();
440
552
}
441
553
442
- if (!levelCheckRanks (op)) {
554
+ if (!levelCheckRanksAndSizes (op)) {
443
555
return failure ();
444
556
}
445
557
@@ -455,6 +567,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
455
567
return failure ();
456
568
}
457
569
570
+ // level check MAX_TENSOR_LIST_SIZE
571
+ if (!levelCheckListSize (op)) {
572
+ return failure ();
573
+ }
574
+
458
575
return success ();
459
576
}
460
577
@@ -685,6 +802,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
685
802
}
686
803
687
804
bool TosaValidation::isValidElementType (Type type) {
805
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(type))
806
+ type = quantType.getStorageType ();
807
+
688
808
if (isa<FloatType>(type)) {
689
809
return type.isF32 () || type.isF16 () || type.isBF16 ();
690
810
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
0 commit comments