Skip to content

Commit 2078c7c

Browse files
committed
[mlir][tosa] Enhance verify checks for PAD Op
* add padding shape verification * add and update LIT test Change-Id: Ie77ba21d271362906618389cf90cf0af20e2fcae Signed-off-by: Peng Sun <[email protected]>
1 parent 1041d54 commit 2078c7c

File tree

4 files changed

+61
-18
lines changed

4 files changed

+61
-18
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,15 +1515,49 @@ LogicalResult tosa::PadOp::verify() {
15151515
if (!inputType || !outputType)
15161516
return success();
15171517

1518-
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
1518+
auto inputRank = inputType.getRank();
1519+
auto outputRank = outputType.getRank();
1520+
if (inputRank != outputRank)
1521+
return emitOpError() << "expect same input and output tensor rank, but got "
1522+
<< "inputRank: " << inputRank
1523+
<< ", outputRank: " << outputRank;
1524+
1525+
DenseIntElementsAttr paddingAttr;
1526+
if (!matchPattern(getPadding(), m_Constant(&paddingAttr)))
1527+
return failure();
1528+
1529+
auto paddingValues = paddingAttr.getValues<APInt>();
1530+
if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1531+
return emitOpError() << "padding tensor must have " << inputRank
1532+
<< " * 2 = " << inputRank * 2 << " elements, but got "
1533+
<< paddingValues.size();
1534+
1535+
auto inputShape = inputType.getShape();
1536+
auto outputShape = outputType.getShape();
1537+
1538+
for (int64_t i = 0; i < inputRank; ++i) {
1539+
// Skip shape verification for dynamic dims
1540+
if (inputShape[i] == ShapedType::kDynamic ||
1541+
outputShape[i] == ShapedType::kDynamic)
1542+
continue;
1543+
1544+
int64_t padStart = paddingValues[i * 2].getSExtValue();
1545+
int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
15191546

1520-
if (inputType.getRank() != outputType.getRank())
1521-
return emitOpError() << "expect same input and output tensor rank.";
1547+
if (padStart < 0 || padEnd < 0) {
1548+
return emitOpError() << "padding values must be non-negative, got ["
1549+
<< padStart << ", " << padEnd << "] for dimension "
1550+
<< i;
1551+
}
15221552

1523-
if (paddingRank != inputType.getRank() * 2)
1524-
return emitOpError() << "expected padding tensor dim 0 to have size "
1525-
<< inputType.getRank() * 2
1526-
<< " (2*rank(shape1)) but got size " << paddingRank;
1553+
if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1554+
return emitOpError() << "mismatch in output shape at dimension " << i
1555+
<< ": expected " << inputShape[i] << " + "
1556+
<< padStart << " + " << padEnd << " = "
1557+
<< (inputShape[i] + padStart + padEnd)
1558+
<< ", but got " << outputShape[i];
1559+
}
1560+
}
15271561

15281562
return success();
15291563
}

mlir/test/Dialect/Tosa/dynamic_extension.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>)
2020

2121
// -----
2222

23-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
23+
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
2424
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
25-
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
26-
return %1 : tensor<13x21x3xi8>
25+
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
26+
return %1 : tensor<13x22x4xi8>
2727
}
2828

2929
// -----

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2
303303

304304
// -----
305305

306-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
306+
func.func @test_pad_padding_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
307307
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
308308
// expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}
309309
%0 = tosa.pad %arg0, %arg1, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
@@ -312,9 +312,18 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
312312

313313
// -----
314314

315-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
315+
func.func @test_pad_pad_const_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
316316
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
317317
// expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
318+
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
319+
return %1 : tensor<13x22x4xi8>
320+
}
321+
322+
// -----
323+
324+
func.func @test_pad_output_mismatch(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
325+
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
326+
// expected-error@+1 {{mismatch in output shape at dimension 1: expected 21 + 0 + 1 = 22, but got 21}}
318327
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
319328
return %1 : tensor<13x21x3xi8>
320329
}
@@ -324,7 +333,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) ->
324333
func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
325334
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
326335
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
327-
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
336+
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank}}
328337
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
329338
}
330339

@@ -341,7 +350,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens
341350
func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
342351
%0 = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
343352
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
344-
// expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
353+
// expected-error@+1 {{'tosa.pad' op padding tensor must have 2 * 2 = 4 elements, but got 6}}
345354
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21xf32>
346355
return
347356
}
@@ -361,7 +370,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
361370
func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
362371
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
363372
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
364-
// expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
373+
// expected-error@+1 {{'tosa.pad' op padding tensor must have 3 * 2 = 6 elements, but got 4}}
365374
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
366375
return %1 : tensor<13x21x3xf32>
367376
}

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,11 +407,11 @@ func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21
407407

408408
// -----
409409

410-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
410+
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
411411
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
412412
// expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
413-
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
414-
return %1 : tensor<13x21x3xi8>
413+
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
414+
return %1 : tensor<13x22x4xi8>
415415
}
416416

417417
// -----

0 commit comments

Comments
 (0)