Skip to content

Commit fcd29d6

Browse files
Tai78641tatwaichong
andcommitted
[mlir][tosa] Add several level checks
Add the following types of level check to consolidate the level validity - Add rank level checks for Erf, Cos, and Sin. - Add MAX_LOG2_SIZE level check: The maximum value is 63 when the level is set to "none" and 31 when the level is set to "8K". - Add MAX_TENSOR_LIST_SIZE level check : The maximum value is 256 when the level is set to "none" and 64 when the level is set to "8K". Co-authored-by: TatWai Chong <[email protected]> Change-Id: I797fafe504219e43950824c04839c7187065fe8e
1 parent 4c9e14b commit fcd29d6

File tree

2 files changed

+742
-75
lines changed

2 files changed

+742
-75
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 152 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,22 @@ struct TosaLevel {
7070
int32_t MAX_KERNEL = 0;
7171
int32_t MAX_STRIDE = 0;
7272
int32_t MAX_SCALE = 0;
73-
74-
// @todo: MAX_LOG2_SIZE value and checks
73+
int32_t MAX_LOG2_SIZE = 0;
74+
int32_t MAX_NESTING = 0;
75+
int32_t MAX_TENSOR_LIST_SIZE = 0;
7576

7677
bool operator==(const TosaLevel &rhs) {
7778
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
78-
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
79+
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
80+
MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
81+
MAX_NESTING == rhs.MAX_NESTING &&
82+
MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
7983
}
8084
};
8185

82-
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
83-
static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
86+
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
87+
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
88+
63, 256, 256};
8489

8590
//===----------------------------------------------------------------------===//
8691
// TOSA Validation Pass.
@@ -147,107 +152,151 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
147152
return true;
148153
}
149154

150-
bool levelCheckRank(Operation *op, const Value &v,
151-
const std::string &checkDesc) {
155+
bool levelCheckListSize(Operation *op, int32_t v,
156+
const std::string &checkDesc) {
157+
if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
158+
op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
159+
<< checkDesc;
160+
return false;
161+
}
162+
return true;
163+
}
164+
165+
bool levelCheckRankAndSizes(Operation *op, const Value &v,
166+
const std::string &operandOrResult) {
152167
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
153168
if (!type.hasRank()) {
154169
op->emitOpError() << "failed level check: unranked tensor";
155170
return false;
156171
}
157172
if (type.getRank() > tosaLevel.MAX_RANK) {
158-
op->emitOpError() << "failed level check: " << checkDesc;
173+
op->emitOpError() << "failed level check: " << operandOrResult
174+
<< " rank(shape) <= MAX_RANK";
159175
return false;
160176
}
177+
178+
const int64_t max_dim = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
179+
const int64_t max_size =
180+
(INT64_C(1) << (tosaLevel.MAX_LOG2_SIZE + 1)) - 1;
181+
182+
auto shape = type.getShape();
183+
bool has_dynamic = false;
184+
for (auto dim : shape) {
185+
if (mlir::ShapedType::isDynamic(dim)) {
186+
has_dynamic = true;
187+
continue;
188+
}
189+
if (dim > max_dim) {
190+
op->emitOpError() << "failed level check: " << operandOrResult
191+
<< " shape dimension <= (1<<MAX_LOG2_SIZE) - 1";
192+
return false;
193+
}
194+
}
195+
if (!has_dynamic) {
196+
int64_t element_bits = type.getElementTypeBitWidth();
197+
int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
198+
int64_t size = element_bytes * type.getNumElements();
199+
if (size > max_size) {
200+
op->emitOpError()
201+
<< "failed level check: " << operandOrResult << " tensor size "
202+
<< size << " (in bytes) <= "
203+
<< "(1<<MAX_LOG2_SIZE+1) - 1, where max_size = " << max_size;
204+
return false;
205+
}
206+
}
161207
}
162208
return true;
163209
}
164210

165211
template <typename T>
166-
bool levelCheckRanksFor(Operation *op) {
212+
bool levelCheckRanksAndSizesFor(Operation *op) {
167213
if (dyn_cast<T>(op)) {
168214
// level check ranks of all operands and results
169215
for (auto v : op->getOperands()) {
170-
if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
216+
if (!levelCheckRankAndSizes(op, v, "operand"))
171217
return false;
172218
}
173219
for (auto v : op->getResults()) {
174-
if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
220+
if (!levelCheckRankAndSizes(op, v, "result"))
175221
return false;
176222
}
177223
}
178224
return true;
179225
}
180226

181-
bool levelCheckRanks(Operation *op) {
182-
#define CHECK_RANKS_FOR(tosaOp) \
183-
if (!levelCheckRanksFor<tosaOp##Op>(op)) \
227+
bool levelCheckRanksAndSizes(Operation *op) {
228+
#define CHECK_RANKS_AND_SIZES_FOR(tosaOp) \
229+
if (!levelCheckRanksAndSizesFor<tosaOp##Op>(op)) \
184230
return false;
185231

186232
// tensor operators:
187-
CHECK_RANKS_FOR(ArgMax);
233+
CHECK_RANKS_AND_SIZES_FOR(ArgMax);
188234
// all activation functions:
189-
CHECK_RANKS_FOR(Clamp);
190-
CHECK_RANKS_FOR(Sigmoid);
191-
CHECK_RANKS_FOR(Tanh);
235+
CHECK_RANKS_AND_SIZES_FOR(Clamp);
236+
CHECK_RANKS_AND_SIZES_FOR(Erf);
237+
CHECK_RANKS_AND_SIZES_FOR(Sigmoid);
238+
CHECK_RANKS_AND_SIZES_FOR(Tanh);
192239
// all elementwise binary operators:
193-
CHECK_RANKS_FOR(Add);
194-
CHECK_RANKS_FOR(ArithmeticRightShift);
195-
CHECK_RANKS_FOR(BitwiseAnd);
196-
CHECK_RANKS_FOR(BitwiseOr);
197-
CHECK_RANKS_FOR(BitwiseXor);
198-
CHECK_RANKS_FOR(IntDiv);
199-
CHECK_RANKS_FOR(LogicalAnd);
200-
CHECK_RANKS_FOR(LogicalLeftShift);
201-
CHECK_RANKS_FOR(LogicalRightShift);
202-
CHECK_RANKS_FOR(LogicalOr);
203-
CHECK_RANKS_FOR(LogicalXor);
204-
CHECK_RANKS_FOR(Maximum);
205-
CHECK_RANKS_FOR(Minimum);
206-
CHECK_RANKS_FOR(Mul);
207-
CHECK_RANKS_FOR(Pow);
208-
CHECK_RANKS_FOR(Sub);
209-
CHECK_RANKS_FOR(Table);
240+
CHECK_RANKS_AND_SIZES_FOR(Add);
241+
CHECK_RANKS_AND_SIZES_FOR(ArithmeticRightShift);
242+
CHECK_RANKS_AND_SIZES_FOR(BitwiseAnd);
243+
CHECK_RANKS_AND_SIZES_FOR(BitwiseOr);
244+
CHECK_RANKS_AND_SIZES_FOR(BitwiseXor);
245+
CHECK_RANKS_AND_SIZES_FOR(IntDiv);
246+
CHECK_RANKS_AND_SIZES_FOR(LogicalAnd);
247+
CHECK_RANKS_AND_SIZES_FOR(LogicalLeftShift);
248+
CHECK_RANKS_AND_SIZES_FOR(LogicalRightShift);
249+
CHECK_RANKS_AND_SIZES_FOR(LogicalOr);
250+
CHECK_RANKS_AND_SIZES_FOR(LogicalXor);
251+
CHECK_RANKS_AND_SIZES_FOR(Maximum);
252+
CHECK_RANKS_AND_SIZES_FOR(Minimum);
253+
CHECK_RANKS_AND_SIZES_FOR(Mul);
254+
CHECK_RANKS_AND_SIZES_FOR(Pow);
255+
CHECK_RANKS_AND_SIZES_FOR(Sub);
256+
CHECK_RANKS_AND_SIZES_FOR(Table);
210257
// all elementwise unary operators:
211-
CHECK_RANKS_FOR(Abs);
212-
CHECK_RANKS_FOR(BitwiseNot);
213-
CHECK_RANKS_FOR(Ceil);
214-
CHECK_RANKS_FOR(Clz);
215-
CHECK_RANKS_FOR(Exp);
216-
CHECK_RANKS_FOR(Floor);
217-
CHECK_RANKS_FOR(Log);
218-
CHECK_RANKS_FOR(LogicalNot);
219-
CHECK_RANKS_FOR(Negate);
220-
CHECK_RANKS_FOR(Reciprocal);
221-
CHECK_RANKS_FOR(Rsqrt);
258+
CHECK_RANKS_AND_SIZES_FOR(Abs);
259+
CHECK_RANKS_AND_SIZES_FOR(BitwiseNot);
260+
CHECK_RANKS_AND_SIZES_FOR(Ceil);
261+
CHECK_RANKS_AND_SIZES_FOR(Clz);
262+
CHECK_RANKS_AND_SIZES_FOR(Cos);
263+
CHECK_RANKS_AND_SIZES_FOR(Exp);
264+
CHECK_RANKS_AND_SIZES_FOR(Floor);
265+
CHECK_RANKS_AND_SIZES_FOR(Log);
266+
CHECK_RANKS_AND_SIZES_FOR(LogicalNot);
267+
CHECK_RANKS_AND_SIZES_FOR(Negate);
268+
CHECK_RANKS_AND_SIZES_FOR(Reciprocal);
269+
CHECK_RANKS_AND_SIZES_FOR(Rsqrt);
270+
CHECK_RANKS_AND_SIZES_FOR(Sin);
222271
// all elementwise ternary operators:
223-
CHECK_RANKS_FOR(Select);
272+
CHECK_RANKS_AND_SIZES_FOR(Select);
224273
// all comparison operators:
225-
CHECK_RANKS_FOR(Equal);
226-
CHECK_RANKS_FOR(Greater);
227-
CHECK_RANKS_FOR(GreaterEqual);
274+
CHECK_RANKS_AND_SIZES_FOR(Equal);
275+
CHECK_RANKS_AND_SIZES_FOR(Greater);
276+
CHECK_RANKS_AND_SIZES_FOR(GreaterEqual);
228277
// all reduction operators:
229-
CHECK_RANKS_FOR(ReduceAll);
230-
CHECK_RANKS_FOR(ReduceAny);
231-
CHECK_RANKS_FOR(ReduceMax);
232-
CHECK_RANKS_FOR(ReduceMin);
233-
CHECK_RANKS_FOR(ReduceProd);
234-
CHECK_RANKS_FOR(ReduceSum);
278+
CHECK_RANKS_AND_SIZES_FOR(ReduceAll);
279+
CHECK_RANKS_AND_SIZES_FOR(ReduceAny);
280+
CHECK_RANKS_AND_SIZES_FOR(ReduceMax);
281+
CHECK_RANKS_AND_SIZES_FOR(ReduceMin);
282+
CHECK_RANKS_AND_SIZES_FOR(ReduceProd);
283+
CHECK_RANKS_AND_SIZES_FOR(ReduceSum);
235284
// all data layout operators:
236-
CHECK_RANKS_FOR(Concat);
237-
CHECK_RANKS_FOR(Pad);
238-
CHECK_RANKS_FOR(Reshape);
239-
CHECK_RANKS_FOR(Reverse);
240-
CHECK_RANKS_FOR(Slice);
241-
CHECK_RANKS_FOR(Tile);
242-
CHECK_RANKS_FOR(Transpose);
285+
CHECK_RANKS_AND_SIZES_FOR(Concat);
286+
CHECK_RANKS_AND_SIZES_FOR(Pad);
287+
CHECK_RANKS_AND_SIZES_FOR(Reshape);
288+
CHECK_RANKS_AND_SIZES_FOR(Reverse);
289+
CHECK_RANKS_AND_SIZES_FOR(Slice);
290+
CHECK_RANKS_AND_SIZES_FOR(Tile);
291+
CHECK_RANKS_AND_SIZES_FOR(Transpose);
243292
// all type conversion operators:
244-
CHECK_RANKS_FOR(Cast);
245-
CHECK_RANKS_FOR(Rescale);
293+
CHECK_RANKS_AND_SIZES_FOR(Cast);
294+
CHECK_RANKS_AND_SIZES_FOR(Rescale);
246295
// all data nodes operators:
247-
CHECK_RANKS_FOR(Const);
248-
CHECK_RANKS_FOR(Identity);
296+
CHECK_RANKS_AND_SIZES_FOR(Const);
297+
CHECK_RANKS_AND_SIZES_FOR(Identity);
249298

250-
#undef CHECK_RANKS_FOR
299+
#undef CHECK_RANKS_AND_SIZES_FOR
251300
return true;
252301
}
253302

@@ -396,6 +445,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
396445
return true;
397446
}
398447

448+
bool levelCheckListSize(Operation *op) {
449+
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
450+
return levelCheckListSize(op, concat.getInput1().size(), "input1");
451+
}
452+
if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
453+
if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
454+
!levelCheckListSize(op, custom.getOutputList().size(),
455+
"output_list")) {
456+
return false;
457+
}
458+
}
459+
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
460+
if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
461+
!levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
462+
return false;
463+
}
464+
}
465+
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
466+
if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
467+
!levelCheckListSize(op, w.getOutput().size(), "outputs")) {
468+
return false;
469+
}
470+
}
471+
return true;
472+
}
473+
399474
// configure profile and level values from pass options profileName and
400475
// levelName
401476
void configLevelAndProfile() {
@@ -449,7 +524,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
449524
return success();
450525
}
451526

452-
if (!levelCheckRanks(op)) {
527+
if (!levelCheckRanksAndSizes(op)) {
453528
return failure();
454529
}
455530

@@ -465,6 +540,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
465540
return failure();
466541
}
467542

543+
// level check MAX_TENSOR_LIST_SIZE
544+
if (!levelCheckListSize(op)) {
545+
return failure();
546+
}
547+
468548
return success();
469549
}
470550

@@ -695,6 +775,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
695775
}
696776

697777
bool TosaValidation::isValidElementType(Type type) {
778+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(type))
779+
type = quantType.getStorageType();
780+
698781
if (isa<FloatType>(type)) {
699782
return type.isF32() || type.isF16() || type.isBF16();
700783
} else if (auto intTy = dyn_cast<IntegerType>(type)) {

0 commit comments

Comments
 (0)