Skip to content

Commit ad9de04

Browse files
tatwaichonglhutton1
authored andcommitted
[mlir][tosa] Add error if and level checks for COND_IF & WHILE_LOOP (llvm#136194)
Error if checks: verify whether the same length and type between input list, output list, and control-flow blocks. Level_checks: verify whether the nested depth exceeds MAX_NESTING. (cherry-picked from commit 3d47bc9) Change-Id: I4f0dfd8f0e43674d1c6b0f72cc7ec144ab31782a
1 parent af5e12f commit ad9de04

File tree

5 files changed

+738
-79
lines changed

5 files changed

+738
-79
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2632,6 +2632,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
26322632
);
26332633

26342634
let hasCustomAssemblyFormat = 1;
2635+
let hasVerifier = 1;
26352636
}
26362637

26372638
//===----------------------------------------------------------------------===//
@@ -2670,6 +2671,7 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
26702671
);
26712672

26722673
let hasCustomAssemblyFormat = 1;
2674+
let hasVerifier = 1;
26732675
}
26742676

26752677
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
573573
return success();
574574
}
575575

576+
// Verify whether same type and shape of the given two types.
577+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
578+
StringRef name1, Type type2,
579+
StringRef name2) {
580+
auto shapeType1 = dyn_cast<ShapedType>(type1);
581+
auto shapeType2 = dyn_cast<ShapedType>(type2);
582+
if (!shapeType1 || !shapeType2)
583+
return failure();
584+
585+
auto elemType1 = shapeType1.getElementType();
586+
auto elemType2 = shapeType2.getElementType();
587+
if (elemType1 != elemType2)
588+
return op->emitOpError()
589+
<< "require same element type for " << name1 << " (" << elemType1
590+
<< ") and " << name2 << " (" << elemType2 << ")";
591+
592+
if (failed(verifyCompatibleShape(type1, type2)))
593+
return op->emitOpError()
594+
<< "require same shapes for " << name1 << " (" << type1 << ") and "
595+
<< name2 << " (" << type2 << ")";
596+
597+
return success();
598+
}
599+
600+
// Verify whether same length, type, and shape of the given two tensor lists.
601+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
602+
StringRef name1,
603+
ValueRange list2,
604+
StringRef name2) {
605+
if (list1.size() != list2.size())
606+
return op->emitOpError()
607+
<< "require same number of values in " << name1 << " ("
608+
<< list1.size() << ") and " << name2 << " (" << list2.size() << ")";
609+
610+
for (auto [type1, type2] :
611+
llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
612+
if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
613+
return failure();
614+
}
615+
616+
return success();
617+
}
618+
619+
static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
620+
ShapeAdaptor shapeAdaptor(type);
621+
if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
622+
return success();
623+
624+
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
625+
}
626+
576627
// verify that inType and outType have same element types
577628
template <typename T>
578629
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3601,6 +3652,84 @@ void IfOp::print(OpAsmPrinter &p) {
36013652
p.printOptionalAttrDict((*this)->getAttrs());
36023653
}
36033654

3655+
LogicalResult IfOp::verify() {
3656+
if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
3657+
"'then_graph' arguments", getInputList(),
3658+
"'input_list'")
3659+
.failed())
3660+
return failure();
3661+
3662+
if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
3663+
"'else_graph' arguments", getInputList(),
3664+
"'input_list'")
3665+
.failed())
3666+
return failure();
3667+
3668+
auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3669+
if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
3670+
"'then_graph' results", getOutputList(),
3671+
"'output_list'")
3672+
.failed())
3673+
return failure();
3674+
3675+
auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3676+
if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
3677+
"'else_graph' results", getOutputList(),
3678+
"'output_list'")
3679+
.failed())
3680+
return failure();
3681+
3682+
auto condType = getCondition().getType();
3683+
if (errorIfShapeNotSizeOne(*this, condType).failed())
3684+
return emitOpError() << "'condition' must be a size 1 tensor, got "
3685+
<< condType;
3686+
3687+
return success();
3688+
}
3689+
3690+
LogicalResult WhileOp::verify() {
3691+
if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
3692+
getOutputList(), "'output_list'")
3693+
.failed())
3694+
return failure();
3695+
3696+
if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
3697+
"'cond_graph' arguments", getInputList(),
3698+
"'input_list'")
3699+
.failed())
3700+
return failure();
3701+
3702+
if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
3703+
"'body_graph' arguments", getInputList(),
3704+
"'input_list'")
3705+
.failed())
3706+
return failure();
3707+
3708+
auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3709+
if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
3710+
"'body_graph' results", getInputList(),
3711+
"'input_list'")
3712+
.failed())
3713+
return failure();
3714+
3715+
// Condition block output must be a single element tensor with a single bool
3716+
// value.
3717+
auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3718+
if (condYield.getInputs().size() != 1)
3719+
return emitOpError() << "require 'cond_graph' only have one result";
3720+
3721+
auto condOutType = condYield.getInputs()[0].getType();
3722+
if (errorIfShapeNotSizeOne(*this, condOutType).failed())
3723+
return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3724+
<< condOutType;
3725+
3726+
if (!getElementTypeOrSelf(condOutType).isInteger(1))
3727+
return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3728+
<< condOutType;
3729+
3730+
return success();
3731+
}
3732+
36043733
LogicalResult ReverseOp::verify() {
36053734
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
36063735
/* outType = */ getOutput().getType())

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,35 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
449449
return true;
450450
}
451451

452+
// Recursively perform a bottom-up search to determine the maximum nesting
453+
// depth, starting from a specific operation and continuing up to the function
454+
// or module scope. Tosa nesting_depth starts at 0 and increments by one each
455+
// time a new nested `region` is encountered.
456+
static void getMaxNestedDepth(Operation *op, int32_t &depth) {
457+
if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
458+
return;
459+
460+
op = op->getParentOp();
461+
if (!op)
462+
return;
463+
464+
depth++;
465+
getMaxNestedDepth(op, depth);
466+
return;
467+
}
468+
469+
bool levelCheckMaxNesting(Operation *op) {
470+
int32_t maxNestedDepth = 0;
471+
getMaxNestedDepth(op, maxNestedDepth);
472+
473+
if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
474+
op->emitOpError() << "failed level check: " << maxNestedDepth
475+
<< " >= MAX_NESTING";
476+
return false;
477+
}
478+
return true;
479+
}
480+
452481
bool levelCheckListSize(Operation *op) {
453482
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
454483
return levelCheckListSize(op, concat.getInput1().size(), "input1");
@@ -751,6 +780,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
751780
return failure();
752781
}
753782

783+
if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
784+
if (!levelCheckMaxNesting(op)) {
785+
return failure();
786+
}
787+
}
788+
754789
return success();
755790
}
756791

0 commit comments

Comments
 (0)