Skip to content

[mlir][arith] Add constraints to the MulIOp for preventing type mismatch while folding #136093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2025

Conversation

Prakhar-Dixit
Copy link
Contributor

Fixes #135289
The original version didn't check if the types of lhs, rhs, and the result matched, which could cause type errors.
This fix adds type checks to make sure the constants have the same type as the result before applying the simplification.

Minimal example crashing :

func.func @nested_muli() -> (i32) {
  %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
  %1 = "test.constant"() {value = -2147483648} : () -> i32
  %2 = "test.constant"() {value = 0x80000000} : () -> i32
  %4 = arith.muli %0, %1 : i32
  %5 = arith.muli %4, %2 : i32
  return %5 : i32
}

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2025

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Prakhar Dixit (Prakhar-Dixit)

Changes

Fixes #135289
The original version didn't check if the types of lhs, rhs, and the result matched, which could cause type errors.
This fix adds type checks to make sure the constants have the same type as the result before applying the simplification.

Minimal example crashing :

func.func @<!-- -->nested_muli() -&gt; (i32) {
  %0 = "test.constant"() {value = 0x7fffffff} : () -&gt; i32
  %1 = "test.constant"() {value = -2147483648} : () -&gt; i32
  %2 = "test.constant"() {value = 0x80000000} : () -&gt; i32
  %4 = arith.muli %0, %1 : i32
  %5 = arith.muli %4, %2 : i32
  return %5 : i32
}

Full diff: https://github.com/llvm/llvm-project/pull/136093.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+3-1)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+18)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 6d7ac2be951dd..7e212df9029d1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -90,7 +90,9 @@ def MulIMulIConstant :
           (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+             [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
+              (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f0b2731707d18..d62c5b18fd041 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1234,6 +1234,24 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
   return %add : index
 }
 
+// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
+// CHECK-LABEL:   func.func @nested_muli() -> i32 {
+// CHECK:           %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
+// CHECK:           %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
+// CHECK:           %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
+// CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : i32
+// CHECK:           %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32
+// CHECK:           return %[[VAL_4]] : i32
+// CHECK:         }
+func.func @nested_muli() -> (i32) {
+  %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
+  %1 = "test.constant"() {value = -2147483648} : () -> i32
+  %2 = "test.constant"() {value = 0x80000000} : () -> i32
+  %4 = arith.muli %0, %1 : i32
+  %5 = arith.muli %4, %2 : i32
+  return %5 : i32
+}
+
 // CHECK-LABEL: @tripleMulIMulIIndex
 //       CHECK:   %[[cres:.+]] = arith.constant 15 : index
 //       CHECK:   %[[muli:.+]] = arith.muli %arg0, %[[cres]] : index

@joker-eph joker-eph merged commit 35f4cdb into llvm:main Apr 17, 2025
14 checks passed
@Prakhar-Dixit Prakhar-Dixit deleted the pd/135289 branch April 17, 2025 09:19
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
…tch while folding (llvm#136093)

Fixes llvm#135289
The original version didn't check if the types of lhs, rhs, and the
result matched, which could cause type errors.
This fix adds type checks to make sure the constants attributes have
the same type as the SSA values before applying the simplification.
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this PR -- why do we only fix this for mul folders but not other ones? It seems like this might have been merged prematurely.

Isn't the root cause that it's generally unsupported to have contants whose attribute types don't match ssa types?

@joker-eph
Copy link
Collaborator

joker-eph commented Apr 17, 2025

We should fix them all in all likeliness. I tried to prevent constant where the attribute type mismatches the SSA value in the past, but it's a rabbit hole :(

I wrote a verifier for it at some point, but many cases are problematic:

complex.constant doesn't return a TypedAttribute:

  %0 = "complex.constant"() <{value = [1.200000e+00 : f32, 2.300000e+00 : f32]}> : () -> complex<f32>

tosa.const returns a TypedAttribute and mismatches:

  %input = "tosa.const"() {values = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>

llvm.constant accepts index for the attribute value:

  %0 = llvm.mlir.constant(1 : index) : i64

@kuhar
Copy link
Member

kuhar commented Apr 17, 2025

Could we handle this in some underlying utility used in DDR like m_Constant?

@kuhar
Copy link
Member

kuhar commented Apr 17, 2025

I would like to understand why we proceeded with this isolated workaround without explaining what the underlying issue is and considering proper solutions (or referencing previous investigations). This seems very far from a principled fix and I'm concerned about the proliferation of similar workarounds in the future.

@joker-eph
Copy link
Collaborator

I would like to understand why we proceeded with this isolated workaround

Because the underlying issue seems hard to fix fundamentally?

without explaining what the underlying issue is

What kind of explanation are you looking for other than "some constant returns an attribute that does not match the SSA value type"?

referencing previous investigations

There is a bug where we tried to look at llvm.mlir.constant behavior, as mentioned above, I can't find it (I must be bad at GitHub search, or the tool isn't very good...).

@joker-eph
Copy link
Collaborator

One of such issue: #74236
And the workaround PR: https://github.com/llvm/llvm-project/pull/88314/files

@kuhar
Copy link
Member

kuhar commented Apr 17, 2025

without explaining what the underlying issue is

What kind of explanation are you looking for other than "some constant returns an attribute that does not match the SSA value type"?

Thanks for linking the historical issues, I was not aware of this being identified an issue in the past. I would be great to put similar urls in the PR description to add context and be able to follow previous discussions and fixes.

For example, this piece of code provides a nice explanation for doing things a certain way: https://github.com/llvm/llvm-project/pull/88314/files#diff-c62b57552386a2a552ce6e3fe37bc23d399f2508636851b77ec6cd47fee906af.

Because the underlying issue seems hard to fix fundamentally?

The issue with this PR is that it seems to work around a larger bug in a way that applies to a single instance rather than following standard compiler engineering practices and attempting to provide a general solution. For example, I see many other patterns in this file that I'd also expect to run into the same problem -- I think it applies to the other folds that use MergeOverflow (?). I'd expect either the other uses of this to be updated as well or a brief explanation of whether this maybe an isolated issue with this pattern if this is not the case. I don't propose to fix the universe at once, but at least see some consideration for whether we can handle the issue in the same vicinity.

@makslevental
Copy link
Contributor

makslevental commented Apr 17, 2025

Because the underlying issue seems hard to fix fundamentally?

complex.constant doesn't return a TypedAttribute

Just double checking my understanding: is the issue that there's no uniform/generic interface for getting the type of the OpFoldResult? If so, couldn't we just promote the ConstantLike trait to full-fledged interface with a required member called getOpFoldResultType()? Then e.g., this PR could be spelled

(Constraint<CPred<"$0.getType() == cast<ConstantLike>($1).getOpFoldResultType()">> $res, $c0),
(Constraint<CPred<"$0.getType() == cast<ConstantLike>($1).getOpFoldResultType()">> $res, $c1)

or (preferrably) we could just bake this directly into the canonicalization/folding system.

@joker-eph
Copy link
Collaborator

If so, couldn't we just promote the ConstantLike trait to full-fledged interface with a required member called getOpFoldResultType()?

I'm not sure what that would get us? The folder still return an attribute and the consumer will try to handle it.
The problem is when this attribute type does not match the SSA value type.

Then e.g., this PR could be spelled

I think @kuhar 's objection is on the fast that the constraints needs to exist on the pattern in the first place.

or (preferrably) we could just bake this directly into the canonicalization/folding system.

Folding already forbids to fold to a different type, however we don't call the folder on constants themselves so it never shows up there (IIUC).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] Assertion `succeeded( ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...))' failed
5 participants