Skip to content

Commit 5af3873

Browse files
Tai78641tatwaichong
andcommitted
[mlir][tosa] Add several level checks
Add the following types of level check to consolidate the level validity - Add rank level checks for ERF, COS, SIN and COND_IF. - Add MAX_LOG2_SIZE level check: The maximum value is 63 when the level is set to "none" and 31 when the level is set to "8K". - Add MAX_TENSOR_LIST_SIZE level check : The maximum value is 256 when the level is set to "none" and 64 when the level is set to "8K". - TOSA 1.0 spec does not allow operations with dynamic shapes, so an error should be raised instead. Co-authored-by: TatWai Chong <[email protected]> Change-Id: I797fafe504219e43950824c04839c7187065fe8e
1 parent 7371f69 commit 5af3873

File tree

2 files changed

+826
-100
lines changed

2 files changed

+826
-100
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 199 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,22 @@ struct TosaLevel {
6161
int32_t MAX_KERNEL = 0;
6262
int32_t MAX_STRIDE = 0;
6363
int32_t MAX_SCALE = 0;
64-
65-
// @todo: MAX_LOG2_SIZE value and checks
64+
int32_t MAX_LOG2_SIZE = 0;
65+
int32_t MAX_NESTING = 0;
66+
int32_t MAX_TENSOR_LIST_SIZE = 0;
6667

6768
bool operator==(const TosaLevel &rhs) {
6869
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
69-
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
70+
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
71+
MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
72+
MAX_NESTING == rhs.MAX_NESTING &&
73+
MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
7074
}
7175
};
7276

73-
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
74-
static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
77+
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
78+
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
79+
63, 256, 256};
7580

7681
//===----------------------------------------------------------------------===//
7782
// TOSA Validation Pass.
@@ -137,107 +142,188 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
137142
return true;
138143
}
139144

140-
bool levelCheckRank(Operation *op, const Value &v,
141-
const std::string &checkDesc) {
145+
bool levelCheckListSize(Operation *op, int32_t v,
146+
const std::string &checkDesc) {
147+
if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
148+
op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
149+
<< checkDesc;
150+
return false;
151+
}
152+
return true;
153+
}
154+
155+
bool levelCheckRankAndSizes(Operation *op, const Value &v,
156+
const std::string &operandOrResult,
157+
int32_t highest_rank) {
142158
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
143159
if (!type.hasRank()) {
144160
op->emitOpError() << "failed level check: unranked tensor";
145161
return false;
146162
}
147-
if (type.getRank() > tosaLevel.MAX_RANK) {
148-
op->emitOpError() << "failed level check: " << checkDesc;
163+
if (type.getRank() > highest_rank) {
164+
op->emitOpError() << "failed level check: " << operandOrResult
165+
<< " rank(shape) <= MAX_RANK";
166+
return false;
167+
}
168+
169+
auto shape = type.getShape();
170+
for (auto dim : shape) {
171+
if (mlir::ShapedType::isDynamic(dim)) {
172+
op->emitOpError() << "failed level check: " << operandOrResult
173+
<< " shape dimension cannot be dynamic";
174+
return false;
175+
}
176+
}
177+
178+
int64_t element_bits = type.getElementTypeBitWidth();
179+
int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
180+
int64_t size = element_bytes * type.getNumElements();
181+
182+
// According to 1.11. Tensor Definitions of Tosa spec, the value of
183+
// tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
184+
// defined in 1.7. Levels.
185+
// For each tensor, the number of tensor elements multiplied by the
186+
// element size in bytes must be representable as a tensor_size_t.
187+
const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
188+
if (size > max_size) {
189+
op->emitOpError()
190+
<< "failed level check: " << operandOrResult
191+
<< " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
149192
return false;
150193
}
151194
}
152195
return true;
153196
}
154197

155198
template <typename T>
156-
bool levelCheckRanksFor(Operation *op) {
157-
if (dyn_cast<T>(op)) {
158-
// level check ranks of all operands and results
159-
for (auto v : op->getOperands()) {
160-
if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
161-
return false;
162-
}
163-
for (auto v : op->getResults()) {
164-
if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
165-
return false;
166-
}
199+
bool levelCheckRanksAndSizesFor(T tosaOp) {
200+
// level check ranks of all operands and results
201+
auto op = tosaOp.getOperation();
202+
for (auto v : op->getOperands()) {
203+
if (!levelCheckRankAndSizes(op, v, "operand", tosaLevel.MAX_RANK))
204+
return false;
205+
}
206+
207+
for (auto v : op->getResults()) {
208+
if (!levelCheckRankAndSizes(op, v, "result", tosaLevel.MAX_RANK))
209+
return false;
167210
}
168211
return true;
169212
}
170213

171-
bool levelCheckRanks(Operation *op) {
172-
#define CHECK_RANKS_FOR(tosaOp) \
173-
if (!levelCheckRanksFor<tosaOp##Op>(op)) \
174-
return false;
214+
template <>
215+
bool levelCheckRanksAndSizesFor(tosa::ArgMaxOp tosaOp) {
216+
auto op = tosaOp.getOperation();
217+
if (!levelCheckRankAndSizes(op, tosaOp.getInput(), "operand",
218+
tosaLevel.MAX_RANK))
219+
return false;
220+
221+
// rank(output) = rank(input) - 1
222+
if (!levelCheckRankAndSizes(op, tosaOp.getOutput(), "result",
223+
tosaLevel.MAX_RANK - 1))
224+
return false;
225+
226+
return true;
227+
}
228+
229+
template <>
230+
bool levelCheckRanksAndSizesFor(tosa::IfOp tosaOp) {
231+
auto op = tosaOp.getOperation();
232+
233+
// Only the condition input has rank limitation.
234+
if (!levelCheckRankAndSizes(op, tosaOp.getCond(), "operand",
235+
tosaLevel.MAX_RANK))
236+
return false;
237+
238+
return true;
239+
}
240+
241+
bool levelCheckRanksAndSizes(Operation *op) {
242+
#define CHECK_RANKS_AND_SIZES_FOR(tosaOp) \
243+
if (isa<tosa::tosaOp##Op>(op)) \
244+
if (!levelCheckRanksAndSizesFor(cast<tosa::tosaOp##Op>(op))) \
245+
return false;
246+
247+
#define CHECK_RANKS_AND_SIZES_SKIP(tosaOp) \
248+
if (isa<tosa::tosaOp##Op>(op)) \
249+
return true;
175250

176251
// tensor operators:
177-
CHECK_RANKS_FOR(ArgMax);
252+
CHECK_RANKS_AND_SIZES_FOR(ArgMax);
178253
// all activation functions:
179-
CHECK_RANKS_FOR(Clamp);
180-
CHECK_RANKS_FOR(Sigmoid);
181-
CHECK_RANKS_FOR(Tanh);
254+
CHECK_RANKS_AND_SIZES_FOR(Clamp);
255+
CHECK_RANKS_AND_SIZES_FOR(Erf);
256+
CHECK_RANKS_AND_SIZES_FOR(Sigmoid);
257+
CHECK_RANKS_AND_SIZES_FOR(Tanh);
182258
// all elementwise binary operators:
183-
CHECK_RANKS_FOR(Add);
184-
CHECK_RANKS_FOR(ArithmeticRightShift);
185-
CHECK_RANKS_FOR(BitwiseAnd);
186-
CHECK_RANKS_FOR(BitwiseOr);
187-
CHECK_RANKS_FOR(BitwiseXor);
188-
CHECK_RANKS_FOR(IntDiv);
189-
CHECK_RANKS_FOR(LogicalAnd);
190-
CHECK_RANKS_FOR(LogicalLeftShift);
191-
CHECK_RANKS_FOR(LogicalRightShift);
192-
CHECK_RANKS_FOR(LogicalOr);
193-
CHECK_RANKS_FOR(LogicalXor);
194-
CHECK_RANKS_FOR(Maximum);
195-
CHECK_RANKS_FOR(Minimum);
196-
CHECK_RANKS_FOR(Mul);
197-
CHECK_RANKS_FOR(Pow);
198-
CHECK_RANKS_FOR(Sub);
199-
CHECK_RANKS_FOR(Table);
259+
CHECK_RANKS_AND_SIZES_FOR(Add);
260+
CHECK_RANKS_AND_SIZES_FOR(ArithmeticRightShift);
261+
CHECK_RANKS_AND_SIZES_FOR(BitwiseAnd);
262+
CHECK_RANKS_AND_SIZES_FOR(BitwiseOr);
263+
CHECK_RANKS_AND_SIZES_FOR(BitwiseXor);
264+
CHECK_RANKS_AND_SIZES_FOR(IntDiv);
265+
CHECK_RANKS_AND_SIZES_FOR(LogicalAnd);
266+
CHECK_RANKS_AND_SIZES_FOR(LogicalLeftShift);
267+
CHECK_RANKS_AND_SIZES_FOR(LogicalRightShift);
268+
CHECK_RANKS_AND_SIZES_FOR(LogicalOr);
269+
CHECK_RANKS_AND_SIZES_FOR(LogicalXor);
270+
CHECK_RANKS_AND_SIZES_FOR(Maximum);
271+
CHECK_RANKS_AND_SIZES_FOR(Minimum);
272+
CHECK_RANKS_AND_SIZES_FOR(Mul);
273+
CHECK_RANKS_AND_SIZES_FOR(Pow);
274+
CHECK_RANKS_AND_SIZES_FOR(Sub);
275+
CHECK_RANKS_AND_SIZES_FOR(Table);
200276
// all elementwise unary operators:
201-
CHECK_RANKS_FOR(Abs);
202-
CHECK_RANKS_FOR(BitwiseNot);
203-
CHECK_RANKS_FOR(Ceil);
204-
CHECK_RANKS_FOR(Clz);
205-
CHECK_RANKS_FOR(Exp);
206-
CHECK_RANKS_FOR(Floor);
207-
CHECK_RANKS_FOR(Log);
208-
CHECK_RANKS_FOR(LogicalNot);
209-
CHECK_RANKS_FOR(Negate);
210-
CHECK_RANKS_FOR(Reciprocal);
211-
CHECK_RANKS_FOR(Rsqrt);
277+
CHECK_RANKS_AND_SIZES_FOR(Abs);
278+
CHECK_RANKS_AND_SIZES_FOR(BitwiseNot);
279+
CHECK_RANKS_AND_SIZES_FOR(Ceil);
280+
CHECK_RANKS_AND_SIZES_FOR(Clz);
281+
CHECK_RANKS_AND_SIZES_FOR(Cos);
282+
CHECK_RANKS_AND_SIZES_FOR(Exp);
283+
CHECK_RANKS_AND_SIZES_FOR(Floor);
284+
CHECK_RANKS_AND_SIZES_FOR(Log);
285+
CHECK_RANKS_AND_SIZES_FOR(LogicalNot);
286+
CHECK_RANKS_AND_SIZES_FOR(Negate);
287+
CHECK_RANKS_AND_SIZES_FOR(Reciprocal);
288+
CHECK_RANKS_AND_SIZES_FOR(Rsqrt);
289+
CHECK_RANKS_AND_SIZES_FOR(Sin);
212290
// all elementwise ternary operators:
213-
CHECK_RANKS_FOR(Select);
291+
CHECK_RANKS_AND_SIZES_FOR(Select);
214292
// all comparison operators:
215-
CHECK_RANKS_FOR(Equal);
216-
CHECK_RANKS_FOR(Greater);
217-
CHECK_RANKS_FOR(GreaterEqual);
293+
CHECK_RANKS_AND_SIZES_FOR(Equal);
294+
CHECK_RANKS_AND_SIZES_FOR(Greater);
295+
CHECK_RANKS_AND_SIZES_FOR(GreaterEqual);
218296
// all reduction operators:
219-
CHECK_RANKS_FOR(ReduceAll);
220-
CHECK_RANKS_FOR(ReduceAny);
221-
CHECK_RANKS_FOR(ReduceMax);
222-
CHECK_RANKS_FOR(ReduceMin);
223-
CHECK_RANKS_FOR(ReduceProduct);
224-
CHECK_RANKS_FOR(ReduceSum);
297+
CHECK_RANKS_AND_SIZES_FOR(ReduceAll);
298+
CHECK_RANKS_AND_SIZES_FOR(ReduceAny);
299+
CHECK_RANKS_AND_SIZES_FOR(ReduceMax);
300+
CHECK_RANKS_AND_SIZES_FOR(ReduceMin);
301+
CHECK_RANKS_AND_SIZES_FOR(ReduceProduct);
302+
CHECK_RANKS_AND_SIZES_FOR(ReduceSum);
225303
// all data layout operators:
226-
CHECK_RANKS_FOR(Concat);
227-
CHECK_RANKS_FOR(Pad);
228-
CHECK_RANKS_FOR(Reshape);
229-
CHECK_RANKS_FOR(Reverse);
230-
CHECK_RANKS_FOR(Slice);
231-
CHECK_RANKS_FOR(Tile);
232-
CHECK_RANKS_FOR(Transpose);
304+
CHECK_RANKS_AND_SIZES_FOR(Concat);
305+
CHECK_RANKS_AND_SIZES_FOR(Pad);
306+
CHECK_RANKS_AND_SIZES_FOR(Reshape);
307+
CHECK_RANKS_AND_SIZES_FOR(Reverse);
308+
CHECK_RANKS_AND_SIZES_FOR(Slice);
309+
CHECK_RANKS_AND_SIZES_FOR(Tile);
310+
CHECK_RANKS_AND_SIZES_FOR(Transpose);
233311
// all type conversion operators:
234-
CHECK_RANKS_FOR(Cast);
235-
CHECK_RANKS_FOR(Rescale);
312+
CHECK_RANKS_AND_SIZES_FOR(Cast);
313+
CHECK_RANKS_AND_SIZES_FOR(Rescale);
314+
// control flow operators:
315+
CHECK_RANKS_AND_SIZES_FOR(If);
236316
// all data nodes operators:
237-
CHECK_RANKS_FOR(Const);
238-
CHECK_RANKS_FOR(Identity);
317+
CHECK_RANKS_AND_SIZES_FOR(Const);
318+
CHECK_RANKS_AND_SIZES_FOR(Identity);
319+
320+
// The following operators do not have level rank and size constraint.
321+
CHECK_RANKS_AND_SIZES_SKIP(Yield);
322+
CHECK_RANKS_AND_SIZES_SKIP(Custom);
323+
CHECK_RANKS_AND_SIZES_SKIP(While);
239324

240-
#undef CHECK_RANKS_FOR
325+
#undef CHECK_RANKS_AND_SIZES_FOR
326+
#undef CHECK_RANKS_AND_SIZES_SKIP
241327
return true;
242328
}
243329

@@ -386,6 +472,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
386472
return true;
387473
}
388474

475+
bool levelCheckListSize(Operation *op) {
476+
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
477+
return levelCheckListSize(op, concat.getInput1().size(), "input1");
478+
}
479+
if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
480+
if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
481+
!levelCheckListSize(op, custom.getOutputList().size(),
482+
"output_list")) {
483+
return false;
484+
}
485+
}
486+
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
487+
if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
488+
!levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
489+
return false;
490+
}
491+
}
492+
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
493+
if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
494+
!levelCheckListSize(op, w.getOutput().size(), "outputs")) {
495+
return false;
496+
}
497+
}
498+
return true;
499+
}
500+
389501
// configure profile and level values from pass options profileName and
390502
// levelName
391503
void configLevelAndProfile() {
@@ -439,7 +551,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
439551
return success();
440552
}
441553

442-
if (!levelCheckRanks(op)) {
554+
if (!levelCheckRanksAndSizes(op)) {
443555
return failure();
444556
}
445557

@@ -455,6 +567,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
455567
return failure();
456568
}
457569

570+
// level check MAX_TENSOR_LIST_SIZE
571+
if (!levelCheckListSize(op)) {
572+
return failure();
573+
}
574+
458575
return success();
459576
}
460577

@@ -685,6 +802,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
685802
}
686803

687804
bool TosaValidation::isValidElementType(Type type) {
805+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(type))
806+
type = quantType.getStorageType();
807+
688808
if (isa<FloatType>(type)) {
689809
return type.isF32() || type.isF16() || type.isBF16();
690810
} else if (auto intTy = dyn_cast<IntegerType>(type)) {

0 commit comments

Comments
 (0)