-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][arith] Fix multiplication canonicalizations #144787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The Arith dialect includes patterns that canonicalize a sequence of: - trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) - trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: ```mlir %x = arith.extui %a: i32 to i33 %y = arith.extui %b: i32 to i33 %m = arith.muli %x, %y: i33 %c1 = arith.constant 1: i33 %sh = arith.shrui %m, %c1 : i33 %hi = arith.trunci %sh: i33 to i32 ``` would incorrectly be canonicalized to: ```mlir _, %hi = arith.mului_extended %a, %b : i32 ````
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Tobias Gysi (gysit) ChangesThe Arith dialect includes patterns that canonicalize a sequence of:
These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: %x = arith.extui %a: i32 to i33
%y = arith.extui %b: i32 to i33
%m = arith.muli %x, %y: i33
%c1 = arith.constant 1: i33
%sh = arith.shrui %m, %c1 : i33
%hi = arith.trunci %sh: i33 to i32 would incorrectly be canonicalized to: _, %hi = arith.mului_extended %a, %b : i32 Full diff: https://github.com/llvm/llvm-project/pull/144787.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 13eb97a910bd4..2f7beed549108 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -273,7 +273,7 @@ def RedundantSelectFalse :
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
(SelectOp $pred, $a, $c)>;
-// select(pred, false, true) => not(pred)
+// select(pred, false, true) => not(pred)
def SelectI1ToNot :
Pat<(SelectOp $pred,
(ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
@@ -376,6 +376,12 @@ def TruncationMatchesShiftAmount :
CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == "
"*getIntOrSplatIntValue($2)">]>>;
+def ValueWidthMatchesShiftAmount :
+ Constraint<And<[
+ CPred<"succeeded(getIntOrSplatIntValue($1))">,
+ CPred<"getScalarOrElementWidth($0) == "
+ "*getIntOrSplatIntValue($1)">]>>;
+
// trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated
def TruncIExtSIToExtSI :
Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)),
@@ -406,7 +412,8 @@ def TruncIShrUIMulIToMulSIExtended :
(Arith_MulSIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
- (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+ (TruncationMatchesShiftAmount $mul, $x, $c0),
+ (ValueWidthMatchesShiftAmount $x, $c0)]>;
// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
def TruncIShrUIMulIToMulUIExtended :
@@ -417,7 +424,8 @@ def TruncIShrUIMulIToMulUIExtended :
(Arith_MulUIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
- (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+ (TruncationMatchesShiftAmount $mul, $x, $c0),
+ (ValueWidthMatchesShiftAmount $x, $c0)]>;
//===----------------------------------------------------------------------===//
// TruncIOp
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index b6188c81ff912..542603722ab8a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
// CHECK-LABEL: @foldSubXX_tensor
-// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
+// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
// CHECK: %[[sub:.+]] = arith.subi
// CHECK: return %[[c0]], %[[sub]]
func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
@@ -2966,6 +2966,21 @@ func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
return %hi : i32
}
+// Verify that the signed extended multiplication pattern does not match
+// if the right shift does not match the bitwidth of the multipliers.
+
+// CHECK-LABEL: @wideMulToMulSIExtendedWithWrongShift
+// CHECK-NOT: arith.mulsi_extended
+func.func @wideMulToMulSIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 {
+ %x = arith.extsi %a: i32 to i33
+ %y = arith.extsi %b: i32 to i33
+ %m = arith.muli %x, %y: i33
+ %c1 = arith.constant 1: i33
+ %sh = arith.shrui %m, %c1 : i33
+ %hi = arith.trunci %sh: i33 to i32
+ return %hi : i32
+}
+
// CHECK-LABEL: @wideMulToMulSIExtendedVector
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
@@ -2994,6 +3009,21 @@ func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
return %hi : i32
}
+// Verify that the unsigned extended multiplication pattern does not match
+// if the right shift does not match the bitwidth of the multipliers.
+
+// CHECK-LABEL: @wideMulToMulUIExtendedWithWrongShift
+// CHECK-NOT: arith.mului_extended
+func.func @wideMulToMulUIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 {
+ %x = arith.extui %a: i32 to i33
+ %y = arith.extui %b: i32 to i33
+ %m = arith.muli %x, %y: i33
+ %c1 = arith.constant 1: i33
+ %sh = arith.shrui %m, %c1 : i33
+ %hi = arith.trunci %sh: i33 to i32
+ return %hi : i32
+}
+
// CHECK-LABEL: @wideMulToMulUIExtendedVector
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>
|
@kuhar this did show up as a downstream. I tried to fix to the best of my understanding but it definitely needs a second pair of eyes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR ensures that the Arith dialect’s extended multiplication canonicalization patterns only fire when the right‐shift amount exactly equals the operand bitwidth.
- Added negative MLIR tests to verify that neither signed nor unsigned extended multiply patterns match when the shift constant is incorrect.
- Introduced a new
ValueWidthMatchesShiftAmount
constraint and applied it to both signed (mulsi_extended
) and unsigned (mului_extended
) patterns.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
File | Description |
---|---|
mlir/test/Dialect/Arith/canonicalize.mlir | Added CHECK-NOT tests for wrong shift amounts in both signed and unsigned cases |
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td | Defined ValueWidthMatchesShiftAmount constraint and added it to relevant patterns |
Comments suppressed due to low confidence (3)
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td:379
- [nitpick] Add a brief comment above this constraint explaining that it requires the shift constant to equal the operand's bitwidth; this will help future maintainers understand its intent.
def ValueWidthMatchesShiftAmount :
mlir/test/Dialect/Arith/canonicalize.mlir:2972
- [nitpick] Consider adding a positive test case for the scenario where the shift constant equals the operand bitwidth to verify that the canonicalization still fires as expected.
// CHECK-LABEL: @wideMulToMulSIExtendedWithWrongShift
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td:379
- [nitpick] The name
ValueWidthMatchesShiftAmount
is quite long; consider renaming it to something likeShiftEqualsValueWidth
for improved readability.
def ValueWidthMatchesShiftAmount :
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this. I think we should drop these canon patterns all together -- the issue is that one of the results is always unused so it's probably too aggressive to be enabled by default without a clear benefit. Similarly, it doesn't allow us to decompose muli_extended to plain muli without being folded back again.
I'm not aware of any code that relies on these patterns anymore.
The Arith dialect includes patterns that canonicalize a sequence of: - trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) - trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: ``` %x = arith.extui %a: i32 to i33 %y = arith.extui %b: i32 to i33 %m = arith.muli %x, %y: i33 %c1 = arith.constant 1: i33 %sh = arith.shrui %m, %c1 : i33 %hi = arith.trunci %sh: i33 to i32 ``` would incorrectly be canonicalized to: ``` _, %hi = arith.mului_extended %a, %b : i32 ``` This commit removes the faulty canonicalizations since they are not believed to be generally beneficial (c.f., the discussion of the alternative llvm#144787 which fixes the canonicalizations).
Ok, closing this one in favor of #144844. |
The Arith dialect includes patterns that canonicalize a sequence of: - trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) - trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: ``` %x = arith.extui %a: i32 to i33 %y = arith.extui %b: i32 to i33 %m = arith.muli %x, %y: i33 %c1 = arith.constant 1: i33 %sh = arith.shrui %m, %c1 : i33 %hi = arith.trunci %sh: i33 to i32 ``` would incorrectly be canonicalized to: ``` _, %hi = arith.mului_extended %a, %b : i32 ``` This commit removes the faulty canonicalizations since they are not believed to be generally beneficial (c.f., the discussion of the alternative llvm#144787 which fixes the canonicalizations).
The Arith dialect includes patterns that canonicalize a sequence of: - trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) - trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: ``` %x = arith.extui %a: i32 to i33 %y = arith.extui %b: i32 to i33 %m = arith.muli %x, %y: i33 %c1 = arith.constant 1: i33 %sh = arith.shrui %m, %c1 : i33 %hi = arith.trunci %sh: i33 to i32 ``` would incorrectly be canonicalized to: ``` _, %hi = arith.mului_extended %a, %b : i32 ``` This commit removes the faulty canonicalizations since they are not believed to be generally beneficial (c.f., the discussion of the alternative #144787 which fixes the canonicalizations).
The Arith dialect includes patterns that canonicalize a sequence of: - trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) - trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: ``` %x = arith.extui %a: i32 to i33 %y = arith.extui %b: i32 to i33 %m = arith.muli %x, %y: i33 %c1 = arith.constant 1: i33 %sh = arith.shrui %m, %c1 : i33 %hi = arith.trunci %sh: i33 to i32 ``` would incorrectly be canonicalized to: ``` _, %hi = arith.mului_extended %a, %b : i32 ``` This commit removes the faulty canonicalizations since they are not believed to be generally beneficial (c.f., the discussion of the alternative llvm/llvm-project#144787 which fixes the canonicalizations).
The Arith dialect includes patterns that canonicalize a sequence of: - trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) - trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: ``` %x = arith.extui %a: i32 to i33 %y = arith.extui %b: i32 to i33 %m = arith.muli %x, %y: i33 %c1 = arith.constant 1: i33 %sh = arith.shrui %m, %c1 : i33 %hi = arith.trunci %sh: i33 to i32 ``` would incorrectly be canonicalized to: ``` _, %hi = arith.mului_extended %a, %b : i32 ``` This commit removes the faulty canonicalizations since they are not believed to be generally beneficial (c.f., the discussion of the alternative #144787 which fixes the canonicalizations).
Why is it a problem? Seems like the cost-model thing is a lowering problem.
That is completely expected: this kind of decomposition are "lowering" (or "codegen prepare") and as such aren't expected to be resilient to canonicalization. |
I do not have a strong opinion on keeping vs removing the canonicalization (as long as the bug is fixed). It is true that the output of the patterns seems more canonical from an arith dialect perspective. |
This pattern is kind of obscure while adding logical complexity / maintenance cost to the codebase. I remember it appearing exactly once in an internal model and I'm not aware of any other real-world input where it actually matches. I think this kind of fold is best done by each backend when we know whether the target natively calculates the high result of multiplication or not.
Ack, I agree with you on this point after thinking about this more. But overall, I don't think this pattern is worth keeping. If someone actually finds a usecase for it, they should be able to re-add it without much effort. |
The Arith dialect includes patterns that canonicalize a sequence of:
These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width.
For example, the following code:
would incorrectly be canonicalized to: