Skip to content

Commit 20ae283

Browse files
authored
[mlir][tosa] Change the shift of mul to be required (#125297)
Change the shift operand for the mul operator to be a required operand. Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor whose shape is [1] (ie, tensor containing a single element) Signed-off-by: Tai Ly <[email protected]>
1 parent 070f84e commit 20ae283

File tree

12 files changed

+135
-78
lines changed

12 files changed

+135
-78
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
105105
Tosa_Tensor4D:$input,
106106
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
107107
Tosa_Tensor1D:$bias,
108-
Optional<Tosa_ZeroPointTensor>:$input_zp,
109-
Optional<Tosa_ZeroPointTensor>:$weight_zp,
108+
Optional<Tosa_ScalarTensor>:$input_zp,
109+
Optional<Tosa_ScalarTensor>:$weight_zp,
110110
Tosa_IntArrayAttr4:$pad,
111111
Tosa_IntArrayAttr2:$stride,
112112
Tosa_IntArrayAttr2:$dilation,
@@ -136,8 +136,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
136136
Tosa_Tensor5D:$input,
137137
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
138138
Tosa_Tensor1D:$bias,
139-
Optional<Tosa_ZeroPointTensor>:$input_zp,
140-
Optional<Tosa_ZeroPointTensor>:$weight_zp,
139+
Optional<Tosa_ScalarTensor>:$input_zp,
140+
Optional<Tosa_ScalarTensor>:$weight_zp,
141141
Tosa_IntArrayAttr6:$pad,
142142
Tosa_IntArrayAttr3:$stride,
143143
Tosa_IntArrayAttr3:$dilation,
@@ -168,8 +168,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
168168
Tosa_Tensor4D:$input,
169169
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
170170
Tosa_Tensor1D:$bias,
171-
Optional<Tosa_ZeroPointTensor>:$input_zp,
172-
Optional<Tosa_ZeroPointTensor>:$weight_zp,
171+
Optional<Tosa_ScalarTensor>:$input_zp,
172+
Optional<Tosa_ScalarTensor>:$weight_zp,
173173
Tosa_IntArrayAttr4:$pad,
174174
Tosa_IntArrayAttr2:$stride,
175175
Tosa_IntArrayAttr2:$dilation,
@@ -356,8 +356,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
356356
Tosa_Tensor4D:$input,
357357
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
358358
Tosa_Tensor1D:$bias,
359-
Optional<Tosa_ZeroPointTensor>:$input_zp,
360-
Optional<Tosa_ZeroPointTensor>:$weight_zp,
359+
Optional<Tosa_ScalarTensor>:$input_zp,
360+
Optional<Tosa_ScalarTensor>:$weight_zp,
361361
Tosa_IntArrayAttr4:$out_pad,
362362
Tosa_IntArrayAttr2:$stride,
363363
Tosa_IntArrayAttr4:$out_shape,
@@ -817,7 +817,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [
817817
let arguments = (ins
818818
Tosa_Tensor:$input1,
819819
Tosa_Tensor:$input2,
820-
Optional<TosaTensorRankOf<[Tosa_Int8], [1]>>:$shift
820+
// Apply right shift on i32_t input data only
821+
Tosa_ScalarInt8Tensor:$shift
821822
);
822823

823824
let results = (outs
@@ -1590,7 +1591,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
15901591
let arguments = (ins
15911592
Tosa_RankedTensor:$input1,
15921593
Tosa_Shape:$padding,
1593-
Optional<Tosa_ScalarTensor>:$pad_const,
1594+
Optional<Tosa_Rank0Tensor>:$pad_const,
15941595
OptionalAttr<I32Attr>:$input_zp
15951596
);
15961597

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[
9393
IsRankedTensorTypePred,
9494
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
9595

96+
def AllDimensionsAreSizeOne : And<[
97+
IsRankedTensorTypePred,
98+
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
99+
96100
class TosaTensorOf<
97101
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
98102
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
@@ -109,6 +113,11 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
109113
[HasAnyRankOfPred<ranks>],
110114
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
111115

116+
class TosaScalarTensorOf<list<Type> allowedTypes, list<int> ranks>
117+
: TosaRankedTensorOf<allowedTypes,
118+
[HasAnyRankOfPred<ranks>, AllDimensionsAreSizeOne],
119+
"tosa-conformant scalar tensor">;
120+
112121
//===----------------------------------------------------------------------===//
113122
// Tensor types
114123
//===----------------------------------------------------------------------===//
@@ -136,8 +145,10 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
136145
// Tensor types with constrained ranks.
137146
//===----------------------------------------------------------------------===//
138147

139-
// Rank-0 (scalar) tensor
140-
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
148+
def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
149+
150+
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
151+
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
141152

142153
// We include unranked tensors as a supported type for all possible tosa
143154
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
@@ -296,9 +307,4 @@ def Rank1TosaShape : TosaShapeOfRank<1>;
296307
def Rank2TosaShape : TosaShapeOfRank<2>;
297308
def Rank4TosaShape : TosaShapeOfRank<4>;
298309

299-
// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this
300-
// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the
301-
// following def can be removed.
302-
def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>;
303-
304310
#endif // TOSA_TYPES_BASE

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,22 +92,27 @@ static Value createLinalgBodyCalculationForElementwiseOp(
9292
// tosa::MulOp
9393
if (isa<tosa::MulOp>(op)) {
9494
auto shift_val = cast<tosa::MulOp>(op).getShift();
95+
ElementsAttr shift_elem;
96+
if (!shift_val.getImpl() ||
97+
!matchPattern(shift_val, m_Constant(&shift_elem))) {
98+
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
99+
}
100+
101+
int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
95102

96103
if (isa<FloatType>(elementTy)) {
104+
if (shift != 0) {
105+
(void)rewriter.notifyMatchFailure(op,
106+
"Cannot have shift value for float");
107+
return nullptr;
108+
}
97109
return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
98110
}
99111

100112
if (isa<IntegerType>(elementTy)) {
101-
int32_t shift = 0;
102-
ElementsAttr shift_elem;
103-
if (shift_val.getImpl() &&
104-
matchPattern(shift_val, m_Constant(&shift_elem))) {
105-
// Explicit shift is set.
106-
shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
107-
}
108-
109113
Value a = args[0];
110114
Value b = args[1];
115+
111116
if (shift > 0) {
112117
auto shiftConst =
113118
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,16 +1130,10 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
11301130
ValueShapeRange operands, DictionaryAttr attributes,
11311131
OpaqueProperties properties, RegionRange regions,
11321132
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1133-
LogicalResult status = success();
1133+
// mul op's output shape only depend on input1 and input2, not on shift
1134+
ValueShapeRange twoInputs = operands.drop_back();
11341135
llvm::SmallVector<int64_t> outShape;
1135-
if (operands.size() == 2) {
1136-
status = resolveBroadcastShape(operands, outShape);
1137-
} else {
1138-
// mul op's output shape only depend on input1 and input2, not on shift
1139-
ValueShapeRange two_inputs = operands.drop_back();
1140-
status = resolveBroadcastShape(two_inputs, outShape);
1141-
}
1142-
if (status.failed()) {
1136+
if (resolveBroadcastShape(twoInputs, outShape).failed()) {
11431137
inferredReturnShapes.push_back(ShapedTypeComponents());
11441138
} else {
11451139
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
@@ -1174,6 +1168,15 @@ LogicalResult tosa::MulOp::verify() {
11741168
return emitOpError(
11751169
"requires the same element type for all operands and results");
11761170
}
1171+
1172+
// verify shift has value 0 for non-integer types
1173+
ElementsAttr shift_elem;
1174+
if (matchPattern(getShift(), m_Constant(&shift_elem))) {
1175+
int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1176+
if (shift != 0) {
1177+
return emitOpError() << "require shift to be 0 for float type";
1178+
}
1179+
}
11771180
}
11781181

11791182
// Verify the op has same ranks for all main operands (excludes extra operands

mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,7 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,
287287

288288
for (Value operand : op->getOperands()) {
289289
// If this is a problem in future, think about alternatives to recursion.
290-
if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
291-
operand == op->getOperand(2)) {
290+
if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) {
292291
// do not recurse into MulOp's shift operand
293292
continue;
294293
}
@@ -332,8 +331,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
332331
for (Value v : op->getOperands()) {
333332
if (valuesMap.contains(v)) {
334333
operands.push_back(valuesMap.at(v));
335-
} else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
336-
v == op->getOperand(2)) {
334+
} else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) {
337335
// special case for MulOp's shift operand
338336
operands.push_back(v);
339337
} else {

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
472472

473473
// CHECK: linalg.generic
474474
// CHECK: arith.mulf
475-
%4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
475+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
476+
%4 = tosa.mul %0, %1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32>
476477

477478
// CHECK: linalg.generic
478479
// CHECK: arith.negf
@@ -618,7 +619,8 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
618619
// CHECK: arith.extsi
619620
// CHECK: arith.extsi
620621
// CHECK: arith.muli
621-
%0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
622+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
623+
%0 = tosa.mul %arg0, %arg0, %shift : (tensor<1xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1xi32>
622624

623625
return
624626
}

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,9 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
322322
func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
323323
// CHECK: return %arg0
324324
// CHECK-NOT: tosa.mul
325+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
325326
%ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
326-
%1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
327+
%1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
327328
return %1 : tensor<2x3xf32>
328329
}
329330

@@ -334,7 +335,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
334335
// CHECK: return %arg0
335336
// CHECK-NOT: tosa.mul
336337
%ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
337-
%1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
338+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
339+
%1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
338340
return %1 : tensor<2x3xf32>
339341
}
340342

@@ -370,11 +372,12 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
370372
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
371373
// CHECK-NOT: tosa.mul
372374
%zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
373-
%1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
375+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
376+
%1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32>
374377

375378
// CHECK-NOT: tosa.mul
376379
// CHECK: return %[[ZERO]], %[[ZERO]]
377-
%2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
380+
%2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
378381
return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
379382
}
380383

@@ -974,7 +977,8 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
974977
// CHECK: tosa.mul
975978
%0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
976979
%1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
977-
%2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>)-> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
980+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
981+
%2 = tosa.mul %0, %1, %shift : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi8>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
978982
return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
979983
}
980984

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
238238
func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
239239
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
240240
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
241-
%mul = tosa.mul %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
241+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
242+
%mul = tosa.mul %arg0, %zero, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
242243
// CHECK: return %[[ZERO]]
243244
return %mul : tensor<f32>
244245
}
@@ -249,7 +250,8 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
249250
func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
250251
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
251252
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
252-
%mul = tosa.mul %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
253+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
254+
%mul = tosa.mul %zero, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
253255
// CHECK: return %[[ZERO]]
254256
return %mul : tensor<f32>
255257
}
@@ -283,7 +285,8 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
283285
// CHECK-LABEL: @fold_mul_one_rhs_f32
284286
func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
285287
%one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
286-
%mul = tosa.mul %arg0, %one : (tensor<f32>, tensor<f32>) -> tensor<f32>
288+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
289+
%mul = tosa.mul %arg0, %one, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
287290
// CHECK: return %arg0
288291
return %mul : tensor<f32>
289292
}
@@ -293,7 +296,8 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
293296
// CHECK-LABEL: @fold_mul_one_lhs_f32
294297
func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
295298
%one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
296-
%mul = tosa.mul %one, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
299+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
300+
%mul = tosa.mul %one, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
297301
// CHECK: return %arg0
298302
return %mul : tensor<f32>
299303
}
@@ -339,7 +343,8 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> {
339343
func.func @fold_mul_splat_f32() -> tensor<10xf32> {
340344
%one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
341345
%two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
342-
%mul = tosa.mul %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
346+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
347+
%mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32>
343348
// CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>}
344349
// CHECK: return %[[THREE]]
345350
return %mul : tensor<10xf32>

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -768,26 +768,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
768768

769769
// CHECK-LABEL: test_mul_type_mismatch
770770
func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
771+
%shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
771772
// expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}}
772-
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf16>) -> tensor<13x21x3xf32>
773+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf16>, tensor<1xi8>) -> tensor<13x21x3xf32>
773774
return %0 : tensor<13x21x3xf32>
774775
}
775776

776777
// -----
777778

778779
// CHECK-LABEL: test_mul_invalid_shift
779-
func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
780-
%shift = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
781-
// expected-error@+1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>'}}
782-
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<f32>) -> tensor<13x21x3xi32>
783-
return %0 : tensor<13x21x3xi32>
780+
func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
781+
%shift = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
782+
// expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
783+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
784+
return %0 : tensor<13x21x3xf32>
784785
}
785786

786787
// -----
787788

788789
// CHECK-LABEL: test_mul_missing_shift
789790
func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
790-
// this is ok because mul's shift operand is optional for now
791+
// expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}}
791792
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
792793
return %0 : tensor<13x21x3xi32>
793794
}
@@ -1099,3 +1100,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
10991100
%0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32>
11001101
return %0 : tensor<1x13x21x3xf32>
11011102
}
1103+
1104+
// -----
1105+
// CHECK-LABEL: test_mul_non_scalar_shift_2d
1106+
func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
1107+
%shift = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
1108+
// expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
1109+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32>
1110+
return %0 : tensor<13x21x3xf32>
1111+
}
1112+
1113+
// -----
1114+
// CHECK-LABEL: test_mul_non_scalar_shift_1d
1115+
func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
1116+
%shift = "tosa.const"() <{value = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8>
1117+
// expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
1118+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32>
1119+
return %0 : tensor<13x21x3xf32>
1120+
}
1121+
1122+
// -----
1123+
// CHECK-LABEL: test_mul_non_broadcast
1124+
func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
1125+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
1126+
// expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
1127+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
1128+
return %0 : tensor<13x21x3xf32>
1129+
}

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ func.func @test_mul_scalar_with_unranked_output(%arg0: tensor<f32>, %arg1: tenso
355355
// -----
356356
// CHECK-LABEL: mul
357357
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
358-
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
358+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
359+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
359360
return %0 : tensor<13x21x3xf32>
360361
}
361362

0 commit comments

Comments
 (0)