@@ -70,17 +70,22 @@ struct TosaLevel {
70
70
int32_t MAX_KERNEL = 0 ;
71
71
int32_t MAX_STRIDE = 0 ;
72
72
int32_t MAX_SCALE = 0 ;
73
-
74
- // @todo: MAX_LOG2_SIZE value and checks
73
+ int32_t MAX_LOG2_SIZE = 0 ;
74
+ int32_t MAX_NESTING = 0 ;
75
+ int32_t MAX_TENSOR_LIST_SIZE = 0 ;
75
76
76
77
bool operator ==(const TosaLevel &rhs) {
77
78
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
78
- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE ;
79
+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
80
+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
81
+ MAX_NESTING == rhs.MAX_NESTING &&
82
+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE ;
79
83
}
80
84
};
81
85
82
- static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 };
83
- static constexpr TosaLevel TOSA_LEVEL_NONE = {0 , 0 , 0 , 0 };
86
+ static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 , 31 , 6 , 64 };
87
+ static constexpr TosaLevel TOSA_LEVEL_NONE = {32 , 2147483647 , 2147483647 , 2048 ,
88
+ 63 , 256 , 256 };
84
89
85
90
// ===----------------------------------------------------------------------===//
86
91
// TOSA Validation Pass.
@@ -147,107 +152,151 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
147
152
return true ;
148
153
}
149
154
150
- bool levelCheckRank (Operation *op, const Value &v,
151
- const std::string &checkDesc) {
155
+ bool levelCheckListSize (Operation *op, int32_t v,
156
+ const std::string &checkDesc) {
157
+ if (v > tosaLevel.MAX_TENSOR_LIST_SIZE ) {
158
+ op->emitOpError () << " failed level check for MAX_TENSOR_LIST_SIZE: "
159
+ << checkDesc;
160
+ return false ;
161
+ }
162
+ return true ;
163
+ }
164
+
165
+ bool levelCheckRankAndSizes (Operation *op, const Value &v,
166
+ const std::string &operandOrResult) {
152
167
if (ShapedType type = dyn_cast<ShapedType>(v.getType ())) {
153
168
if (!type.hasRank ()) {
154
169
op->emitOpError () << " failed level check: unranked tensor" ;
155
170
return false ;
156
171
}
157
172
if (type.getRank () > tosaLevel.MAX_RANK ) {
158
- op->emitOpError () << " failed level check: " << checkDesc;
173
+ op->emitOpError () << " failed level check: " << operandOrResult
174
+ << " rank(shape) <= MAX_RANK" ;
159
175
return false ;
160
176
}
177
+
178
+ const int64_t max_dim = (INT64_C (1 ) << tosaLevel.MAX_LOG2_SIZE ) - 1 ;
179
+ const int64_t max_size =
180
+ (INT64_C (1 ) << (tosaLevel.MAX_LOG2_SIZE + 1 )) - 1 ;
181
+
182
+ auto shape = type.getShape ();
183
+ bool has_dynamic = false ;
184
+ for (auto dim : shape) {
185
+ if (mlir::ShapedType::isDynamic (dim)) {
186
+ has_dynamic = true ;
187
+ continue ;
188
+ }
189
+ if (dim > max_dim) {
190
+ op->emitOpError () << " failed level check: " << operandOrResult
191
+ << " shape dimension <= (1<<MAX_LOG2_SIZE) - 1" ;
192
+ return false ;
193
+ }
194
+ }
195
+ if (!has_dynamic) {
196
+ int64_t element_bits = type.getElementTypeBitWidth ();
197
+ int64_t element_bytes = std::max (INT64_C (1 ), element_bits / 8 );
198
+ int64_t size = element_bytes * type.getNumElements ();
199
+ if (size > max_size) {
200
+ op->emitOpError ()
201
+ << " failed level check: " << operandOrResult << " tensor size "
202
+ << size << " (in bytes) <= "
203
+ << " (1<<MAX_LOG2_SIZE+1) - 1, where max_size = " << max_size;
204
+ return false ;
205
+ }
206
+ }
161
207
}
162
208
return true ;
163
209
}
164
210
165
211
template <typename T>
166
- bool levelCheckRanksFor (Operation *op) {
212
+ bool levelCheckRanksAndSizesFor (Operation *op) {
167
213
if (dyn_cast<T>(op)) {
168
214
// level check ranks of all operands and results
169
215
for (auto v : op->getOperands ()) {
170
- if (!levelCheckRank (op, v, " operand rank(shape) <= MAX_RANK " ))
216
+ if (!levelCheckRankAndSizes (op, v, " operand" ))
171
217
return false ;
172
218
}
173
219
for (auto v : op->getResults ()) {
174
- if (!levelCheckRank (op, v, " result rank(shape) <= MAX_RANK " ))
220
+ if (!levelCheckRankAndSizes (op, v, " result" ))
175
221
return false ;
176
222
}
177
223
}
178
224
return true ;
179
225
}
180
226
181
- bool levelCheckRanks (Operation *op) {
182
- #define CHECK_RANKS_FOR (tosaOp ) \
183
- if (!levelCheckRanksFor <tosaOp##Op>(op)) \
227
+ bool levelCheckRanksAndSizes (Operation *op) {
228
+ #define CHECK_RANKS_AND_SIZES_FOR (tosaOp ) \
229
+ if (!levelCheckRanksAndSizesFor <tosaOp##Op>(op)) \
184
230
return false ;
185
231
186
232
// tensor operators:
187
- CHECK_RANKS_FOR (ArgMax);
233
+ CHECK_RANKS_AND_SIZES_FOR (ArgMax);
188
234
// all activation functions:
189
- CHECK_RANKS_FOR (Clamp);
190
- CHECK_RANKS_FOR (Sigmoid);
191
- CHECK_RANKS_FOR (Tanh);
235
+ CHECK_RANKS_AND_SIZES_FOR (Clamp);
236
+ CHECK_RANKS_AND_SIZES_FOR (Erf);
237
+ CHECK_RANKS_AND_SIZES_FOR (Sigmoid);
238
+ CHECK_RANKS_AND_SIZES_FOR (Tanh);
192
239
// all elementwise binary operators:
193
- CHECK_RANKS_FOR (Add);
194
- CHECK_RANKS_FOR (ArithmeticRightShift);
195
- CHECK_RANKS_FOR (BitwiseAnd);
196
- CHECK_RANKS_FOR (BitwiseOr);
197
- CHECK_RANKS_FOR (BitwiseXor);
198
- CHECK_RANKS_FOR (IntDiv);
199
- CHECK_RANKS_FOR (LogicalAnd);
200
- CHECK_RANKS_FOR (LogicalLeftShift);
201
- CHECK_RANKS_FOR (LogicalRightShift);
202
- CHECK_RANKS_FOR (LogicalOr);
203
- CHECK_RANKS_FOR (LogicalXor);
204
- CHECK_RANKS_FOR (Maximum);
205
- CHECK_RANKS_FOR (Minimum);
206
- CHECK_RANKS_FOR (Mul);
207
- CHECK_RANKS_FOR (Pow);
208
- CHECK_RANKS_FOR (Sub);
209
- CHECK_RANKS_FOR (Table);
240
+ CHECK_RANKS_AND_SIZES_FOR (Add);
241
+ CHECK_RANKS_AND_SIZES_FOR (ArithmeticRightShift);
242
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseAnd);
243
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseOr);
244
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseXor);
245
+ CHECK_RANKS_AND_SIZES_FOR (IntDiv);
246
+ CHECK_RANKS_AND_SIZES_FOR (LogicalAnd);
247
+ CHECK_RANKS_AND_SIZES_FOR (LogicalLeftShift);
248
+ CHECK_RANKS_AND_SIZES_FOR (LogicalRightShift);
249
+ CHECK_RANKS_AND_SIZES_FOR (LogicalOr);
250
+ CHECK_RANKS_AND_SIZES_FOR (LogicalXor);
251
+ CHECK_RANKS_AND_SIZES_FOR (Maximum);
252
+ CHECK_RANKS_AND_SIZES_FOR (Minimum);
253
+ CHECK_RANKS_AND_SIZES_FOR (Mul);
254
+ CHECK_RANKS_AND_SIZES_FOR (Pow);
255
+ CHECK_RANKS_AND_SIZES_FOR (Sub);
256
+ CHECK_RANKS_AND_SIZES_FOR (Table);
210
257
// all elementwise unary operators:
211
- CHECK_RANKS_FOR (Abs);
212
- CHECK_RANKS_FOR (BitwiseNot);
213
- CHECK_RANKS_FOR (Ceil);
214
- CHECK_RANKS_FOR (Clz);
215
- CHECK_RANKS_FOR (Exp);
216
- CHECK_RANKS_FOR (Floor);
217
- CHECK_RANKS_FOR (Log);
218
- CHECK_RANKS_FOR (LogicalNot);
219
- CHECK_RANKS_FOR (Negate);
220
- CHECK_RANKS_FOR (Reciprocal);
221
- CHECK_RANKS_FOR (Rsqrt);
258
+ CHECK_RANKS_AND_SIZES_FOR (Abs);
259
+ CHECK_RANKS_AND_SIZES_FOR (BitwiseNot);
260
+ CHECK_RANKS_AND_SIZES_FOR (Ceil);
261
+ CHECK_RANKS_AND_SIZES_FOR (Clz);
262
+ CHECK_RANKS_AND_SIZES_FOR (Cos);
263
+ CHECK_RANKS_AND_SIZES_FOR (Exp);
264
+ CHECK_RANKS_AND_SIZES_FOR (Floor);
265
+ CHECK_RANKS_AND_SIZES_FOR (Log);
266
+ CHECK_RANKS_AND_SIZES_FOR (LogicalNot);
267
+ CHECK_RANKS_AND_SIZES_FOR (Negate);
268
+ CHECK_RANKS_AND_SIZES_FOR (Reciprocal);
269
+ CHECK_RANKS_AND_SIZES_FOR (Rsqrt);
270
+ CHECK_RANKS_AND_SIZES_FOR (Sin);
222
271
// all elementwise ternary operators:
223
- CHECK_RANKS_FOR (Select);
272
+ CHECK_RANKS_AND_SIZES_FOR (Select);
224
273
// all comparison operators:
225
- CHECK_RANKS_FOR (Equal);
226
- CHECK_RANKS_FOR (Greater);
227
- CHECK_RANKS_FOR (GreaterEqual);
274
+ CHECK_RANKS_AND_SIZES_FOR (Equal);
275
+ CHECK_RANKS_AND_SIZES_FOR (Greater);
276
+ CHECK_RANKS_AND_SIZES_FOR (GreaterEqual);
228
277
// all reduction operators:
229
- CHECK_RANKS_FOR (ReduceAll);
230
- CHECK_RANKS_FOR (ReduceAny);
231
- CHECK_RANKS_FOR (ReduceMax);
232
- CHECK_RANKS_FOR (ReduceMin);
233
- CHECK_RANKS_FOR (ReduceProd);
234
- CHECK_RANKS_FOR (ReduceSum);
278
+ CHECK_RANKS_AND_SIZES_FOR (ReduceAll);
279
+ CHECK_RANKS_AND_SIZES_FOR (ReduceAny);
280
+ CHECK_RANKS_AND_SIZES_FOR (ReduceMax);
281
+ CHECK_RANKS_AND_SIZES_FOR (ReduceMin);
282
+ CHECK_RANKS_AND_SIZES_FOR (ReduceProd);
283
+ CHECK_RANKS_AND_SIZES_FOR (ReduceSum);
235
284
// all data layout operators:
236
- CHECK_RANKS_FOR (Concat);
237
- CHECK_RANKS_FOR (Pad);
238
- CHECK_RANKS_FOR (Reshape);
239
- CHECK_RANKS_FOR (Reverse);
240
- CHECK_RANKS_FOR (Slice);
241
- CHECK_RANKS_FOR (Tile);
242
- CHECK_RANKS_FOR (Transpose);
285
+ CHECK_RANKS_AND_SIZES_FOR (Concat);
286
+ CHECK_RANKS_AND_SIZES_FOR (Pad);
287
+ CHECK_RANKS_AND_SIZES_FOR (Reshape);
288
+ CHECK_RANKS_AND_SIZES_FOR (Reverse);
289
+ CHECK_RANKS_AND_SIZES_FOR (Slice);
290
+ CHECK_RANKS_AND_SIZES_FOR (Tile);
291
+ CHECK_RANKS_AND_SIZES_FOR (Transpose);
243
292
// all type conversion operators:
244
- CHECK_RANKS_FOR (Cast);
245
- CHECK_RANKS_FOR (Rescale);
293
+ CHECK_RANKS_AND_SIZES_FOR (Cast);
294
+ CHECK_RANKS_AND_SIZES_FOR (Rescale);
246
295
// all data nodes operators:
247
- CHECK_RANKS_FOR (Const);
248
- CHECK_RANKS_FOR (Identity);
296
+ CHECK_RANKS_AND_SIZES_FOR (Const);
297
+ CHECK_RANKS_AND_SIZES_FOR (Identity);
249
298
250
- #undef CHECK_RANKS_FOR
299
+ #undef CHECK_RANKS_AND_SIZES_FOR
251
300
return true ;
252
301
}
253
302
@@ -396,6 +445,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
396
445
return true ;
397
446
}
398
447
448
+ bool levelCheckListSize (Operation *op) {
449
+ if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
450
+ return levelCheckListSize (op, concat.getInput1 ().size (), " input1" );
451
+ }
452
+ if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
453
+ if (!levelCheckListSize (op, custom.getInputList ().size (), " input_list" ) ||
454
+ !levelCheckListSize (op, custom.getOutputList ().size (),
455
+ " output_list" )) {
456
+ return false ;
457
+ }
458
+ }
459
+ if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
460
+ if (!levelCheckListSize (op, condIf.getInputs ().size (), " inputs" ) ||
461
+ !levelCheckListSize (op, condIf.getOutput ().size (), " outputs" )) {
462
+ return false ;
463
+ }
464
+ }
465
+ if (auto w = dyn_cast<tosa::WhileOp>(op)) {
466
+ if (!levelCheckListSize (op, w.getInputs ().size (), " inputs" ) ||
467
+ !levelCheckListSize (op, w.getOutput ().size (), " outputs" )) {
468
+ return false ;
469
+ }
470
+ }
471
+ return true ;
472
+ }
473
+
399
474
// configure profile and level values from pass options profileName and
400
475
// levelName
401
476
void configLevelAndProfile () {
@@ -449,7 +524,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
449
524
return success ();
450
525
}
451
526
452
- if (!levelCheckRanks (op)) {
527
+ if (!levelCheckRanksAndSizes (op)) {
453
528
return failure ();
454
529
}
455
530
@@ -465,6 +540,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
465
540
return failure ();
466
541
}
467
542
543
+ // level check MAX_TENSOR_LIST_SIZE
544
+ if (!levelCheckListSize (op)) {
545
+ return failure ();
546
+ }
547
+
468
548
return success ();
469
549
}
470
550
@@ -695,6 +775,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
695
775
}
696
776
697
777
bool TosaValidation::isValidElementType (Type type) {
778
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(type))
779
+ type = quantType.getStorageType ();
780
+
698
781
if (isa<FloatType>(type)) {
699
782
return type.isF32 () || type.isF16 () || type.isBF16 ();
700
783
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
0 commit comments