Skip to content

Commit 9672966

Browse files
Tai78641tatwaichong
andcommitted
[mlir][tosa] Add several level checks
Add the following types of level check to consolidate the level validity - Add rank level checks for Erf, Cos, and Sin. - Add MAX_LOG2_SIZE level check: The maximum value is 63 when the level is set to "none" and 31 when the level is set to "8K". - Add MAX_TENSOR_LIST_SIZE level check : The maximum value is 256 when the level is set to "none" and 64 when the level is set to "8K". Co-authored-by: TatWai Chong <[email protected]> Change-Id: I797fafe504219e43950824c04839c7187065fe8e
1 parent 4c9e14b commit 9672966

File tree

2 files changed

+741
-75
lines changed

2 files changed

+741
-75
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 151 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,22 @@ struct TosaLevel {
7070
int32_t MAX_KERNEL = 0;
7171
int32_t MAX_STRIDE = 0;
7272
int32_t MAX_SCALE = 0;
73-
74-
// @todo: MAX_LOG2_SIZE value and checks
73+
int32_t MAX_LOG2_SIZE = 0;
74+
int32_t MAX_NESTING = 0;
75+
int32_t MAX_TENSOR_LIST_SIZE = 0;
7576

7677
bool operator==(const TosaLevel &rhs) {
7778
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
78-
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
79+
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
80+
MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
81+
MAX_NESTING == rhs.MAX_NESTING &&
82+
MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
7983
}
8084
};
8185

82-
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
83-
static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
86+
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
87+
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
88+
63, 256, 256};
8489

8590
//===----------------------------------------------------------------------===//
8691
// TOSA Validation Pass.
@@ -147,107 +152,150 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
147152
return true;
148153
}
149154

150-
bool levelCheckRank(Operation *op, const Value &v,
151-
const std::string &checkDesc) {
155+
bool levelCheckListSize(Operation *op, int32_t v,
156+
const std::string &checkDesc) {
157+
if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
158+
op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
159+
<< checkDesc;
160+
return false;
161+
}
162+
return true;
163+
}
164+
165+
bool levelCheckRankAndSizes(Operation *op, const Value &v,
166+
const std::string &operandOrResult) {
152167
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
153168
if (!type.hasRank()) {
154169
op->emitOpError() << "failed level check: unranked tensor";
155170
return false;
156171
}
157172
if (type.getRank() > tosaLevel.MAX_RANK) {
158-
op->emitOpError() << "failed level check: " << checkDesc;
173+
op->emitOpError() << "failed level check: " << operandOrResult
174+
<< " rank(shape) <= MAX_RANK";
159175
return false;
160176
}
177+
178+
const int64_t max_dim = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
179+
const int64_t max_size = (INT64_C(1) << (tosaLevel.MAX_LOG2_SIZE + 1)) - 1;
180+
181+
auto shape = type.getShape();
182+
bool has_dynamic = false;
183+
for (auto dim : shape) {
184+
if (mlir::ShapedType::isDynamic(dim)) {
185+
has_dynamic = true;
186+
continue;
187+
}
188+
if (dim > max_dim) {
189+
op->emitOpError() << "failed level check: " << operandOrResult
190+
<< " shape dimension <= (1<<MAX_LOG2_SIZE) - 1";
191+
return false;
192+
}
193+
}
194+
if (!has_dynamic) {
195+
int64_t element_bits = type.getElementTypeBitWidth();
196+
int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
197+
int64_t size = element_bytes * type.getNumElements();
198+
if (size > max_size) {
199+
op->emitOpError()
200+
<< "failed level check: " << operandOrResult << " tensor size "
201+
<< size << " (in bytes) <= "
202+
<< "(1<<MAX_LOG2_SIZE+1) - 1, where max_size = " << max_size;
203+
return false;
204+
}
205+
}
161206
}
162207
return true;
163208
}
164209

165210
template <typename T>
166-
bool levelCheckRanksFor(Operation *op) {
211+
bool levelCheckRanksAndSizesFor(Operation *op) {
167212
if (dyn_cast<T>(op)) {
168213
// level check ranks of all operands and results
169214
for (auto v : op->getOperands()) {
170-
if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
215+
if (!levelCheckRankAndSizes(op, v, "operand"))
171216
return false;
172217
}
173218
for (auto v : op->getResults()) {
174-
if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
219+
if (!levelCheckRankAndSizes(op, v, "result"))
175220
return false;
176221
}
177222
}
178223
return true;
179224
}
180225

181-
bool levelCheckRanks(Operation *op) {
182-
#define CHECK_RANKS_FOR(tosaOp) \
183-
if (!levelCheckRanksFor<tosaOp##Op>(op)) \
226+
bool levelCheckRanksAndSizes(Operation *op) {
227+
#define CHECK_RANKS_AND_SIZES_FOR(tosaOp) \
228+
if (!levelCheckRanksAndSizesFor<tosaOp##Op>(op)) \
184229
return false;
185230

186231
// tensor operators:
187-
CHECK_RANKS_FOR(ArgMax);
232+
CHECK_RANKS_AND_SIZES_FOR(ArgMax);
188233
// all activation functions:
189-
CHECK_RANKS_FOR(Clamp);
190-
CHECK_RANKS_FOR(Sigmoid);
191-
CHECK_RANKS_FOR(Tanh);
234+
CHECK_RANKS_AND_SIZES_FOR(Clamp);
235+
CHECK_RANKS_AND_SIZES_FOR(Erf);
236+
CHECK_RANKS_AND_SIZES_FOR(Sigmoid);
237+
CHECK_RANKS_AND_SIZES_FOR(Tanh);
192238
// all elementwise binary operators:
193-
CHECK_RANKS_FOR(Add);
194-
CHECK_RANKS_FOR(ArithmeticRightShift);
195-
CHECK_RANKS_FOR(BitwiseAnd);
196-
CHECK_RANKS_FOR(BitwiseOr);
197-
CHECK_RANKS_FOR(BitwiseXor);
198-
CHECK_RANKS_FOR(IntDiv);
199-
CHECK_RANKS_FOR(LogicalAnd);
200-
CHECK_RANKS_FOR(LogicalLeftShift);
201-
CHECK_RANKS_FOR(LogicalRightShift);
202-
CHECK_RANKS_FOR(LogicalOr);
203-
CHECK_RANKS_FOR(LogicalXor);
204-
CHECK_RANKS_FOR(Maximum);
205-
CHECK_RANKS_FOR(Minimum);
206-
CHECK_RANKS_FOR(Mul);
207-
CHECK_RANKS_FOR(Pow);
208-
CHECK_RANKS_FOR(Sub);
209-
CHECK_RANKS_FOR(Table);
239+
CHECK_RANKS_AND_SIZES_FOR(Add);
240+
CHECK_RANKS_AND_SIZES_FOR(ArithmeticRightShift);
241+
CHECK_RANKS_AND_SIZES_FOR(BitwiseAnd);
242+
CHECK_RANKS_AND_SIZES_FOR(BitwiseOr);
243+
CHECK_RANKS_AND_SIZES_FOR(BitwiseXor);
244+
CHECK_RANKS_AND_SIZES_FOR(IntDiv);
245+
CHECK_RANKS_AND_SIZES_FOR(LogicalAnd);
246+
CHECK_RANKS_AND_SIZES_FOR(LogicalLeftShift);
247+
CHECK_RANKS_AND_SIZES_FOR(LogicalRightShift);
248+
CHECK_RANKS_AND_SIZES_FOR(LogicalOr);
249+
CHECK_RANKS_AND_SIZES_FOR(LogicalXor);
250+
CHECK_RANKS_AND_SIZES_FOR(Maximum);
251+
CHECK_RANKS_AND_SIZES_FOR(Minimum);
252+
CHECK_RANKS_AND_SIZES_FOR(Mul);
253+
CHECK_RANKS_AND_SIZES_FOR(Pow);
254+
CHECK_RANKS_AND_SIZES_FOR(Sub);
255+
CHECK_RANKS_AND_SIZES_FOR(Table);
210256
// all elementwise unary operators:
211-
CHECK_RANKS_FOR(Abs);
212-
CHECK_RANKS_FOR(BitwiseNot);
213-
CHECK_RANKS_FOR(Ceil);
214-
CHECK_RANKS_FOR(Clz);
215-
CHECK_RANKS_FOR(Exp);
216-
CHECK_RANKS_FOR(Floor);
217-
CHECK_RANKS_FOR(Log);
218-
CHECK_RANKS_FOR(LogicalNot);
219-
CHECK_RANKS_FOR(Negate);
220-
CHECK_RANKS_FOR(Reciprocal);
221-
CHECK_RANKS_FOR(Rsqrt);
257+
CHECK_RANKS_AND_SIZES_FOR(Abs);
258+
CHECK_RANKS_AND_SIZES_FOR(BitwiseNot);
259+
CHECK_RANKS_AND_SIZES_FOR(Ceil);
260+
CHECK_RANKS_AND_SIZES_FOR(Clz);
261+
CHECK_RANKS_AND_SIZES_FOR(Cos);
262+
CHECK_RANKS_AND_SIZES_FOR(Exp);
263+
CHECK_RANKS_AND_SIZES_FOR(Floor);
264+
CHECK_RANKS_AND_SIZES_FOR(Log);
265+
CHECK_RANKS_AND_SIZES_FOR(LogicalNot);
266+
CHECK_RANKS_AND_SIZES_FOR(Negate);
267+
CHECK_RANKS_AND_SIZES_FOR(Reciprocal);
268+
CHECK_RANKS_AND_SIZES_FOR(Rsqrt);
269+
CHECK_RANKS_AND_SIZES_FOR(Sin);
222270
// all elementwise ternary operators:
223-
CHECK_RANKS_FOR(Select);
271+
CHECK_RANKS_AND_SIZES_FOR(Select);
224272
// all comparison operators:
225-
CHECK_RANKS_FOR(Equal);
226-
CHECK_RANKS_FOR(Greater);
227-
CHECK_RANKS_FOR(GreaterEqual);
273+
CHECK_RANKS_AND_SIZES_FOR(Equal);
274+
CHECK_RANKS_AND_SIZES_FOR(Greater);
275+
CHECK_RANKS_AND_SIZES_FOR(GreaterEqual);
228276
// all reduction operators:
229-
CHECK_RANKS_FOR(ReduceAll);
230-
CHECK_RANKS_FOR(ReduceAny);
231-
CHECK_RANKS_FOR(ReduceMax);
232-
CHECK_RANKS_FOR(ReduceMin);
233-
CHECK_RANKS_FOR(ReduceProd);
234-
CHECK_RANKS_FOR(ReduceSum);
277+
CHECK_RANKS_AND_SIZES_FOR(ReduceAll);
278+
CHECK_RANKS_AND_SIZES_FOR(ReduceAny);
279+
CHECK_RANKS_AND_SIZES_FOR(ReduceMax);
280+
CHECK_RANKS_AND_SIZES_FOR(ReduceMin);
281+
CHECK_RANKS_AND_SIZES_FOR(ReduceProd);
282+
CHECK_RANKS_AND_SIZES_FOR(ReduceSum);
235283
// all data layout operators:
236-
CHECK_RANKS_FOR(Concat);
237-
CHECK_RANKS_FOR(Pad);
238-
CHECK_RANKS_FOR(Reshape);
239-
CHECK_RANKS_FOR(Reverse);
240-
CHECK_RANKS_FOR(Slice);
241-
CHECK_RANKS_FOR(Tile);
242-
CHECK_RANKS_FOR(Transpose);
284+
CHECK_RANKS_AND_SIZES_FOR(Concat);
285+
CHECK_RANKS_AND_SIZES_FOR(Pad);
286+
CHECK_RANKS_AND_SIZES_FOR(Reshape);
287+
CHECK_RANKS_AND_SIZES_FOR(Reverse);
288+
CHECK_RANKS_AND_SIZES_FOR(Slice);
289+
CHECK_RANKS_AND_SIZES_FOR(Tile);
290+
CHECK_RANKS_AND_SIZES_FOR(Transpose);
243291
// all type conversion operators:
244-
CHECK_RANKS_FOR(Cast);
245-
CHECK_RANKS_FOR(Rescale);
292+
CHECK_RANKS_AND_SIZES_FOR(Cast);
293+
CHECK_RANKS_AND_SIZES_FOR(Rescale);
246294
// all data nodes operators:
247-
CHECK_RANKS_FOR(Const);
248-
CHECK_RANKS_FOR(Identity);
295+
CHECK_RANKS_AND_SIZES_FOR(Const);
296+
CHECK_RANKS_AND_SIZES_FOR(Identity);
249297

250-
#undef CHECK_RANKS_FOR
298+
#undef CHECK_RANKS_AND_SIZES_FOR
251299
return true;
252300
}
253301

@@ -396,6 +444,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
396444
return true;
397445
}
398446

447+
bool levelCheckListSize(Operation *op) {
448+
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
449+
return levelCheckListSize(op, concat.getInput1().size(), "input1");
450+
}
451+
if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
452+
if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
453+
!levelCheckListSize(op, custom.getOutputList().size(),
454+
"output_list")) {
455+
return false;
456+
}
457+
}
458+
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
459+
if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
460+
!levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
461+
return false;
462+
}
463+
}
464+
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
465+
if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
466+
!levelCheckListSize(op, w.getOutput().size(), "outputs")) {
467+
return false;
468+
}
469+
}
470+
return true;
471+
}
472+
399473
// configure profile and level values from pass options profileName and
400474
// levelName
401475
void configLevelAndProfile() {
@@ -449,7 +523,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
449523
return success();
450524
}
451525

452-
if (!levelCheckRanks(op)) {
526+
if (!levelCheckRanksAndSizes(op)) {
453527
return failure();
454528
}
455529

@@ -465,6 +539,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
465539
return failure();
466540
}
467541

542+
// level check MAX_TENSOR_LIST_SIZE
543+
if (!levelCheckListSize(op)) {
544+
return failure();
545+
}
546+
468547
return success();
469548
}
470549

@@ -695,6 +774,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
695774
}
696775

697776
bool TosaValidation::isValidElementType(Type type) {
777+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(type))
778+
type = quantType.getStorageType();
779+
698780
if (isa<FloatType>(type)) {
699781
return type.isF32() || type.isF16() || type.isBF16();
700782
} else if (auto intTy = dyn_cast<IntegerType>(type)) {

0 commit comments

Comments
 (0)