Skip to content

Commit 23707c2

Browse files
Tai78641tatwaichong
andcommitted
[mlir][tosa] Add several level checks
Add the following types of level check to consolidate the level validity - Add rank level checks for Erf, Cos, and Sin. - Add MAX_LOG2_SIZE level check: The maximum value is 63 when the level is set to "none" and 31 when the level is set to "8K". - Add MAX_TENSOR_LIST_SIZE level check : The maximum value is 256 when the level is set to "none" and 64 when the level is set to "8K". Co-authored-by: TatWai Chong <[email protected]> Change-Id: I797fafe504219e43950824c04839c7187065fe8e
1 parent 3c46deb commit 23707c2

File tree

2 files changed

+740
-71
lines changed

2 files changed

+740
-71
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 150 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,22 @@ struct TosaLevel {
7070
int32_t MAX_KERNEL = 0;
7171
int32_t MAX_STRIDE = 0;
7272
int32_t MAX_SCALE = 0;
73-
74-
// @todo: MAX_LOG2_SIZE value and checks
73+
int32_t MAX_LOG2_SIZE = 0;
74+
int32_t MAX_NESTING = 0;
75+
int32_t MAX_TENSOR_LIST_SIZE = 0;
7576

7677
bool operator==(const TosaLevel &rhs) {
7778
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
78-
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
79+
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
80+
MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
81+
MAX_NESTING == rhs.MAX_NESTING &&
82+
MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
7983
}
8084
};
8185

82-
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
83-
static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
86+
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
87+
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
88+
63, 256, 256};
8489

8590
//===----------------------------------------------------------------------===//
8691
// TOSA Validation Pass.
@@ -147,107 +152,149 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
147152
return true;
148153
}
149154

150-
bool levelCheckRank(Operation *op, const Value &v,
151-
const std::string &checkDesc) {
155+
bool levelCheckListSize(Operation *op, int32_t v,
156+
const std::string &checkDesc) {
157+
if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
158+
op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
159+
<< checkDesc;
160+
return false;
161+
}
162+
return true;
163+
}
164+
165+
bool levelCheckRankAndSizes(Operation *op, const Value &v,
166+
const std::string &operandOrResult) {
152167
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
153168
if (!type.hasRank()) {
154169
op->emitOpError() << "failed level check: unranked tensor";
155170
return false;
156171
}
157172
if (type.getRank() > tosaLevel.MAX_RANK) {
158-
op->emitOpError() << "failed level check: " << checkDesc;
173+
op->emitOpError() << "failed level check: " << operandOrResult
174+
<< " rank(shape) <= MAX_RANK";
159175
return false;
160176
}
177+
178+
const int64_t max_dim = (1L << tosaLevel.MAX_LOG2_SIZE) - 1;
179+
const int64_t max_size = (1L << (tosaLevel.MAX_LOG2_SIZE + 1)) - 1;
180+
181+
auto shape = type.getShape();
182+
bool has_dynamic = false;
183+
for (auto dim : shape) {
184+
if (mlir::ShapedType::isDynamic(dim)) {
185+
has_dynamic = true;
186+
continue;
187+
}
188+
if (dim > max_dim) {
189+
op->emitOpError() << "failed level check: " << operandOrResult
190+
<< " shape dimension <= (1<<MAX_LOG2_SIZE) - 1";
191+
return false;
192+
}
193+
}
194+
if (!has_dynamic) {
195+
int64_t element_bits = type.getElementTypeBitWidth();
196+
int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
197+
int64_t size = element_bytes * type.getNumElements();
198+
if (size > max_size) {
199+
op->emitOpError()
200+
<< "failed level check: " << operandOrResult
201+
<< " tensor size (in bytes) <= (1<<MAX_LOG2_SIZE+1) - 1";
202+
return false;
203+
}
204+
}
161205
}
162206
return true;
163207
}
164208

165209
template <typename T>
166-
bool levelCheckRanksFor(Operation *op) {
210+
bool levelCheckRanksAndSizesFor(Operation *op) {
167211
if (dyn_cast<T>(op)) {
168212
// level check ranks of all operands and results
169213
for (auto v : op->getOperands()) {
170-
if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
214+
if (!levelCheckRankAndSizes(op, v, "operand"))
171215
return false;
172216
}
173217
for (auto v : op->getResults()) {
174-
if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
218+
if (!levelCheckRankAndSizes(op, v, "result"))
175219
return false;
176220
}
177221
}
178222
return true;
179223
}
180224

181-
bool levelCheckRanks(Operation *op) {
182-
#define CHECK_RANKS_FOR(tosaOp) \
183-
if (!levelCheckRanksFor<tosaOp##Op>(op)) \
225+
bool levelCheckRanksAndSizes(Operation *op) {
226+
#define CHECK_RANKS_AND_SIZES_FOR(tosaOp) \
227+
if (!levelCheckRanksAndSizesFor<tosaOp##Op>(op)) \
184228
return false;
185229

186230
// tensor operators:
187-
CHECK_RANKS_FOR(ArgMax);
231+
CHECK_RANKS_AND_SIZES_FOR(ArgMax);
188232
// all activation functions:
189-
CHECK_RANKS_FOR(Clamp);
190-
CHECK_RANKS_FOR(Sigmoid);
191-
CHECK_RANKS_FOR(Tanh);
233+
CHECK_RANKS_AND_SIZES_FOR(Clamp);
234+
CHECK_RANKS_AND_SIZES_FOR(Erf);
235+
CHECK_RANKS_AND_SIZES_FOR(Sigmoid);
236+
CHECK_RANKS_AND_SIZES_FOR(Tanh);
192237
// all elementwise binary operators:
193-
CHECK_RANKS_FOR(Add);
194-
CHECK_RANKS_FOR(ArithmeticRightShift);
195-
CHECK_RANKS_FOR(BitwiseAnd);
196-
CHECK_RANKS_FOR(BitwiseOr);
197-
CHECK_RANKS_FOR(BitwiseXor);
198-
CHECK_RANKS_FOR(IntDiv);
199-
CHECK_RANKS_FOR(LogicalAnd);
200-
CHECK_RANKS_FOR(LogicalLeftShift);
201-
CHECK_RANKS_FOR(LogicalRightShift);
202-
CHECK_RANKS_FOR(LogicalOr);
203-
CHECK_RANKS_FOR(LogicalXor);
204-
CHECK_RANKS_FOR(Maximum);
205-
CHECK_RANKS_FOR(Minimum);
206-
CHECK_RANKS_FOR(Mul);
207-
CHECK_RANKS_FOR(Pow);
208-
CHECK_RANKS_FOR(Sub);
209-
CHECK_RANKS_FOR(Table);
238+
CHECK_RANKS_AND_SIZES_FOR(Add);
239+
CHECK_RANKS_AND_SIZES_FOR(ArithmeticRightShift);
240+
CHECK_RANKS_AND_SIZES_FOR(BitwiseAnd);
241+
CHECK_RANKS_AND_SIZES_FOR(BitwiseOr);
242+
CHECK_RANKS_AND_SIZES_FOR(BitwiseXor);
243+
CHECK_RANKS_AND_SIZES_FOR(IntDiv);
244+
CHECK_RANKS_AND_SIZES_FOR(LogicalAnd);
245+
CHECK_RANKS_AND_SIZES_FOR(LogicalLeftShift);
246+
CHECK_RANKS_AND_SIZES_FOR(LogicalRightShift);
247+
CHECK_RANKS_AND_SIZES_FOR(LogicalOr);
248+
CHECK_RANKS_AND_SIZES_FOR(LogicalXor);
249+
CHECK_RANKS_AND_SIZES_FOR(Maximum);
250+
CHECK_RANKS_AND_SIZES_FOR(Minimum);
251+
CHECK_RANKS_AND_SIZES_FOR(Mul);
252+
CHECK_RANKS_AND_SIZES_FOR(Pow);
253+
CHECK_RANKS_AND_SIZES_FOR(Sub);
254+
CHECK_RANKS_AND_SIZES_FOR(Table);
210255
// all elementwise unary operators:
211-
CHECK_RANKS_FOR(Abs);
212-
CHECK_RANKS_FOR(BitwiseNot);
213-
CHECK_RANKS_FOR(Ceil);
214-
CHECK_RANKS_FOR(Clz);
215-
CHECK_RANKS_FOR(Exp);
216-
CHECK_RANKS_FOR(Floor);
217-
CHECK_RANKS_FOR(Log);
218-
CHECK_RANKS_FOR(LogicalNot);
219-
CHECK_RANKS_FOR(Negate);
220-
CHECK_RANKS_FOR(Reciprocal);
221-
CHECK_RANKS_FOR(Rsqrt);
256+
CHECK_RANKS_AND_SIZES_FOR(Abs);
257+
CHECK_RANKS_AND_SIZES_FOR(BitwiseNot);
258+
CHECK_RANKS_AND_SIZES_FOR(Ceil);
259+
CHECK_RANKS_AND_SIZES_FOR(Clz);
260+
CHECK_RANKS_AND_SIZES_FOR(Cos);
261+
CHECK_RANKS_AND_SIZES_FOR(Exp);
262+
CHECK_RANKS_AND_SIZES_FOR(Floor);
263+
CHECK_RANKS_AND_SIZES_FOR(Log);
264+
CHECK_RANKS_AND_SIZES_FOR(LogicalNot);
265+
CHECK_RANKS_AND_SIZES_FOR(Negate);
266+
CHECK_RANKS_AND_SIZES_FOR(Reciprocal);
267+
CHECK_RANKS_AND_SIZES_FOR(Rsqrt);
268+
CHECK_RANKS_AND_SIZES_FOR(Sin);
222269
// all elementwise ternary operators:
223-
CHECK_RANKS_FOR(Select);
270+
CHECK_RANKS_AND_SIZES_FOR(Select);
224271
// all comparison operators:
225-
CHECK_RANKS_FOR(Equal);
226-
CHECK_RANKS_FOR(Greater);
227-
CHECK_RANKS_FOR(GreaterEqual);
272+
CHECK_RANKS_AND_SIZES_FOR(Equal);
273+
CHECK_RANKS_AND_SIZES_FOR(Greater);
274+
CHECK_RANKS_AND_SIZES_FOR(GreaterEqual);
228275
// all reduction operators:
229-
CHECK_RANKS_FOR(ReduceAll);
230-
CHECK_RANKS_FOR(ReduceAny);
231-
CHECK_RANKS_FOR(ReduceMax);
232-
CHECK_RANKS_FOR(ReduceMin);
233-
CHECK_RANKS_FOR(ReduceProd);
234-
CHECK_RANKS_FOR(ReduceSum);
276+
CHECK_RANKS_AND_SIZES_FOR(ReduceAll);
277+
CHECK_RANKS_AND_SIZES_FOR(ReduceAny);
278+
CHECK_RANKS_AND_SIZES_FOR(ReduceMax);
279+
CHECK_RANKS_AND_SIZES_FOR(ReduceMin);
280+
CHECK_RANKS_AND_SIZES_FOR(ReduceProd);
281+
CHECK_RANKS_AND_SIZES_FOR(ReduceSum);
235282
// all data layout operators:
236-
CHECK_RANKS_FOR(Concat);
237-
CHECK_RANKS_FOR(Pad);
238-
CHECK_RANKS_FOR(Reshape);
239-
CHECK_RANKS_FOR(Reverse);
240-
CHECK_RANKS_FOR(Slice);
241-
CHECK_RANKS_FOR(Tile);
242-
CHECK_RANKS_FOR(Transpose);
283+
CHECK_RANKS_AND_SIZES_FOR(Concat);
284+
CHECK_RANKS_AND_SIZES_FOR(Pad);
285+
CHECK_RANKS_AND_SIZES_FOR(Reshape);
286+
CHECK_RANKS_AND_SIZES_FOR(Reverse);
287+
CHECK_RANKS_AND_SIZES_FOR(Slice);
288+
CHECK_RANKS_AND_SIZES_FOR(Tile);
289+
CHECK_RANKS_AND_SIZES_FOR(Transpose);
243290
// all type conversion operators:
244-
CHECK_RANKS_FOR(Cast);
245-
CHECK_RANKS_FOR(Rescale);
291+
CHECK_RANKS_AND_SIZES_FOR(Cast);
292+
CHECK_RANKS_AND_SIZES_FOR(Rescale);
246293
// all data nodes operators:
247-
CHECK_RANKS_FOR(Const);
248-
CHECK_RANKS_FOR(Identity);
294+
CHECK_RANKS_AND_SIZES_FOR(Const);
295+
CHECK_RANKS_AND_SIZES_FOR(Identity);
249296

250-
#undef CHECK_RANKS_FOR
297+
#undef CHECK_RANKS_AND_SIZES_FOR
251298
return true;
252299
}
253300

@@ -396,6 +443,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
396443
return true;
397444
}
398445

446+
bool levelCheckListSize(Operation *op) {
447+
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
448+
return levelCheckListSize(op, concat.getInput1().size(), "input1");
449+
}
450+
if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
451+
if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
452+
!levelCheckListSize(op, custom.getOutputList().size(),
453+
"output_list")) {
454+
return false;
455+
}
456+
}
457+
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
458+
if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
459+
!levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
460+
return false;
461+
}
462+
}
463+
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
464+
if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
465+
!levelCheckListSize(op, w.getOutput().size(), "outputs")) {
466+
return false;
467+
}
468+
}
469+
return true;
470+
}
471+
399472
// configure profile and level values from pass options profileName and
400473
// levelName
401474
void configLevelAndProfile() {
@@ -449,7 +522,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
449522
return success();
450523
}
451524

452-
if (!levelCheckRanks(op)) {
525+
if (!levelCheckRanksAndSizes(op)) {
453526
return failure();
454527
}
455528

@@ -465,6 +538,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
465538
return failure();
466539
}
467540

541+
// level check MAX_TENSOR_LIST_SIZE
542+
if (!levelCheckListSize(op)) {
543+
return failure();
544+
}
545+
468546
return success();
469547
}
470548

@@ -695,6 +773,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
695773
}
696774

697775
bool TosaValidation::isValidElementType(Type type) {
776+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(type))
777+
type = quantType.getStorageType();
778+
698779
if (isa<FloatType>(type)) {
699780
return type.isF32() || type.isF16() || type.isBF16();
700781
} else if (auto intTy = dyn_cast<IntegerType>(type)) {

0 commit comments

Comments
 (0)