Skip to content

Commit 35f4cdb

Browse files
[mlir][arith] Add constraints to the MulIOp for preventing type mismatch while folding (#136093)
Fixes #135289 The original version didn't check if the types of lhs, rhs, and the result matched, which could cause type errors. This fix adds type checks to make sure the constants attributes have the same type as the SSA values before applying the simplification.
1 parent c3c0b27 commit 35f4cdb

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def MulIMulIConstant :
9090
(Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
9191
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
9292
(Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
93-
(MergeOverflow $ovf1, $ovf2))>;
93+
(MergeOverflow $ovf1, $ovf2)),
94+
[(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
95+
(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
9496

9597
//===----------------------------------------------------------------------===//
9698
// AddUIExtendedOp

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,24 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
12341234
return %add : index
12351235
}
12361236

1237+
// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
1238+
// CHECK-LABEL: func.func @nested_muli() -> i32 {
1239+
// CHECK: %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
1240+
// CHECK: %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
1241+
// CHECK: %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
1242+
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : i32
1243+
// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32
1244+
// CHECK: return %[[VAL_4]] : i32
1245+
// CHECK: }
1246+
func.func @nested_muli() -> (i32) {
1247+
%0 = "test.constant"() {value = 0x7fffffff} : () -> i32
1248+
%1 = "test.constant"() {value = -2147483648} : () -> i32
1249+
%2 = "test.constant"() {value = 0x80000000} : () -> i32
1250+
%4 = arith.muli %0, %1 : i32
1251+
%5 = arith.muli %4, %2 : i32
1252+
return %5 : i32
1253+
}
1254+
12371255
// CHECK-LABEL: @tripleMulIMulIIndex
12381256
// CHECK: %[[cres:.+]] = arith.constant 15 : index
12391257
// CHECK: %[[muli:.+]] = arith.muli %arg0, %[[cres]] : index

0 commit comments

Comments
 (0)