Skip to content

Commit 2fb53f3

Browse files
authored
[mlir][tosa] Fix for incorrect cannonicalization of tosa.pad (#98356)
The current fold method for tosa.pad can produce invalid IR by replacing the padded value with the tosa.pad is a noop. When the type of the input value does not match the type of the tosa.pad, the canonicalizer detects the change in types and asserts. This change addresses the issue by avoiding folding when the input and result types do not match.
1 parent b64c1de commit 2fb53f3

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
859859

860860
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
861861
// If the pad is all zeros we can fold this operation away.
862-
if (adaptor.getPadding()) {
862+
if (adaptor.getPadding() && getInput1().getType() == getType()) {
863863
auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
864864
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
865865
return getInput1();

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,20 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
217217

218218
// -----
219219

220+
// CHECK-LABEL: @pad_noop_type_mismatch_nofold
221+
func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32> {
222+
// CHECK: %[[PAD:.+]] = tosa.pad
223+
// CHECK: return %[[PAD]]
224+
225+
%c0_i32 = arith.constant 0 : i32
226+
%shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32>
227+
228+
%0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32>
229+
return %0 : tensor<?xf32>
230+
}
231+
232+
// -----
233+
220234
// CHECK-LABEL: @pad_determine_val_i32
221235
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
222236
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}

0 commit comments

Comments
 (0)