Skip to content

Commit d71d1d2

Browse files
Tai78641tatwaichong
andcommitted
[mlir][tosa] Add several level checks
Add the following types of level check to consolidate the level validity - Complete rank level checks for operations. - Add MAX_LOG2_SIZE level check: The maximum value is 63 when the level is set to "none" and 31 when the level is set to "8K". - Add MAX_TENSOR_LIST_SIZE level check : The maximum value is 256 when the level is set to "none" and 64 when the level is set to "8K". - TOSA 1.0 spec does not allow operations with dynamic shapes, so an error should be raised instead. Co-authored-by: TatWai Chong <[email protected]> Change-Id: I797fafe504219e43950824c04839c7187065fe8e
1 parent bc91acc commit d71d1d2

File tree

2 files changed

+837
-106
lines changed

2 files changed

+837
-106
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 199 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,22 @@ struct TosaLevel {
6161
int32_t MAX_KERNEL = 0;
6262
int32_t MAX_STRIDE = 0;
6363
int32_t MAX_SCALE = 0;
64-
65-
// @todo: MAX_LOG2_SIZE value and checks
64+
int32_t MAX_LOG2_SIZE = 0;
65+
int32_t MAX_NESTING = 0;
66+
int32_t MAX_TENSOR_LIST_SIZE = 0;
6667

6768
bool operator==(const TosaLevel &rhs) {
6869
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
69-
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
70+
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
71+
MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
72+
MAX_NESTING == rhs.MAX_NESTING &&
73+
MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
7074
}
7175
};
7276

73-
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
74-
static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
77+
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
78+
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
79+
63, 256, 256};
7580

7681
//===----------------------------------------------------------------------===//
7782
// TOSA Validation Pass.
@@ -111,133 +116,212 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
111116
constCheckers.emplace_back(checkConstantOperandPad);
112117
}
113118

114-
bool levelCheckKernel(Operation *op, int32_t v,
115-
const std::string &checkDesc) {
119+
bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
116120
if (v > tosaLevel.MAX_KERNEL) {
117121
op->emitOpError() << "failed level check: " << checkDesc;
118122
return false;
119123
}
120124
return true;
121125
}
122126

123-
bool levelCheckStride(Operation *op, int32_t v,
124-
const std::string &checkDesc) {
127+
bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
125128
if (v > tosaLevel.MAX_STRIDE) {
126129
op->emitOpError() << "failed level check: " << checkDesc;
127130
return false;
128131
}
129132
return true;
130133
}
131134

132-
bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
135+
bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
133136
if (v > tosaLevel.MAX_SCALE) {
134137
op->emitOpError() << "failed level check: " << checkDesc;
135138
return false;
136139
}
137140
return true;
138141
}
139142

140-
bool levelCheckRank(Operation *op, const Value &v,
141-
const std::string &checkDesc) {
143+
bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
144+
if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
145+
op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
146+
<< checkDesc;
147+
return false;
148+
}
149+
return true;
150+
}
151+
152+
bool levelCheckRankAndSizes(Operation *op, const Value &v,
153+
const StringRef operandOrResult,
154+
int32_t highest_rank) {
142155
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
143156
if (!type.hasRank()) {
144157
op->emitOpError() << "failed level check: unranked tensor";
145158
return false;
146159
}
147-
if (type.getRank() > tosaLevel.MAX_RANK) {
148-
op->emitOpError() << "failed level check: " << checkDesc;
160+
if (type.getRank() > highest_rank) {
161+
op->emitOpError() << "failed level check: " << operandOrResult
162+
<< " rank(shape) <= MAX_RANK";
163+
return false;
164+
}
165+
166+
auto shape = type.getShape();
167+
for (auto dim : shape) {
168+
if (mlir::ShapedType::isDynamic(dim)) {
169+
op->emitOpError() << "failed level check: " << operandOrResult
170+
<< " shape dimension cannot be dynamic";
171+
return false;
172+
}
173+
}
174+
175+
int64_t element_bits = type.getElementTypeBitWidth();
176+
int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
177+
int64_t size = element_bytes * type.getNumElements();
178+
179+
// According to 1.11. Tensor Definitions of Tosa spec, the value of
180+
// tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
181+
// defined in 1.7. Levels.
182+
// For each tensor, the number of tensor elements multiplied by the
183+
// element size in bytes must be representable as a tensor_size_t.
184+
const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
185+
if (size > max_size) {
186+
op->emitOpError()
187+
<< "failed level check: " << operandOrResult
188+
<< " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
149189
return false;
150190
}
151191
}
152192
return true;
153193
}
154194

155195
template <typename T>
156-
bool levelCheckRanksFor(Operation *op) {
157-
if (dyn_cast<T>(op)) {
158-
// level check ranks of all operands and results
159-
for (auto v : op->getOperands()) {
160-
if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
161-
return false;
162-
}
163-
for (auto v : op->getResults()) {
164-
if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
165-
return false;
166-
}
196+
bool levelCheckRanksAndSizesFor(T tosaOp) {
197+
// level check ranks of all operands and results
198+
auto op = tosaOp.getOperation();
199+
for (auto v : op->getOperands()) {
200+
if (!levelCheckRankAndSizes(op, v, "operand", tosaLevel.MAX_RANK))
201+
return false;
202+
}
203+
204+
for (auto v : op->getResults()) {
205+
if (!levelCheckRankAndSizes(op, v, "result", tosaLevel.MAX_RANK))
206+
return false;
167207
}
168208
return true;
169209
}
170210

171-
bool levelCheckRanks(Operation *op) {
172-
#define CHECK_RANKS_FOR(tosaOp) \
173-
if (!levelCheckRanksFor<tosaOp##Op>(op)) \
174-
return false;
211+
template <>
212+
bool levelCheckRanksAndSizesFor(tosa::ArgMaxOp tosaOp) {
213+
auto op = tosaOp.getOperation();
214+
if (!levelCheckRankAndSizes(op, tosaOp.getInput(), "operand",
215+
tosaLevel.MAX_RANK))
216+
return false;
217+
218+
// rank(output) = rank(input) - 1
219+
if (!levelCheckRankAndSizes(op, tosaOp.getOutput(), "result",
220+
tosaLevel.MAX_RANK - 1))
221+
return false;
222+
223+
return true;
224+
}
225+
226+
template <>
227+
bool levelCheckRanksAndSizesFor(tosa::IfOp tosaOp) {
228+
auto op = tosaOp.getOperation();
229+
230+
// Only the condition input has rank limitation.
231+
if (!levelCheckRankAndSizes(op, tosaOp.getCond(), "operand",
232+
tosaLevel.MAX_RANK))
233+
return false;
234+
235+
return true;
236+
}
237+
238+
bool levelCheckRanksAndSizes(Operation *op) {
239+
#define CHECK_RANKS_AND_SIZES_FOR(tosaOp) \
240+
if (isa<tosa::tosaOp##Op>(op)) \
241+
if (!levelCheckRanksAndSizesFor(cast<tosa::tosaOp##Op>(op))) \
242+
return false;
243+
244+
#define CHECK_RANKS_AND_SIZES_SKIP(tosaOp) \
245+
if (isa<tosa::tosaOp##Op>(op)) \
246+
return true;
175247

176248
// tensor operators:
177-
CHECK_RANKS_FOR(ArgMax);
249+
CHECK_RANKS_AND_SIZES_FOR(ArgMax);
178250
// all activation functions:
179-
CHECK_RANKS_FOR(Clamp);
180-
CHECK_RANKS_FOR(Sigmoid);
181-
CHECK_RANKS_FOR(Tanh);
251+
CHECK_RANKS_AND_SIZES_FOR(Clamp);
252+
CHECK_RANKS_AND_SIZES_FOR(Erf);
253+
CHECK_RANKS_AND_SIZES_FOR(Sigmoid);
254+
CHECK_RANKS_AND_SIZES_FOR(Tanh);
182255
// all elementwise binary operators:
183-
CHECK_RANKS_FOR(Add);
184-
CHECK_RANKS_FOR(ArithmeticRightShift);
185-
CHECK_RANKS_FOR(BitwiseAnd);
186-
CHECK_RANKS_FOR(BitwiseOr);
187-
CHECK_RANKS_FOR(BitwiseXor);
188-
CHECK_RANKS_FOR(IntDiv);
189-
CHECK_RANKS_FOR(LogicalAnd);
190-
CHECK_RANKS_FOR(LogicalLeftShift);
191-
CHECK_RANKS_FOR(LogicalRightShift);
192-
CHECK_RANKS_FOR(LogicalOr);
193-
CHECK_RANKS_FOR(LogicalXor);
194-
CHECK_RANKS_FOR(Maximum);
195-
CHECK_RANKS_FOR(Minimum);
196-
CHECK_RANKS_FOR(Mul);
197-
CHECK_RANKS_FOR(Pow);
198-
CHECK_RANKS_FOR(Sub);
199-
CHECK_RANKS_FOR(Table);
256+
CHECK_RANKS_AND_SIZES_FOR(Add);
257+
CHECK_RANKS_AND_SIZES_FOR(ArithmeticRightShift);
258+
CHECK_RANKS_AND_SIZES_FOR(BitwiseAnd);
259+
CHECK_RANKS_AND_SIZES_FOR(BitwiseOr);
260+
CHECK_RANKS_AND_SIZES_FOR(BitwiseXor);
261+
CHECK_RANKS_AND_SIZES_FOR(IntDiv);
262+
CHECK_RANKS_AND_SIZES_FOR(LogicalAnd);
263+
CHECK_RANKS_AND_SIZES_FOR(LogicalLeftShift);
264+
CHECK_RANKS_AND_SIZES_FOR(LogicalRightShift);
265+
CHECK_RANKS_AND_SIZES_FOR(LogicalOr);
266+
CHECK_RANKS_AND_SIZES_FOR(LogicalXor);
267+
CHECK_RANKS_AND_SIZES_FOR(Maximum);
268+
CHECK_RANKS_AND_SIZES_FOR(Minimum);
269+
CHECK_RANKS_AND_SIZES_FOR(Mul);
270+
CHECK_RANKS_AND_SIZES_FOR(Pow);
271+
CHECK_RANKS_AND_SIZES_FOR(Sub);
272+
CHECK_RANKS_AND_SIZES_FOR(Table);
200273
// all elementwise unary operators:
201-
CHECK_RANKS_FOR(Abs);
202-
CHECK_RANKS_FOR(BitwiseNot);
203-
CHECK_RANKS_FOR(Ceil);
204-
CHECK_RANKS_FOR(Clz);
205-
CHECK_RANKS_FOR(Exp);
206-
CHECK_RANKS_FOR(Floor);
207-
CHECK_RANKS_FOR(Log);
208-
CHECK_RANKS_FOR(LogicalNot);
209-
CHECK_RANKS_FOR(Negate);
210-
CHECK_RANKS_FOR(Reciprocal);
211-
CHECK_RANKS_FOR(Rsqrt);
274+
CHECK_RANKS_AND_SIZES_FOR(Abs);
275+
CHECK_RANKS_AND_SIZES_FOR(BitwiseNot);
276+
CHECK_RANKS_AND_SIZES_FOR(Ceil);
277+
CHECK_RANKS_AND_SIZES_FOR(Clz);
278+
CHECK_RANKS_AND_SIZES_FOR(Cos);
279+
CHECK_RANKS_AND_SIZES_FOR(Exp);
280+
CHECK_RANKS_AND_SIZES_FOR(Floor);
281+
CHECK_RANKS_AND_SIZES_FOR(Log);
282+
CHECK_RANKS_AND_SIZES_FOR(LogicalNot);
283+
CHECK_RANKS_AND_SIZES_FOR(Negate);
284+
CHECK_RANKS_AND_SIZES_FOR(Reciprocal);
285+
CHECK_RANKS_AND_SIZES_FOR(Rsqrt);
286+
CHECK_RANKS_AND_SIZES_FOR(Sin);
212287
// all elementwise ternary operators:
213-
CHECK_RANKS_FOR(Select);
288+
CHECK_RANKS_AND_SIZES_FOR(Select);
214289
// all comparison operators:
215-
CHECK_RANKS_FOR(Equal);
216-
CHECK_RANKS_FOR(Greater);
217-
CHECK_RANKS_FOR(GreaterEqual);
290+
CHECK_RANKS_AND_SIZES_FOR(Equal);
291+
CHECK_RANKS_AND_SIZES_FOR(Greater);
292+
CHECK_RANKS_AND_SIZES_FOR(GreaterEqual);
218293
// all reduction operators:
219-
CHECK_RANKS_FOR(ReduceAll);
220-
CHECK_RANKS_FOR(ReduceAny);
221-
CHECK_RANKS_FOR(ReduceMax);
222-
CHECK_RANKS_FOR(ReduceMin);
223-
CHECK_RANKS_FOR(ReduceProduct);
224-
CHECK_RANKS_FOR(ReduceSum);
294+
CHECK_RANKS_AND_SIZES_FOR(ReduceAll);
295+
CHECK_RANKS_AND_SIZES_FOR(ReduceAny);
296+
CHECK_RANKS_AND_SIZES_FOR(ReduceMax);
297+
CHECK_RANKS_AND_SIZES_FOR(ReduceMin);
298+
CHECK_RANKS_AND_SIZES_FOR(ReduceProduct);
299+
CHECK_RANKS_AND_SIZES_FOR(ReduceSum);
225300
// all data layout operators:
226-
CHECK_RANKS_FOR(Concat);
227-
CHECK_RANKS_FOR(Pad);
228-
CHECK_RANKS_FOR(Reshape);
229-
CHECK_RANKS_FOR(Reverse);
230-
CHECK_RANKS_FOR(Slice);
231-
CHECK_RANKS_FOR(Tile);
232-
CHECK_RANKS_FOR(Transpose);
301+
CHECK_RANKS_AND_SIZES_FOR(Concat);
302+
CHECK_RANKS_AND_SIZES_FOR(Pad);
303+
CHECK_RANKS_AND_SIZES_FOR(Reshape);
304+
CHECK_RANKS_AND_SIZES_FOR(Reverse);
305+
CHECK_RANKS_AND_SIZES_FOR(Slice);
306+
CHECK_RANKS_AND_SIZES_FOR(Tile);
307+
CHECK_RANKS_AND_SIZES_FOR(Transpose);
233308
// all type conversion operators:
234-
CHECK_RANKS_FOR(Cast);
235-
CHECK_RANKS_FOR(Rescale);
309+
CHECK_RANKS_AND_SIZES_FOR(Cast);
310+
CHECK_RANKS_AND_SIZES_FOR(Rescale);
311+
// control flow operators:
312+
CHECK_RANKS_AND_SIZES_FOR(If);
236313
// all data nodes operators:
237-
CHECK_RANKS_FOR(Const);
238-
CHECK_RANKS_FOR(Identity);
314+
CHECK_RANKS_AND_SIZES_FOR(Const);
315+
CHECK_RANKS_AND_SIZES_FOR(Identity);
239316

240-
#undef CHECK_RANKS_FOR
317+
// The following operators do not have level rank and size constraint.
318+
CHECK_RANKS_AND_SIZES_SKIP(Resize);
319+
CHECK_RANKS_AND_SIZES_SKIP(Yield);
320+
CHECK_RANKS_AND_SIZES_SKIP(Custom);
321+
CHECK_RANKS_AND_SIZES_SKIP(While);
322+
323+
#undef CHECK_RANKS_AND_SIZES_FOR
324+
#undef CHECK_RANKS_AND_SIZES_SKIP
241325
return true;
242326
}
243327

@@ -386,6 +470,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
386470
return true;
387471
}
388472

473+
bool levelCheckListSize(Operation *op) {
474+
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
475+
return levelCheckListSize(op, concat.getInput1().size(), "input1");
476+
}
477+
if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
478+
if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
479+
!levelCheckListSize(op, custom.getOutputList().size(),
480+
"output_list")) {
481+
return false;
482+
}
483+
}
484+
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
485+
if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
486+
!levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
487+
return false;
488+
}
489+
}
490+
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
491+
if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
492+
!levelCheckListSize(op, w.getOutput().size(), "outputs")) {
493+
return false;
494+
}
495+
}
496+
return true;
497+
}
498+
389499
// configure profile and level values from pass options profileName and
390500
// levelName
391501
void configLevelAndProfile() {
@@ -439,7 +549,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
439549
return success();
440550
}
441551

442-
if (!levelCheckRanks(op)) {
552+
if (!levelCheckRanksAndSizes(op)) {
443553
return failure();
444554
}
445555

@@ -455,6 +565,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
455565
return failure();
456566
}
457567

568+
// level check MAX_TENSOR_LIST_SIZE
569+
if (!levelCheckListSize(op)) {
570+
return failure();
571+
}
572+
458573
return success();
459574
}
460575

0 commit comments

Comments
 (0)