Skip to content

Commit b923f6c

Browse files
Jerry-GeTai78641
andauthored
[mlir][tosa] Require PadOp's pad_const to be rank1 (#129156)
Update PadOp's pad_const input to be rank1. Fix various lit tests for this change including some conv ops Signed-off-by: Jerry Ge <[email protected]> Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Tai Ly <[email protected]>
1 parent c752924 commit b923f6c

File tree

8 files changed

+22
-21
lines changed

8 files changed

+22
-21
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1882,7 +1882,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
18821882
let arguments = (ins
18831883
Tosa_RankedTensor:$input1,
18841884
Tosa_Shape:$padding,
1885-
Optional<Tosa_Rank0Tensor>:$pad_const,
1885+
Optional<Tosa_ScalarTensor>:$pad_const,
18861886
OptionalAttr<I32Attr>:$input_zp
18871887
);
18881888

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
206206
}
207207

208208
auto denseAttr = DenseElementsAttr::get(
209-
RankedTensorType::get({}, elementTy), constantAttr);
209+
RankedTensorType::get({1}, elementTy), constantAttr);
210210
auto constantVal = rewriter.create<tosa::ConstOp>(
211211
op.getLoc(), denseAttr.getType(), denseAttr);
212212

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
127127

128128
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
129129

130-
auto padTy = RankedTensorType::get({}, inputETy);
130+
auto padTy = RankedTensorType::get({1}, inputETy);
131131
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
132132
Value padVal =
133133
rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);

mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,8 @@ func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
542542
// CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
543543
// CHECK: tensor.yield [[CST]]
544544
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
545-
%1 = arith.constant dense<42.0> : tensor<f32>
546-
%2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, !tosa.shape<4>, tensor<f32>) -> (tensor<4x9xf32>)
545+
%1 = arith.constant dense<42.0> : tensor<1xf32>
546+
%2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, !tosa.shape<4>, tensor<1xf32>) -> (tensor<4x9xf32>)
547547
return %2 : tensor<4x9xf32>
548548
}
549549

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
288288

289289
// CHECK-LABEL: @pad_determine_val_i32
290290
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
291-
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
291+
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi32>}
292292
// CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
293293
// CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
294294
%0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
@@ -300,7 +300,7 @@ func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>
300300

301301
// CHECK-LABEL: @pad_determine_val_f32
302302
func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
303-
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
303+
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}
304304
// CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
305305
// CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
306306
%0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
@@ -312,7 +312,7 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
312312

313313
// CHECK-LABEL: @pad_determine_val_quant
314314
func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
315-
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
315+
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi32>}
316316
// CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
317317
// CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
318318
%0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
218218

219219
// -----
220220

221-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> {
221+
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
222222
%0 = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
223223
// expected-error@+1 {{'tosa.pad' op pad_const of pad is not constant}}
224-
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<i8>) -> tensor<13x21x3xi8>
224+
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
225225
return %1 : tensor<13x21x3xi8>
226226
}
227227

@@ -248,7 +248,7 @@ func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
248248
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
249249
%0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
250250
%1 = "tosa.const"() {value = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
251-
// expected-error@+1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<2xf32>'}}
251+
// expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
252252
%2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32>
253253
return
254254
}
@@ -545,22 +545,22 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
545545
// -----
546546

547547
func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
548-
%input_zp = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
549-
%weight_zp = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
548+
%input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
549+
%weight_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
550550
// expected-error@+1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}}
551551
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
552-
: (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<i32>, tensor<i32>) -> tensor<1x27x27x16xf32>
552+
: (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x27x27x16xf32>
553553
return %0 : tensor<1x27x27x16xf32>
554554
}
555555

556556
// -----
557557

558558
func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
559-
%input_zp = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
560-
%weight_zp = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
559+
%input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
560+
%weight_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
561561
// expected-error@+1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x?x0x4xf32>'}}
562562
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
563-
: (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<i32>, tensor<i32>) -> tensor<1x27x27x16xf32>
563+
: (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x27x27x16xf32>
564564
return %0 : tensor<1x27x27x16xf32>
565565
}
566566

@@ -1046,6 +1046,7 @@ func.func @test_shape_type(%arg0: !tosa.shape<-1>) -> !tosa.shape<-1> {
10461046
}
10471047

10481048
// -----
1049+
10491050
func.func @test_const_shape() -> !tosa.shape<4> {
10501051
// expected-error@+1 {{'tosa.const_shape' op attribute 'value' failed to satisfy constraint: index elements attribute}}
10511052
%cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> !tosa.shape<4>

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,9 +591,9 @@ func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
591591
// -----
592592
// CHECK-LABEL: pad_explicit_value
593593
func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
594-
%0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32>
594+
%0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
595595
%padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
596-
%1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x21x3xf32>
596+
%1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
597597
return %1 : tensor<13x21x3xf32>
598598
}
599599

mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
6060
func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
6161
// CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 2, 1]> : tensor<5xindex>}
6262
// CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xindex>} : () -> !tosa.shape<10>
63-
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
63+
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}
6464
// CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>}
6565
// CHECK-DAG: %[[CONST4:.+]] = tosa.const_shape {value = dense<[4, 12, 12, 6]> : tensor<4xindex>}
6666
// CHECK-DAG: %[[CONST5:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>}
6767
// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
6868
// CHECK: %[[RESHAPE_I:.+]] = tosa.reshape %arg0, %[[CONST0]]
69-
// CHECK: %[[PAD_I:.+]] = tosa.pad %[[RESHAPE_I]], %[[PAD]], %[[ZERO]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
69+
// CHECK: %[[PAD_I:.+]] = tosa.pad %[[RESHAPE_I]], %[[PAD]], %[[ZERO]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<1xf32>) -> tensor<4x12x12x2x1xf32>
7070
// CHECK: %[[RESHAPE_ARG1:.+]] = tosa.reshape %arg1, %[[CONST3]]
7171
// CHECK: %[[MUL:.+]] = tosa.mul %[[PAD_I]], %[[RESHAPE_ARG1]], %[[SHIFT]]
7272
// CHECK: %[[RESHAPE_O:.+]] = tosa.reshape %[[MUL]], %[[CONST4]]

0 commit comments

Comments
 (0)