Skip to content

Commit bdf8ed0

Browse files
committed
[mlir][arith] Fold (a * b) / b
Alive2 check: https://alive2.llvm.org/ce/z/5XWjWE
1 parent c1ecc0d commit bdf8ed0

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,29 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
580580
// DivUIOp
581581
//===----------------------------------------------------------------------===//
582582

583+
static Value foldDivMul(Value lhs, Value rhs,
584+
arith::IntegerOverflowFlags ovfFlags) {
585+
auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
586+
if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
587+
return {};
588+
589+
if (mul.getLhs() == rhs)
590+
return mul.getRhs();
591+
592+
if (mul.getRhs() == rhs)
593+
return mul.getLhs();
594+
595+
return {};
596+
}
597+
583598
OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
584599
// divui (x, 1) -> x.
585600
if (matchPattern(adaptor.getRhs(), m_One()))
586601
return getLhs();
587602

603+
if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
604+
return val;
605+
588606
// Don't fold if it would require a division by zero.
589607
bool div0 = false;
590608
auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
@@ -621,6 +639,9 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
621639
if (matchPattern(adaptor.getRhs(), m_One()))
622640
return getLhs();
623641

642+
if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
643+
return val;
644+
624645
// Don't fold if it would overflow or if it requires a division by zero.
625646
bool overflowOrDiv0 = false;
626647
auto result = constFoldBinaryOp<IntegerAttr>(

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,6 +2060,26 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
20602060

20612061
// -----
20622062

2063+
// CHECK-LABEL: @test_divui_mul
2064+
// CHECK-SAME: (%[[ARG:.*]]: index, %{{.*}}: index)
2065+
// CHECK: return %[[ARG]]
2066+
func.func @test_divui_mul(%arg0: index, %arg1: index) -> index {
2067+
%0 = arith.muli %arg0, %arg1 overflow<nuw> : index
2068+
%1 = arith.divui %0, %arg1 : index
2069+
return %1 : index
2070+
}
2071+
2072+
// CHECK-LABEL: @test_divsi_mul
2073+
// CHECK-SAME: (%[[ARG:.*]]: index, %{{.*}}: index)
2074+
// CHECK: return %[[ARG]]
2075+
func.func @test_divsi_mul(%arg0: index, %arg1: index) -> index {
2076+
%0 = arith.muli %arg1, %arg0 overflow<nsw> : index
2077+
%1 = arith.divsi %0, %arg1 : index
2078+
return %1 : index
2079+
}
2080+
2081+
// -----
2082+
20632083
// CHECK-LABEL: @test_cmpf(
20642084
func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
20652085
// CHECK-DAG: %[[T:.*]] = arith.constant true

0 commit comments

Comments
 (0)