@@ -949,6 +949,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
949
949
return success ();
950
950
}
951
951
952
+ LogicalResult tosa::ConcatOp::verify () {
953
+ // check that each input has same element type as output
954
+ auto outType = getOutput ().getType ();
955
+ const Operation::operand_range inputList = getInput1 ();
956
+
957
+ // Check there is at least one input
958
+ if (inputList.empty ())
959
+ return emitOpError (" expect at least one input" );
960
+
961
+ if (!llvm::all_of (inputList, [&](auto input) {
962
+ return succeeded (verifySameElementTypes (
963
+ *this , /* inType = */ input.getType (), outType));
964
+ })) {
965
+ return failure ();
966
+ }
967
+
968
+ const int32_t axis = getAxis ();
969
+ ShapeAdaptor firstRankedInputShape = nullptr ;
970
+ for (const auto &input : inputList) {
971
+ const Type inputType = input.getType ();
972
+ ShapeAdaptor currShape (inputType);
973
+ if (currShape.hasRank ()) {
974
+ firstRankedInputShape = currShape;
975
+ // Check axis is in expected range
976
+ if (axis < 0 || axis >= firstRankedInputShape.getRank ())
977
+ return emitOpError (" expect axis to be within range 0 < axis < "
978
+ " rank(input1[firstRankedTensorIdx]), got " )
979
+ << axis;
980
+ break ;
981
+ }
982
+ }
983
+
984
+ const auto allOperandsHasRank = [](const Value input) {
985
+ return ShapeAdaptor (input.getType ()).hasRank ();
986
+ };
987
+ if (llvm::all_of (inputList, allOperandsHasRank)) {
988
+ const int64_t firstInputRank = firstRankedInputShape.getRank ();
989
+
990
+ for (const auto &[index, input] : llvm::enumerate (inputList.drop_front ())) {
991
+ const ShapeAdaptor inputShape (input.getType ());
992
+ const int64_t inputRank = inputShape.getRank ();
993
+ const size_t operandNum = index + 1 ;
994
+
995
+ // Check that each operand has the same rank
996
+ if (inputRank != firstInputRank)
997
+ return emitOpError (
998
+ " expect all operands to have the same rank, but got " )
999
+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
1000
+ << operandNum;
1001
+
1002
+ // Check non-axis dims match
1003
+ for (int i = 0 ; i < inputRank; i++) {
1004
+ const int64_t inputDim = inputShape.getDimSize (i);
1005
+ const int64_t firstInputDim = firstRankedInputShape.getDimSize (i);
1006
+ if (i == axis || firstRankedInputShape.isDynamicDim (i) ||
1007
+ inputShape.isDynamicDim (i))
1008
+ continue ;
1009
+ if (inputDim != firstInputDim)
1010
+ return emitOpError (" expect all operand shapes to have the same sizes "
1011
+ " on non-axis dimensions, but got " )
1012
+ << inputDim << " vs " << firstInputDim << " at index " << i
1013
+ << " on operands 0 and " << operandNum;
1014
+ }
1015
+ }
1016
+ }
1017
+
1018
+ return success ();
1019
+ }
1020
+
952
1021
LogicalResult tosa::EqualOp::inferReturnTypeComponents (
953
1022
MLIRContext *context, ::std::optional<Location> location,
954
1023
ValueShapeRange operands, DictionaryAttr attributes,
@@ -998,6 +1067,51 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
998
1067
return success ();
999
1068
}
1000
1069
1070
+ LogicalResult MatMulOp::verify () {
1071
+ auto aType = llvm::dyn_cast<ShapedType>(getA ().getType ());
1072
+ auto bType = llvm::dyn_cast<ShapedType>(getB ().getType ());
1073
+
1074
+ // Must be shaped tensor types
1075
+ if (!aType)
1076
+ emitOpError (" expect a shaped tensor for input a, got " ) << getA ().getType ();
1077
+
1078
+ if (!bType)
1079
+ emitOpError (" expect a shaped tensor for input b, got " ) << getB ().getType ();
1080
+
1081
+ auto aElementType = aType.getElementType ();
1082
+ auto bElementType = bType.getElementType ();
1083
+
1084
+ auto aQuantizedEType =
1085
+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1086
+ auto bQuantizedEType =
1087
+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1088
+
1089
+ if (aQuantizedEType || bQuantizedEType) {
1090
+ if (!aQuantizedEType || !bQuantizedEType) {
1091
+ emitOpError (
1092
+ " expect operands to be both quantized or both not quantized, got " )
1093
+ << aElementType << " and " << bElementType;
1094
+ }
1095
+ // both a and b have quantized element types
1096
+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth ();
1097
+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth ();
1098
+ if (aQuantWidth != bQuantWidth) {
1099
+ emitOpError (" expect quantized operands to have same widths, got " )
1100
+ << aQuantWidth << " and " << bQuantWidth;
1101
+ }
1102
+
1103
+ return success ();
1104
+ }
1105
+
1106
+ // non-quantized element types
1107
+ if (aElementType != bElementType) {
1108
+ emitOpError (" expect same element type for inputs a and b, got " )
1109
+ << aElementType << " and " << bElementType;
1110
+ }
1111
+
1112
+ return success ();
1113
+ }
1114
+
1001
1115
LogicalResult tosa::PadOp::inferReturnTypeComponents (
1002
1116
MLIRContext *context, ::std::optional<Location> location,
1003
1117
PadOp::Adaptor adaptor,
@@ -1046,6 +1160,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
1046
1160
}
1047
1161
1048
1162
LogicalResult tosa::PadOp::verify () {
1163
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1164
+ /* outType = */ getOutput ().getType ())
1165
+ .failed ()) {
1166
+ return failure ();
1167
+ }
1168
+
1169
+ if (auto padConst = getPadConst ()) {
1170
+ if (verifySameElementTypes (*this , /* inType = */ padConst.getType (),
1171
+ /* outType = */ getOutput ().getType ())
1172
+ .failed ()) {
1173
+ return failure ();
1174
+ }
1175
+ }
1176
+
1049
1177
RankedTensorType inputType = getInput1 ().getType ();
1050
1178
RankedTensorType outputType = getOutput ().getType ();
1051
1179
auto paddingRank = cast<tosa::shapeType>(getPadding ().getType ()).getRank ();
@@ -1119,21 +1247,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1119
1247
}
1120
1248
1121
1249
LogicalResult tosa::SliceOp::verify () {
1250
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1251
+ /* outType = */ getOutput ().getType ())
1252
+ .failed ())
1253
+ return failure ();
1122
1254
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
1123
1255
if (!inputType)
1124
1256
return success ();
1125
1257
1126
1258
auto startShapeRank =
1127
1259
llvm::cast<tosa::shapeType>(getStart ().getType ()).getRank ();
1128
1260
if (inputType.getRank () != startShapeRank)
1129
- return emitOpError (
1130
- " length of start attribute is not equal rank of input shape" );
1261
+ return emitOpError (" length of start is not equal to rank of input shape" );
1131
1262
1132
1263
auto sizeShapeRank =
1133
1264
llvm::cast<tosa::shapeType>(getSize ().getType ()).getRank ();
1134
1265
if (inputType.getRank () != sizeShapeRank)
1135
- return emitOpError (
1136
- " length of size attribute is not equal rank of input shape" );
1266
+ return emitOpError (" length of size is not equal to rank of input shape" );
1137
1267
1138
1268
return success ();
1139
1269
}
@@ -1338,6 +1468,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
1338
1468
}
1339
1469
1340
1470
LogicalResult tosa::TileOp::verify () {
1471
+ if (verifySameElementTypes (*this , /* intype = */ getInput1 ().getType (),
1472
+ /* outType = */ getOutput ().getType ())
1473
+ .failed ()) {
1474
+ return failure ();
1475
+ }
1341
1476
ShapedType inputType = llvm::cast<ShapedType>(getInput1 ().getType ());
1342
1477
ShapedType outputType = llvm::cast<ShapedType>(getType ());
1343
1478
@@ -1419,6 +1554,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1419
1554
}
1420
1555
1421
1556
llvm::LogicalResult tosa::ReshapeOp::verify () {
1557
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1558
+ /* outType = */ getOutput ().getType ())
1559
+ .failed ()) {
1560
+ return failure ();
1561
+ }
1422
1562
TensorType inputType = getInput1 ().getType ();
1423
1563
RankedTensorType outputType = getType ();
1424
1564
@@ -1606,6 +1746,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1606
1746
}
1607
1747
1608
1748
LogicalResult tosa::TransposeOp::verify () {
1749
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1750
+ /* outType = */ getOutput ().getType ())
1751
+ .failed ()) {
1752
+ return failure ();
1753
+ }
1609
1754
TensorType inputType = getInput1 ().getType ();
1610
1755
TensorType outputType = getOutput ().getType ();
1611
1756
const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
@@ -1706,6 +1851,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1706
1851
return success ();
1707
1852
}
1708
1853
1854
+ LogicalResult tosa::GatherOp::verify () {
1855
+ return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
1856
+ /* outType = */ getOutput ().getType ());
1857
+ }
1858
+
1709
1859
LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
1710
1860
MLIRContext *context, ::std::optional<Location> location,
1711
1861
ResizeOp::Adaptor adaptor,
@@ -1867,6 +2017,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1867
2017
return success ();
1868
2018
}
1869
2019
2020
+ LogicalResult tosa::ScatterOp::verify () {
2021
+ if (verifySameElementTypes (*this , /* inType = */ getValuesIn ().getType (),
2022
+ /* outType = */ getValuesOut ().getType ())
2023
+ .failed () ||
2024
+ verifySameElementTypes (*this , /* inType = */ getInput ().getType (),
2025
+ /* outType = */ getValuesOut ().getType ())
2026
+ .failed ()) {
2027
+ return failure ();
2028
+ }
2029
+ return success ();
2030
+ }
2031
+
1870
2032
static LogicalResult ReduceInferReturnTypes (
1871
2033
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1872
2034
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2322,6 +2484,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2322
2484
inferredReturnShapes);
2323
2485
}
2324
2486
2487
+ LogicalResult MaxPool2dOp::verify () {
2488
+ return verifySameElementTypes (*this , /* intype = */ getInput ().getType (),
2489
+ /* outType = */ getOutput ().getType ());
2490
+ }
2491
+
2325
2492
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents (
2326
2493
MLIRContext *context, ::std::optional<Location> location,
2327
2494
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2622,6 +2789,10 @@ void IfOp::print(OpAsmPrinter &p) {
2622
2789
}
2623
2790
2624
2791
LogicalResult ReverseOp::verify () {
2792
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
2793
+ /* outType = */ getOutput ().getType ())
2794
+ .failed ())
2795
+ return failure ();
2625
2796
TensorType inputType = getInput1 ().getType ();
2626
2797
TensorType outputType = getOutput ().getType ();
2627
2798
int32_t reverseAxis = getAxis ();
@@ -2650,6 +2821,31 @@ LogicalResult ReverseOp::verify() {
2650
2821
return success ();
2651
2822
}
2652
2823
2824
+ LogicalResult tosa::SelectOp::verify () {
2825
+ // verify input2 and input3 have same element type as output
2826
+ if (verifySameElementTypes (*this , /* inType = */ getInput2 ().getType (),
2827
+ /* outType = */ getOutput ().getType ())
2828
+ .failed () ||
2829
+ verifySameElementTypes (*this , /* inType = */ getInput3 ().getType (),
2830
+ /* outType = */ getOutput ().getType ())
2831
+ .failed ()) {
2832
+ return failure ();
2833
+ }
2834
+ // verify input1 has element type of bool
2835
+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1 ().getType ());
2836
+ if (!predicateType) {
2837
+ emitOpError (" expect shaped tensor for input1, got " )
2838
+ << getInput1 ().getType ();
2839
+ }
2840
+ auto predicateElementType = predicateType.getElementType ();
2841
+ if (!predicateElementType.isInteger (1 )) {
2842
+ emitOpError (" expect element type of bool for input1, got " )
2843
+ << predicateElementType;
2844
+ }
2845
+
2846
+ return success ();
2847
+ }
2848
+
2653
2849
// parse and print of WhileOp refer to the implementation of SCF dialect.
2654
2850
ParseResult WhileOp::parse (OpAsmParser &parser, OperationState &result) {
2655
2851
SmallVector<OpAsmParser::Argument, 4 > regionArgs;
0 commit comments