Skip to content

Commit 0186960

Browse files
Jerry-GeTai78641lhutton1
committed
[mlir][tosa] Add more verifiers for the following operators
For ConcatOp this commit also enhances the verifier by checking 4 another conditions: - The input list is not empty - The axis value is within range of the input shapes - All inputs have the same rank - All non concatenate axis dims have the same value For MatmulOp: - Checked input a, bs tensor type, element types For the following operators, added the verifySameElementTypes check. - PadOp - SliceOp - TileOp - ReshapeOp - TransposeOp - GatherOp - ScatterOp - MaxPool2dOp - ReverseOp - SelectOp Change-Id: I1e8a1017f21f617443bc40bae42189915048c750 Co-authored-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]> Signed-off-by: Jerry Ge <[email protected]>
1 parent d1c1ab1 commit 0186960

File tree

3 files changed

+245
-13
lines changed

3 files changed

+245
-13
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
315315
];
316316

317317
let builders = [Tosa_MatMulOpQuantInfoBuilder];
318+
let hasVerifier = 1;
318319
}
319320

320321
//===----------------------------------------------------------------------===//
@@ -349,6 +350,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
349350
];
350351

351352
let hasCanonicalizer = 1;
353+
let hasVerifier = 1;
352354
}
353355

354356
//===----------------------------------------------------------------------===//
@@ -1479,6 +1481,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
14791481

14801482
let hasCanonicalizeMethod = 1;
14811483
let hasFolder = 1;
1484+
let hasVerifier = 1;
14821485

14831486
let assemblyFormat = [{
14841487
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1854,6 +1857,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
18541857

18551858
let hasCanonicalizer = 1;
18561859
let hasFolder = 1;
1860+
let hasVerifier = 1;
18571861

18581862
let extraClassDeclaration = [{
18591863
/// Returns true when two result types are compatible for this op;
@@ -2110,6 +2114,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
21102114
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
21112115
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
21122116
];
2117+
2118+
let hasVerifier = 1;
21132119
}
21142120

21152121
//===----------------------------------------------------------------------===//
@@ -2143,6 +2149,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
21432149
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
21442150
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
21452151
];
2152+
2153+
let hasVerifier = 1;
21462154
}
21472155

21482156
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 200 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
949949
return success();
950950
}
951951

952+
LogicalResult tosa::ConcatOp::verify() {
953+
// check that each input has same element type as output
954+
auto outType = getOutput().getType();
955+
const Operation::operand_range inputList = getInput1();
956+
957+
// Check there is at least one input
958+
if (inputList.empty())
959+
return emitOpError("expect at least one input");
960+
961+
if (!llvm::all_of(inputList, [&](auto input) {
962+
return succeeded(verifySameElementTypes(
963+
*this, /* inType = */ input.getType(), outType));
964+
})) {
965+
return failure();
966+
}
967+
968+
const int32_t axis = getAxis();
969+
ShapeAdaptor firstRankedInputShape = nullptr;
970+
for (const auto &input : inputList) {
971+
const Type inputType = input.getType();
972+
ShapeAdaptor currShape(inputType);
973+
if (currShape.hasRank()) {
974+
firstRankedInputShape = currShape;
975+
// Check axis is in expected range
976+
if (axis < 0 || axis >= firstRankedInputShape.getRank())
977+
return emitOpError("expect axis to be within range 0 < axis < "
978+
"rank(input1[firstRankedTensorIdx]), got ")
979+
<< axis;
980+
break;
981+
}
982+
}
983+
984+
const auto allOperandsHasRank = [](const Value input) {
985+
return ShapeAdaptor(input.getType()).hasRank();
986+
};
987+
if (llvm::all_of(inputList, allOperandsHasRank)) {
988+
const int64_t firstInputRank = firstRankedInputShape.getRank();
989+
990+
for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
991+
const ShapeAdaptor inputShape(input.getType());
992+
const int64_t inputRank = inputShape.getRank();
993+
const size_t operandNum = index + 1;
994+
995+
// Check that each operand has the same rank
996+
if (inputRank != firstInputRank)
997+
return emitOpError(
998+
"expect all operands to have the same rank, but got ")
999+
<< firstInputRank << " vs " << inputRank << " on operands 0 and "
1000+
<< operandNum;
1001+
1002+
// Check non-axis dims match
1003+
for (int i = 0; i < inputRank; i++) {
1004+
const int64_t inputDim = inputShape.getDimSize(i);
1005+
const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1006+
if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1007+
inputShape.isDynamicDim(i))
1008+
continue;
1009+
if (inputDim != firstInputDim)
1010+
return emitOpError("expect all operand shapes to have the same sizes "
1011+
"on non-axis dimensions, but got ")
1012+
<< inputDim << " vs " << firstInputDim << " at index " << i
1013+
<< " on operands 0 and " << operandNum;
1014+
}
1015+
}
1016+
}
1017+
1018+
return success();
1019+
}
1020+
9521021
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
9531022
MLIRContext *context, ::std::optional<Location> location,
9541023
ValueShapeRange operands, DictionaryAttr attributes,
@@ -998,6 +1067,51 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
9981067
return success();
9991068
}
10001069

1070+
LogicalResult MatMulOp::verify() {
1071+
auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1072+
auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1073+
1074+
// Must be shaped tensor types
1075+
if (!aType)
1076+
emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
1077+
1078+
if (!bType)
1079+
emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
1080+
1081+
auto aElementType = aType.getElementType();
1082+
auto bElementType = bType.getElementType();
1083+
1084+
auto aQuantizedEType =
1085+
llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1086+
auto bQuantizedEType =
1087+
llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1088+
1089+
if (aQuantizedEType || bQuantizedEType) {
1090+
if (!aQuantizedEType || !bQuantizedEType) {
1091+
emitOpError(
1092+
"expect operands to be both quantized or both not quantized, got ")
1093+
<< aElementType << " and " << bElementType;
1094+
}
1095+
// both a and b have quantized element types
1096+
auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1097+
auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1098+
if (aQuantWidth != bQuantWidth) {
1099+
emitOpError("expect quantized operands to have same widths, got ")
1100+
<< aQuantWidth << " and " << bQuantWidth;
1101+
}
1102+
1103+
return success();
1104+
}
1105+
1106+
// non-quantized element types
1107+
if (aElementType != bElementType) {
1108+
emitOpError("expect same element type for inputs a and b, got ")
1109+
<< aElementType << " and " << bElementType;
1110+
}
1111+
1112+
return success();
1113+
}
1114+
10011115
LogicalResult tosa::PadOp::inferReturnTypeComponents(
10021116
MLIRContext *context, ::std::optional<Location> location,
10031117
PadOp::Adaptor adaptor,
@@ -1046,6 +1160,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
10461160
}
10471161

10481162
LogicalResult tosa::PadOp::verify() {
1163+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1164+
/* outType = */ getOutput().getType())
1165+
.failed()) {
1166+
return failure();
1167+
}
1168+
1169+
if (auto padConst = getPadConst()) {
1170+
if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1171+
/* outType = */ getOutput().getType())
1172+
.failed()) {
1173+
return failure();
1174+
}
1175+
}
1176+
10491177
RankedTensorType inputType = getInput1().getType();
10501178
RankedTensorType outputType = getOutput().getType();
10511179
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1119,21 +1247,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
11191247
}
11201248

11211249
LogicalResult tosa::SliceOp::verify() {
1250+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1251+
/* outType = */ getOutput().getType())
1252+
.failed())
1253+
return failure();
11221254
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
11231255
if (!inputType)
11241256
return success();
11251257

11261258
auto startShapeRank =
11271259
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
11281260
if (inputType.getRank() != startShapeRank)
1129-
return emitOpError(
1130-
"length of start attribute is not equal rank of input shape");
1261+
return emitOpError("length of start is not equal to rank of input shape");
11311262

11321263
auto sizeShapeRank =
11331264
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
11341265
if (inputType.getRank() != sizeShapeRank)
1135-
return emitOpError(
1136-
"length of size attribute is not equal rank of input shape");
1266+
return emitOpError("length of size is not equal to rank of input shape");
11371267

11381268
return success();
11391269
}
@@ -1338,6 +1468,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
13381468
}
13391469

13401470
LogicalResult tosa::TileOp::verify() {
1471+
if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
1472+
/* outType = */ getOutput().getType())
1473+
.failed()) {
1474+
return failure();
1475+
}
13411476
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
13421477
ShapedType outputType = llvm::cast<ShapedType>(getType());
13431478

@@ -1419,6 +1554,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
14191554
}
14201555

14211556
llvm::LogicalResult tosa::ReshapeOp::verify() {
1557+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1558+
/* outType = */ getOutput().getType())
1559+
.failed()) {
1560+
return failure();
1561+
}
14221562
TensorType inputType = getInput1().getType();
14231563
RankedTensorType outputType = getType();
14241564

@@ -1606,6 +1746,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
16061746
}
16071747

16081748
LogicalResult tosa::TransposeOp::verify() {
1749+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1750+
/* outType = */ getOutput().getType())
1751+
.failed()) {
1752+
return failure();
1753+
}
16091754
TensorType inputType = getInput1().getType();
16101755
TensorType outputType = getOutput().getType();
16111756
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
@@ -1706,6 +1851,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
17061851
return success();
17071852
}
17081853

1854+
LogicalResult tosa::GatherOp::verify() {
1855+
return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
1856+
/* outType = */ getOutput().getType());
1857+
}
1858+
17091859
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
17101860
MLIRContext *context, ::std::optional<Location> location,
17111861
ResizeOp::Adaptor adaptor,
@@ -1867,6 +2017,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
18672017
return success();
18682018
}
18692019

2020+
LogicalResult tosa::ScatterOp::verify() {
2021+
if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2022+
/* outType = */ getValuesOut().getType())
2023+
.failed() ||
2024+
verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2025+
/* outType = */ getValuesOut().getType())
2026+
.failed()) {
2027+
return failure();
2028+
}
2029+
return success();
2030+
}
2031+
18702032
static LogicalResult ReduceInferReturnTypes(
18712033
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
18722034
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2322,6 +2484,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
23222484
inferredReturnShapes);
23232485
}
23242486

2487+
LogicalResult MaxPool2dOp::verify() {
2488+
return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
2489+
/* outType = */ getOutput().getType());
2490+
}
2491+
23252492
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
23262493
MLIRContext *context, ::std::optional<Location> location,
23272494
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2622,6 +2789,10 @@ void IfOp::print(OpAsmPrinter &p) {
26222789
}
26232790

26242791
LogicalResult ReverseOp::verify() {
2792+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2793+
/* outType = */ getOutput().getType())
2794+
.failed())
2795+
return failure();
26252796
TensorType inputType = getInput1().getType();
26262797
TensorType outputType = getOutput().getType();
26272798
int32_t reverseAxis = getAxis();
@@ -2650,6 +2821,31 @@ LogicalResult ReverseOp::verify() {
26502821
return success();
26512822
}
26522823

2824+
LogicalResult tosa::SelectOp::verify() {
2825+
// verify input2 and input3 have same element type as output
2826+
if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
2827+
/* outType = */ getOutput().getType())
2828+
.failed() ||
2829+
verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
2830+
/* outType = */ getOutput().getType())
2831+
.failed()) {
2832+
return failure();
2833+
}
2834+
// verify input1 has element type of bool
2835+
auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
2836+
if (!predicateType) {
2837+
emitOpError("expect shaped tensor for input1, got ")
2838+
<< getInput1().getType();
2839+
}
2840+
auto predicateElementType = predicateType.getElementType();
2841+
if (!predicateElementType.isInteger(1)) {
2842+
emitOpError("expect element type of bool for input1, got ")
2843+
<< predicateElementType;
2844+
}
2845+
2846+
return success();
2847+
}
2848+
26532849
// parse and print of WhileOp refer to the implementation of SCF dialect.
26542850
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
26552851
SmallVector<OpAsmParser::Argument, 4> regionArgs;

0 commit comments

Comments
 (0)