Skip to content

Commit c17aa14

Browse files
authored
[mlir][index] Fold cmp(x, x) when x isn't a constant (#78812)
Such cases show up in the middle of optimizations passes, e.g., after some rewrites and then CSE. The current folder can fold such cases when the inputs are constant; this patch improves it to fold even if the inputs are non-constant.
1 parent b86d023 commit c17aa14

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,24 @@ static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
578578
lhsRange, ConstantIntRanges::constant(cstB));
579579
}
580580

581+
/// Return the result of `cmp(pred, x, x)`
582+
static bool compareSameArgs(IndexCmpPredicate pred) {
583+
switch (pred) {
584+
case IndexCmpPredicate::EQ:
585+
case IndexCmpPredicate::SGE:
586+
case IndexCmpPredicate::SLE:
587+
case IndexCmpPredicate::UGE:
588+
case IndexCmpPredicate::ULE:
589+
return true;
590+
case IndexCmpPredicate::NE:
591+
case IndexCmpPredicate::SGT:
592+
case IndexCmpPredicate::SLT:
593+
case IndexCmpPredicate::UGT:
594+
case IndexCmpPredicate::ULT:
595+
return false;
596+
}
597+
}
598+
581599
OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
582600
// Attempt to fold if both inputs are constant.
583601
auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
@@ -606,6 +624,10 @@ OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
606624
return BoolAttr::get(getContext(), *result64);
607625
}
608626

627+
// Fold `cmp(x, x)`
628+
if (getLhs() == getRhs())
629+
return BoolAttr::get(getContext(), compareSameArgs(getPred()));
630+
609631
return {};
610632
}
611633

mlir/test/Dialect/Index/index-canonicalize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,26 @@ func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) {
499499
return %0, %1, %2, %3, %5, %7 : i1, i1, i1, i1, i1, i1
500500
}
501501

502+
// CHECK-LABEL: @cmp_same_args
503+
func.func @cmp_same_args(%a: index) -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
504+
%0 = index.cmp eq(%a, %a)
505+
%1 = index.cmp sge(%a, %a)
506+
%2 = index.cmp sle(%a, %a)
507+
%3 = index.cmp uge(%a, %a)
508+
%4 = index.cmp ule(%a, %a)
509+
%5 = index.cmp ne(%a, %a)
510+
%6 = index.cmp sgt(%a, %a)
511+
%7 = index.cmp slt(%a, %a)
512+
%8 = index.cmp ugt(%a, %a)
513+
%9 = index.cmp ult(%a, %a)
514+
515+
// CHECK-DAG: %[[TRUE:.*]] = index.bool.constant true
516+
// CHECK-DAG: %[[FALSE:.*]] = index.bool.constant false
517+
// CHECK-NEXT: return %[[TRUE]], %[[TRUE]], %[[TRUE]], %[[TRUE]], %[[TRUE]],
518+
// CHECK-SAME: %[[FALSE]], %[[FALSE]], %[[FALSE]], %[[FALSE]], %[[FALSE]]
519+
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
520+
}
521+
502522
// CHECK-LABEL: @cmp_nofold
503523
func.func @cmp_nofold() -> i1 {
504524
%lhs = index.constant 1

0 commit comments

Comments
 (0)