Skip to content

[mlir][arith] Fix multiplication canonicalizations #144787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

gysit
Copy link
Contributor

@gysit gysit commented Jun 18, 2025

The Arith dialect includes patterns that canonicalize a sequence of:

  • trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
  • trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width.

For example, the following code:

  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32

would incorrectly be canonicalized to:

_, %hi = arith.mului_extended %a, %b : i32

The Arith dialect includes patterns that canonicalize a sequence of:

- trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
- trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication,
which assumes that the shift amount is equal to the bit width of the
original operands. This check was missing, leading to incorrect
canonicalizations when the shift amount was less than the bit width.

For example, the following code:
```mlir
  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32
```
would incorrectly be canonicalized to:
```mlir
_, %hi = arith.mului_extended %a, %b : i32
````
@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2025

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Tobias Gysi (gysit)

Changes

The Arith dialect includes patterns that canonicalize a sequence of:

  • trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
  • trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width.

For example, the following code:

  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32

would incorrectly be canonicalized to:

_, %hi = arith.mului_extended %a, %b : i32

Full diff: https://github.com/llvm/llvm-project/pull/144787.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+11-3)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+31-1)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 13eb97a910bd4..2f7beed549108 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -273,7 +273,7 @@ def RedundantSelectFalse :
     Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
         (SelectOp $pred, $a, $c)>;
 
-// select(pred, false, true) => not(pred) 
+// select(pred, false, true) => not(pred)
 def SelectI1ToNot :
     Pat<(SelectOp $pred,
                   (ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
@@ -376,6 +376,12 @@ def TruncationMatchesShiftAmount :
       CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == "
               "*getIntOrSplatIntValue($2)">]>>;
 
+def ValueWidthMatchesShiftAmount :
+    Constraint<And<[
+      CPred<"succeeded(getIntOrSplatIntValue($1))">,
+      CPred<"getScalarOrElementWidth($0) == "
+              "*getIntOrSplatIntValue($1)">]>>;
+
 // trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated
 def TruncIExtSIToExtSI :
     Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)),
@@ -406,7 +412,8 @@ def TruncIShrUIMulIToMulSIExtended :
         (Arith_MulSIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
        (ValueWiderThan $mul, $x),
-       (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+       (TruncationMatchesShiftAmount $mul, $x, $c0),
+       (ValueWidthMatchesShiftAmount $x, $c0)]>;
 
 // trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
 def TruncIShrUIMulIToMulUIExtended :
@@ -417,7 +424,8 @@ def TruncIShrUIMulIToMulUIExtended :
         (Arith_MulUIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
        (ValueWiderThan $mul, $x),
-       (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+       (TruncationMatchesShiftAmount $mul, $x, $c0),
+       (ValueWidthMatchesShiftAmount $x, $c0)]>;
 
 //===----------------------------------------------------------------------===//
 // TruncIOp
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index b6188c81ff912..542603722ab8a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
 
 
 // CHECK-LABEL: @foldSubXX_tensor
-//       CHECK:   %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32> 
+//       CHECK:   %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
 //       CHECK:   %[[sub:.+]] = arith.subi
 //       CHECK:   return %[[c0]], %[[sub]]
 func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
@@ -2966,6 +2966,21 @@ func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
   return %hi : i32
 }
 
+// Verify that the signed extended multiplication pattern does not match
+// if the right shift does not match the bitwidth of the multipliers.
+
+// CHECK-LABEL: @wideMulToMulSIExtendedWithWrongShift
+//   CHECK-NOT: arith.mulsi_extended
+func.func @wideMulToMulSIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 {
+  %x = arith.extsi %a: i32 to i33
+  %y = arith.extsi %b: i32 to i33
+  %m = arith.muli %x, %y: i33
+  %c1 = arith.constant 1: i33
+  %sh = arith.shrui %m, %c1 : i33
+  %hi = arith.trunci %sh: i33 to i32
+  return %hi : i32
+}
+
 // CHECK-LABEL: @wideMulToMulSIExtendedVector
 //  CHECK-SAME:   (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
 //  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
@@ -2994,6 +3009,21 @@ func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
   return %hi : i32
 }
 
+// Verify that the unsigned extended multiplication pattern does not match
+// if the right shift does not match the bitwidth of the multipliers.
+
+// CHECK-LABEL: @wideMulToMulUIExtendedWithWrongShift
+//   CHECK-NOT: arith.mului_extended
+func.func @wideMulToMulUIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 {
+  %x = arith.extui %a: i32 to i33
+  %y = arith.extui %b: i32 to i33
+  %m = arith.muli %x, %y: i33
+  %c1 = arith.constant 1: i33
+  %sh = arith.shrui %m, %c1 : i33
+  %hi = arith.trunci %sh: i33 to i32
+  return %hi : i32
+}
+
 // CHECK-LABEL: @wideMulToMulUIExtendedVector
 //  CHECK-SAME:   (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
 //  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>

@gysit gysit requested a review from kuhar June 18, 2025 19:58
@gysit
Copy link
Contributor Author

gysit commented Jun 18, 2025

@kuhar this did show up as a downstream. I tried to fix to the best of my understanding but it definitely needs a second pair of eyes.

@gysit gysit requested a review from Copilot June 18, 2025 20:01
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR ensures that the Arith dialect’s extended multiplication canonicalization patterns only fire when the right‐shift amount exactly equals the operand bitwidth.

  • Added negative MLIR tests to verify that neither signed nor unsigned extended multiply patterns match when the shift constant is incorrect.
  • Introduced a new ValueWidthMatchesShiftAmount constraint and applied it to both signed (mulsi_extended) and unsigned (mului_extended) patterns.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
mlir/test/Dialect/Arith/canonicalize.mlir Added CHECK-NOT tests for wrong shift amounts in both signed and unsigned cases
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td Defined ValueWidthMatchesShiftAmount constraint and added it to relevant patterns
Comments suppressed due to low confidence (3)

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td:379

  • [nitpick] Add a brief comment above this constraint explaining that it requires the shift constant to equal the operand's bitwidth; this will help future maintainers understand its intent.
def ValueWidthMatchesShiftAmount :

mlir/test/Dialect/Arith/canonicalize.mlir:2972

  • [nitpick] Consider adding a positive test case for the scenario where the shift constant equals the operand bitwidth to verify that the canonicalization still fires as expected.
// CHECK-LABEL: @wideMulToMulSIExtendedWithWrongShift

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td:379

  • [nitpick] The name ValueWidthMatchesShiftAmount is quite long; consider renaming it to something like ShiftEqualsValueWidth for improved readability.
def ValueWidthMatchesShiftAmount :

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this. I think we should drop these canon patterns all together -- the issue is that one of the results is always unused so it's probably too aggressive to be enabled by default without a clear benefit. Similarly, it doesn't allow us to decompose muli_extended to plain muli without being folded back again.

I'm not aware of any code that relies on these patterns anymore.

gysit added a commit to gysit/llvm-project that referenced this pull request Jun 19, 2025
The Arith dialect includes patterns that canonicalize a sequence of:

- trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
- trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication, which
assumes that the shift amount is equal to the bit width of the original
operands. This check was missing, leading to incorrect canonicalizations
when the shift amount was less than the bit width.

For example, the following code:
```
  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32
```
would incorrectly be canonicalized to:
```
_, %hi = arith.mului_extended %a, %b : i32
```
This commit removes the faulty canonicalizations since they are not
believed to be generally beneficial (c.f., the discussion of the
alternative llvm#144787 which fixes
the canonicalizations).
@gysit
Copy link
Contributor Author

gysit commented Jun 19, 2025

Ok, closing this one in favor of #144844.

@gysit gysit closed this Jun 19, 2025
gysit added a commit to gysit/llvm-project that referenced this pull request Jun 19, 2025
The Arith dialect includes patterns that canonicalize a sequence of:

- trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
- trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication, which
assumes that the shift amount is equal to the bit width of the original
operands. This check was missing, leading to incorrect canonicalizations
when the shift amount was less than the bit width.

For example, the following code:
```
  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32
```
would incorrectly be canonicalized to:
```
_, %hi = arith.mului_extended %a, %b : i32
```
This commit removes the faulty canonicalizations since they are not
believed to be generally beneficial (c.f., the discussion of the
alternative llvm#144787 which fixes
the canonicalizations).
gysit added a commit that referenced this pull request Jun 19, 2025
The Arith dialect includes patterns that canonicalize a sequence of:

- trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
- trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication, which
assumes that the shift amount is equal to the bit width of the original
operands. This check was missing, leading to incorrect canonicalizations
when the shift amount was less than the bit width.

For example, the following code:
```
  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32
```
would incorrectly be canonicalized to:
```
_, %hi = arith.mului_extended %a, %b : i32
```
This commit removes the faulty canonicalizations since they are not
believed to be generally beneficial (c.f., the discussion of the
alternative #144787 which fixes
the canonicalizations).
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jun 19, 2025
The Arith dialect includes patterns that canonicalize a sequence of:

- trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
- trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication, which
assumes that the shift amount is equal to the bit width of the original
operands. This check was missing, leading to incorrect canonicalizations
when the shift amount was less than the bit width.

For example, the following code:
```
  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32
```
would incorrectly be canonicalized to:
```
_, %hi = arith.mului_extended %a, %b : i32
```
This commit removes the faulty canonicalizations since they are not
believed to be generally beneficial (c.f., the discussion of the
alternative llvm/llvm-project#144787 which fixes
the canonicalizations).
gysit added a commit that referenced this pull request Jun 20, 2025
The Arith dialect includes patterns that canonicalize a sequence of:

- trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
- trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication, which
assumes that the shift amount is equal to the bit width of the original
operands. This check was missing, leading to incorrect canonicalizations
when the shift amount was less than the bit width.

For example, the following code:
```
  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32
```
would incorrectly be canonicalized to:
```
_, %hi = arith.mului_extended %a, %b : i32
```
This commit removes the faulty canonicalizations since they are not
believed to be generally beneficial (c.f., the discussion of the
alternative #144787 which fixes
the canonicalizations).
@joker-eph
Copy link
Collaborator

Thanks for fixing this. I think we should drop these canon patterns all together -- the issue is that one of the results is always unused so it's probably too aggressive to be enabled by default without a clear benefit.

Why is it a problem? Seems like the cost-model thing is a lowering problem.
At the arithmetic level: having less op is generally beneficial. If mulsi_extended is considered differently then it probably should move out of the arith dialect.

Similarly, it doesn't allow us to decompose muli_extended to plain muli without being folded back again.

That is completely expected: this kind of decomposition are "lowering" (or "codegen prepare") and as such aren't expected to be resilient to canonicalization.

@gysit
Copy link
Contributor Author

gysit commented Jun 20, 2025

I do not have a strong opinion on keeping vs removing the canonicalization (as long as the bug is fixed). It is true that the output of the patterns seems more canonical from an arith dialect perspective.

@kuhar
Copy link
Member

kuhar commented Jun 20, 2025

Why is it a problem? Seems like the cost-model thing is a lowering problem.

This pattern is kind of obscure while adding logical complexity / maintenance cost to the codebase. I remember it appearing exactly once in an internal model and I'm not aware of any other real-world input where it actually matches. I think this kind of fold is best done by each backend when we know whether the target natively calculates the high result of multiplication or not.

That is completely expected: this kind of decomposition are "lowering" (or "codegen prepare") and as such aren't expected to be resilient to canonicalization.

Ack, I agree with you on this point after thinking about this more.

But overall, I don't think this pattern is worth keeping. If someone actually finds a usecase for it, they should be able to re-add it without much effort.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants