Skip to content

Commit f5749e7

Browse files
authored
[mlir][tosa] Remove out_shape from transpose_conv2d (#129133)
1 parent 7446601 commit f5749e7

File tree

6 files changed

+25
-43
lines changed

6 files changed

+25
-43
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,11 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
148148
"::mlir::Value":$weight, "mlir::Value":$bias,
149149
"::mlir::DenseI64ArrayAttr":$outpad,
150150
"::mlir::DenseI64ArrayAttr":$stride,
151-
"::mlir::DenseI64ArrayAttr":$outputShape,
152151
"::mlir::TypeAttr":$acc_type),
153152
[{
154153
buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
155154
input, weight, bias,
156-
outpad, stride,
157-
outputShape, acc_type);
155+
outpad, stride, acc_type);
158156
}]>;
159157

160158
// The tosa.matmul op is also intended to be generated where a fully_connected

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,6 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
408408

409409
Tosa_IntArrayAttr4:$out_pad,
410410
Tosa_IntArrayAttr2:$stride,
411-
Tosa_IntArrayAttr4:$out_shape,
412411
TypeAttrOf<Tosa_AccType>:$acc_type,
413412
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
414413
);

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -569,15 +569,15 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
569569

570570
/// Handles tosa.transpose_conv2d which has outpad and output shape
571571
/// attributes.
572-
static void buildTransConvOpWithQuantInfo(
573-
OpBuilder &builder, OperationState &result, Type outputType, Value input,
574-
Value weight, Value bias, DenseI64ArrayAttr outpad,
575-
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
572+
static void
573+
buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
574+
Type outputType, Value input, Value weight,
575+
Value bias, DenseI64ArrayAttr outpad,
576+
DenseI64ArrayAttr stride, TypeAttr accType) {
576577
auto zps = createZPsAsConst(builder, input, weight);
577578
result.addOperands({input, weight, bias, zps.first, zps.second});
578579
result.addAttribute("out_pad", outpad);
579580
result.addAttribute("stride", stride);
580-
result.addAttribute("out_shape", outputShape);
581581
result.addAttribute("acc_type", accType);
582582
Type finalOutputType = outputType;
583583
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
@@ -2327,9 +2327,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
23272327
MLIRContext *context, ::std::optional<Location> location,
23282328
TransposeConv2DOp::Adaptor adaptor,
23292329
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2330-
// outputShape is mutable.
2331-
llvm::SmallVector<int64_t> outputShape =
2332-
convertToMlirShape(adaptor.getOutShape());
2330+
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
23332331

23342332
int64_t inputWidth = ShapedType::kDynamic;
23352333
int64_t inputHeight = ShapedType::kDynamic;

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tens
168168
func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> {
169169
%zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
170170
// expected-error@+1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}}
171-
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
171+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
172172
return %0 : tensor<1x32x32x16xi8>
173173
}
174174

@@ -741,15 +741,6 @@ func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6
741741

742742
// -----
743743

744-
// CHECK-LABEL: test_transpose_conv2d_invalid_outshape
745-
func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
746-
// expected-error@+1 {{'tosa.transpose_conv2d' op attribute 'out_shape' failed to satisfy constraint: i64 dense array attribute with exactly 4 elements}}
747-
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
748-
return %0 : tensor<1x32x32x16xf32>
749-
}
750-
751-
// -----
752-
753744
// CHECK-LABEL: test_mul_type_mismatch
754745
func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
755746
%shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>

mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,16 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor
3232

3333
// CHECK-LABEL: @transpose_conv2d_quantized_padded
3434
func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
35-
// CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}
36-
// CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>}
37-
// CHECK-DAG: %[[REV0:.+]] = tosa.reverse %2 {axis = 2 : i32}
38-
// CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
39-
// CHECK: tosa.conv2d %arg0, %3, %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
40-
// CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
41-
// CHECK-SAME: stride = array<i64: 1, 1>}
42-
%input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
43-
%weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
35+
// CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
36+
// CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
37+
// CHECK-DAG: %[[REV0:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
38+
// CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[REV0]] {axis = 2 : i32}
39+
// CHECK: tosa.conv2d %arg0, %[[REV1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>, stride = array<i64: 1, 1>}
40+
%input_zp = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
41+
%weight_zp = "tosa.const"() <{value = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
4442
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
4543
acc_type = i32,
4644
out_pad = array<i64: 1, 2, 3, 4>,
47-
out_shape = array<i64: -1, -1, -1, -1>,
4845
stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x21x26x5xi32>
4946
return %0 : tensor<2x21x26x5xi32>
5047
}
@@ -160,12 +157,11 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
160157
// CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]
161158
// CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[CONST10]]
162159
// CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]]
163-
%input_zp = "tosa.const"() {value = dense<-103> : tensor<1xi8>} : () -> tensor<1xi8>
164-
%weight_zp = "tosa.const"() {value = dense<93> : tensor<1xi8>} : () -> tensor<1xi8>
160+
%input_zp = "tosa.const"() <{value = dense<-103> : tensor<1xi8>}> : () -> tensor<1xi8>
161+
%weight_zp = "tosa.const"() <{value = dense<93> : tensor<1xi8>}> : () -> tensor<1xi8>
165162
%2 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
166163
acc_type = i32,
167164
out_pad = array<i64: 2, 0, 0, 1>,
168-
out_shape = array<i64: 1, -1, -1, 1>,
169165
stride = array<i64: 1, 2>} :
170166
(tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
171167
"func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor<
907907
// CHECK-LABEL: @transpose_conv2d_out_shape
908908
func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
909909
// CHECK: -> tensor<2x8x9x5xf32>
910-
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, 8, 9, -1>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
910+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
911911
return
912912
}
913913

@@ -916,7 +916,7 @@ func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<
916916
// CHECK-LABEL: @transpose_conv2d_static
917917
func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
918918
// CHECK: -> tensor<2x18x19x5xf32>
919-
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
919+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
920920
return
921921
}
922922

@@ -925,7 +925,7 @@ func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5
925925
// CHECK-LABEL: @transpose_conv2d_static_strided
926926
func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
927927
// CHECK: -> tensor<2x33x45x5xf32>
928-
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
928+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
929929
return
930930
}
931931

@@ -934,7 +934,7 @@ func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1:
934934
// CHECK-LABEL: @transpose_conv2d_dynamic_input
935935
func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
936936
// CHECK: -> tensor<?x?x?x5xf32>
937-
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x5xf32>
937+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x5xf32>
938938
return
939939
}
940940

@@ -943,7 +943,7 @@ func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: ten
943943
// CHECK-LABEL: @transpose_conv2d_dynamic_weights
944944
func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
945945
// CHECK: -> tensor<2x?x?x5xf32>
946-
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
946+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
947947
return
948948
}
949949

@@ -952,7 +952,7 @@ func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: t
952952
// CHECK-LABEL: @transpose_conv2d_dynamic_bias
953953
func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<?xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
954954
// CHECK: -> tensor<2x8x9x5xf32>
955-
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
955+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
956956
return
957957
}
958958

@@ -961,14 +961,14 @@ func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tens
961961
// CHECK-LABEL: @transpose_conv2d_padded
962962
func.func @transpose_conv2d_padded(%arg0: tensor<2x9x11x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
963963
// CHECK: -> tensor<2x10x13x5xf32>
964-
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 1, 0, 3, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x10x13x5xf32>
964+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 1, 0, 3, 0>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x10x13x5xf32>
965965
return
966966
}
967967

968968
// CHECK-LABEL: @transpose_conv2d_strided
969969
func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
970970
// CHECK: -> tensor<1x13x13x1xf32>
971-
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
971+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
972972
return
973973
}
974974

0 commit comments

Comments
 (0)