Skip to content

Commit 1cade86

Browse files
authored
[mlir][arith] Fold (a * b) / b -> a (#121534)
If overflow flags allow it. Alive2 check: https://alive2.llvm.org/ce/z/5XWjWE
1 parent fa56e8b commit 1cade86

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,31 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
580580
// DivUIOp
581581
//===----------------------------------------------------------------------===//
582582

583+
/// Fold `(a * b) / b -> a`
584+
static Value foldDivMul(Value lhs, Value rhs,
585+
arith::IntegerOverflowFlags ovfFlags) {
586+
auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
587+
if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
588+
return {};
589+
590+
if (mul.getLhs() == rhs)
591+
return mul.getRhs();
592+
593+
if (mul.getRhs() == rhs)
594+
return mul.getLhs();
595+
596+
return {};
597+
}
598+
583599
OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
584600
// divui (x, 1) -> x.
585601
if (matchPattern(adaptor.getRhs(), m_One()))
586602
return getLhs();
587603

604+
// (a * b) / b -> a
605+
if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
606+
return val;
607+
588608
// Don't fold if it would require a division by zero.
589609
bool div0 = false;
590610
auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
@@ -621,6 +641,10 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
621641
if (matchPattern(adaptor.getRhs(), m_One()))
622642
return getLhs();
623643

644+
// (a * b) / b -> a
645+
if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
646+
return val;
647+
624648
// Don't fold if it would overflow or if it requires a division by zero.
625649
bool overflowOrDiv0 = false;
626650
auto result = constFoldBinaryOp<IntegerAttr>(

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,6 +2060,70 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
20602060

20612061
// -----
20622062

2063+
func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
2064+
%0 = arith.muli %arg0, %arg1 overflow<nuw> : index
2065+
%1 = arith.divui %0, %arg0 : index
2066+
return %1 : index
2067+
}
2068+
// CHECK-LABEL: func @fold_divui_of_muli_0(
2069+
// CHECK-SAME: %[[ARG0:.+]]: index,
2070+
// CHECK-SAME: %[[ARG1:.+]]: index)
2071+
// CHECK: return %[[ARG1]]
2072+
2073+
func.func @fold_divui_of_muli_1(%arg0 : index, %arg1 : index) -> index {
2074+
%0 = arith.muli %arg0, %arg1 overflow<nuw> : index
2075+
%1 = arith.divui %0, %arg1 : index
2076+
return %1 : index
2077+
}
2078+
// CHECK-LABEL: func @fold_divui_of_muli_1(
2079+
// CHECK-SAME: %[[ARG0:.+]]: index,
2080+
// CHECK-SAME: %[[ARG1:.+]]: index)
2081+
// CHECK: return %[[ARG0]]
2082+
2083+
func.func @fold_divsi_of_muli_0(%arg0 : index, %arg1 : index) -> index {
2084+
%0 = arith.muli %arg0, %arg1 overflow<nsw> : index
2085+
%1 = arith.divsi %0, %arg0 : index
2086+
return %1 : index
2087+
}
2088+
// CHECK-LABEL: func @fold_divsi_of_muli_0(
2089+
// CHECK-SAME: %[[ARG0:.+]]: index,
2090+
// CHECK-SAME: %[[ARG1:.+]]: index)
2091+
// CHECK: return %[[ARG1]]
2092+
2093+
func.func @fold_divsi_of_muli_1(%arg0 : index, %arg1 : index) -> index {
2094+
%0 = arith.muli %arg0, %arg1 overflow<nsw> : index
2095+
%1 = arith.divsi %0, %arg1 : index
2096+
return %1 : index
2097+
}
2098+
// CHECK-LABEL: func @fold_divsi_of_muli_1(
2099+
// CHECK-SAME: %[[ARG0:.+]]: index,
2100+
// CHECK-SAME: %[[ARG1:.+]]: index)
2101+
// CHECK: return %[[ARG0]]
2102+
2103+
// Do not fold divui(mul(a, v), v) -> a with nuw attribute.
2104+
func.func @no_fold_divui_of_muli(%arg0 : index, %arg1 : index) -> index {
2105+
%0 = arith.muli %arg0, %arg1 : index
2106+
%1 = arith.divui %0, %arg0 : index
2107+
return %1 : index
2108+
}
2109+
// CHECK-LABEL: func @no_fold_divui_of_muli
2110+
// CHECK: %[[T0:.+]] = arith.muli
2111+
// CHECK: %[[T1:.+]] = arith.divui %[[T0]],
2112+
// CHECK: return %[[T1]]
2113+
2114+
// Do not fold divsi(mul(a, v), v) -> a with nuw attribute.
2115+
func.func @no_fold_divsi_of_muli(%arg0 : index, %arg1 : index) -> index {
2116+
%0 = arith.muli %arg0, %arg1 : index
2117+
%1 = arith.divsi %0, %arg0 : index
2118+
return %1 : index
2119+
}
2120+
// CHECK-LABEL: func @no_fold_divsi_of_muli
2121+
// CHECK: %[[T0:.+]] = arith.muli
2122+
// CHECK: %[[T1:.+]] = arith.divsi %[[T0]],
2123+
// CHECK: return %[[T1]]
2124+
2125+
// -----
2126+
20632127
// CHECK-LABEL: @test_cmpf(
20642128
func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
20652129
// CHECK-DAG: %[[T:.*]] = arith.constant true

0 commit comments

Comments
 (0)