-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Fix for incorrect cannonicalization of tosa.pad #98356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The current fold method for tosa.pad can produce invalid IR by replacing the padded value with the tosa.pad is a noop. When the type of the input value does not match the type of the tosa.pad, the canonicalizer detects the change in types and asserts.
@llvm/pr-subscribers-mlir-tosa Author: Spenser Bauman (sabauma) ChangesThe current fold method for tosa.pad can produce invalid IR by replacing the padded value with the tosa.pad is a noop. When the type of the input value does not match the type of the tosa.pad, the canonicalizer detects the change in types and asserts. This change addresses the issue by avoiding folding when the input and result types do not match. Full diff: https://github.com/llvm/llvm-project/pull/98356.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 8687be075ea67..866ab0d2228f7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -859,7 +859,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
// If the pad is all zeros we can fold this operation away.
- if (adaptor.getPadding()) {
+ if (adaptor.getPadding() && getInput1().getType() == getType()) {
auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
return getInput1();
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index accc792c8f2ac..3bcf58015831b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -217,6 +217,20 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// -----
+// CHECK-LABEL: @pad_noop_type_mismatch_nofold
+func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32> {
+ // CHECK: %[[PAD:.+]] = tosa.pad
+ // CHECK: return %[[PAD]]
+
+ %c0_i32 = arith.constant 0 : i32
+ %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32>
+
+ %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
// CHECK-LABEL: @pad_determine_val_i32
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
|
@llvm/pr-subscribers-mlir Author: Spenser Bauman (sabauma) ChangesThe current fold method for tosa.pad can produce invalid IR by replacing the padded value with the tosa.pad is a noop. When the type of the input value does not match the type of the tosa.pad, the canonicalizer detects the change in types and asserts. This change addresses the issue by avoiding folding when the input and result types do not match. Full diff: https://github.com/llvm/llvm-project/pull/98356.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 8687be075ea67..866ab0d2228f7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -859,7 +859,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
// If the pad is all zeros we can fold this operation away.
- if (adaptor.getPadding()) {
+ if (adaptor.getPadding() && getInput1().getType() == getType()) {
auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
return getInput1();
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index accc792c8f2ac..3bcf58015831b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -217,6 +217,20 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// -----
+// CHECK-LABEL: @pad_noop_type_mismatch_nofold
+func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32> {
+ // CHECK: %[[PAD:.+]] = tosa.pad
+ // CHECK: return %[[PAD]]
+
+ %c0_i32 = arith.constant 0 : i32
+ %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32>
+
+ %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
// CHECK-LABEL: @pad_determine_val_i32
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
|
) The current fold method for tosa.pad can produce invalid IR by replacing the padded value with the tosa.pad is a noop. When the type of the input value does not match the type of the tosa.pad, the canonicalizer detects the change in types and asserts. This change addresses the issue by avoiding folding when the input and result types do not match.
The current fold method for tosa.pad can produce invalid IR by replacing the padded value with the tosa.pad is a noop. When the type of the input value does not match the type of the tosa.pad, the canonicalizer detects the change in types and asserts.
This change addresses the issue by avoiding folding when the input and result types do not match.