Skip to content

[mlir][tosa] Enhance verify checks for PAD Op #137177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1534,15 +1534,51 @@ LogicalResult tosa::PadOp::verify() {
if (!inputType || !outputType)
return success();

auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
auto inputRank = inputType.getRank();
auto outputRank = outputType.getRank();
if (inputRank != outputRank)
return emitOpError() << "expect same input and output tensor rank, but got "
<< "inputRank: " << inputRank
<< ", outputRank: " << outputRank;

if (inputType.getRank() != outputType.getRank())
return emitOpError() << "expect same input and output tensor rank.";
DenseIntElementsAttr paddingAttr;
if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
return failure();
}

auto paddingValues = paddingAttr.getValues<APInt>();
if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
return emitOpError() << "padding tensor must have " << inputRank
<< " * 2 = " << inputRank * 2 << " elements, but got "
<< paddingValues.size();

if (paddingRank != inputType.getRank() * 2)
return emitOpError() << "expected padding tensor dim 0 to have size "
<< inputType.getRank() * 2
<< " (2*rank(shape1)) but got size " << paddingRank;
auto inputShape = inputType.getShape();
auto outputShape = outputType.getShape();

for (int64_t i = 0; i < inputRank; ++i) {
int64_t padStart = paddingValues[i * 2].getSExtValue();
int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();

if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
return emitOpError()
<< "invalid padding values at dimension " << i
<< ": values must be non-negative or -1 for dynamic padding, got ["
<< padStart << ", " << padEnd << "]";
}

// Skip shape verification for dynamic input/output
if (inputShape[i] == ShapedType::kDynamic ||
outputShape[i] == ShapedType::kDynamic)
continue;

if (outputShape[i] != inputShape[i] + padStart + padEnd) {
return emitOpError() << "mismatch in output shape at dimension " << i
<< ": expected " << inputShape[i] << " + "
<< padStart << " + " << padEnd << " = "
<< (inputShape[i] + padStart + padEnd)
<< ", but got " << outputShape[i];
}
}

return success();
}
Expand Down
24 changes: 23 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,9 +1112,31 @@ bool checkErrorIfRescale(Operation *op) {
return true;
}

bool checkErrorIfPad(Operation *op) {
auto pad = dyn_cast<tosa::PadOp>(op);
if (!pad)
return true;

DenseIntElementsAttr paddingAttr;
if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
// Pad verifier will catch this
return true;

for (const APInt &val : paddingAttr.getValues<APInt>()) {
if (val.getSExtValue() < 0) {
op->emitOpError() << "padding value must all be non-negative, got "
<< val.getSExtValue();
return false;
}
}

return true;
}

LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op))
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
!checkErrorIfPad(op))
return failure();
return success();
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/dynamic_extension.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>)

// -----

func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
return %1 : tensor<13x22x4xi8>
}

// -----
Expand Down
34 changes: 12 additions & 22 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens

// -----

func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
func.func @test_pad_padding_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}
%0 = tosa.pad %arg0, %arg1, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
Expand All @@ -281,19 +281,19 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)

// -----

func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
func.func @test_pad_const_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
// expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
return %1 : tensor<13x22x4xi8>
}

// -----

func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank}}
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
}

Expand All @@ -307,17 +307,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens

// -----

func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21xf32>
return
}

// -----

func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
// expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
Expand All @@ -327,12 +317,12 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso

// -----

func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
return %1 : tensor<13x21x3xf32>
func.func @test_pad_invalid_padding_value(%arg0: tensor<10xf32>) {
%0 = tosa.const_shape {values = dense<[-1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%1 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{padding value must all be non-negative, got -1}}
%2 = tosa.pad %arg0, %0, %1 : (tensor<10xf32>, !tosa.shape<2>, tensor<1xf32>) -> tensor<10xf32>
return
}

// -----
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/invalid_extension.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,11 @@ func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21

// -----

func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
// expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
return %1 : tensor<13x22x4xi8>
}

// -----
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/Dialect/Tosa/verifier.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,37 @@ func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
return %0 : tensor<13x26x8xf32>
}

// -----
func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.pad' op padding tensor must have 3 * 2 = 6 elements, but got 4}}
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
return %1 : tensor<13x21x3xf32>
}

// -----
func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.pad' op padding tensor must have 2 * 2 = 4 elements, but got 6}}
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21xf32>
return
}

// -----
func.func @test_pad_output_mismatch(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
// expected-error@+1 {{mismatch in output shape at dimension 1: expected 21 + 0 + 1 = 22, but got 21}}
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
}

// -----
func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1xi8>) -> tensor<10xi8> {
%0 = tosa.const_shape {values = dense<[-2, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
// expected-error@+1 {{invalid padding values at dimension 0: values must be non-negative or -1 for dynamic padding, got [-2, 2]}}
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<10xi8>, !tosa.shape<2>, tensor<1xi8>) -> tensor<10xi8>
return %1 : tensor<10xi8>
}
Loading