Skip to content

[mlir][arith] Overflow flags propagation in arith canonicalizations. #91646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;

// TODO: Canonicalizations currently doesn't take into account integer overflow
// flags and always reset them to default (wraparound) which is safe but can
// inhibit later optimizations. Individual patterns must be reviewed for
// better handling of overflow flags.
// Merge overflow flags from 2 ops, selecting the most conservative combination.
def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;

// Default overflow flag (all wraparounds allowed).
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
Expand All @@ -45,23 +45,23 @@ def AddIAddConstant :
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
DefOverflow)>;
(MergeOverflow $ovf1, $ovf2))>;

// addi(subi(x, c0), c1) -> addi(x, c1 - c0)
def AddISubConstantRHS :
Pat<(Arith_AddIOp:$res
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
DefOverflow)>;
(MergeOverflow $ovf1, $ovf2))>;

// addi(subi(c0, x), c1) -> subi(c0 + c1, x)
def AddISubConstantLHS :
Pat<(Arith_AddIOp:$res
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
DefOverflow)>;
(MergeOverflow $ovf1, $ovf2))>;

def IsScalarOrSplatNegativeOne :
Constraint<And<[
Expand All @@ -73,15 +73,15 @@ def AddIMulNegativeOneRhs :
Pat<(Arith_AddIOp
$x,
(Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp $x, $y, DefOverflow),
(Arith_SubIOp $x, $y, DefOverflow), // TODO: overflow flags
[(IsScalarOrSplatNegativeOne $c0)]>;

// addi(muli(x, -1), y) -> subi(y, x)
def AddIMulNegativeOneLhs :
Pat<(Arith_AddIOp
(Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
$y, $ovf2),
(Arith_SubIOp $y, $x, DefOverflow),
(Arith_SubIOp $y, $x, DefOverflow), // TODO: overflow flags
[(IsScalarOrSplatNegativeOne $c0)]>;

// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
Expand All @@ -90,7 +90,7 @@ def MulIMulIConstant :
(Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
DefOverflow)>;
(MergeOverflow $ovf1, $ovf2))>;

//===----------------------------------------------------------------------===//
// AddUIExtendedOp
Expand All @@ -113,52 +113,53 @@ def SubIRHSAddConstant :
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
DefOverflow)>;
DefOverflow)>; // TODO: overflow flags

// subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
def SubILHSAddConstant :
Pat<(Arith_SubIOp:$res
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
DefOverflow)>;
(MergeOverflow $ovf1, $ovf2))>;

// subi(subi(x, c0), c1) -> subi(x, c0 + c1)
def SubIRHSSubConstantRHS :
Pat<(Arith_SubIOp:$res
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
DefOverflow)>;
(MergeOverflow $ovf1, $ovf2))>;

// subi(subi(c0, x), c1) -> subi(c0 - c1, x)
def SubIRHSSubConstantLHS :
Pat<(Arith_SubIOp:$res
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
DefOverflow)>;
(MergeOverflow $ovf1, $ovf2))>;

// subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
def SubILHSSubConstantRHS :
Pat<(Arith_SubIOp:$res
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
DefOverflow)>;
(MergeOverflow $ovf1, $ovf2))>;

// subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
def SubILHSSubConstantLHS :
Pat<(Arith_SubIOp:$res
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
DefOverflow)>;
(MergeOverflow $ovf1, $ovf2))>;

// subi(subi(a, b), a) -> subi(0, b)
def SubISubILHSRHSLHS :
Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
(Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
(Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y,
(MergeOverflow $ovf1, $ovf2))>;

//===----------------------------------------------------------------------===//
// MulSIExtendedOp
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}

// Merge overflow flags from 2 ops, selecting the most conservative combination.
static IntegerOverflowFlagsAttr
mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
IntegerOverflowFlagsAttr val2) {
return IntegerOverflowFlagsAttr::get(val1.getContext(),
val1.getValue() & val2.getValue());
}

/// Invert an integer comparison predicate.
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
Expand Down
119 changes: 119 additions & 0 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,30 @@ func.func @tripleAddAdd(%arg0: index) -> index {
return %add2 : index
}

// CHECK-LABEL: @tripleAddAddOvf1
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
// CHECK: return %[[add]]
func.func @tripleAddAddOvf1(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
%add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
return %add2 : index
}

// CHECK-LABEL: @tripleAddAddOvf2
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
// CHECK: return %[[add]]
func.func @tripleAddAddOvf2(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%add1 = arith.addi %c17, %arg0 overflow<nsw> : index
%add2 = arith.addi %c42, %add1 overflow<nuw> : index
return %add2 : index
}

// CHECK-LABEL: @tripleAddSub0
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
Expand All @@ -845,6 +869,18 @@ func.func @tripleAddSub0(%arg0: index) -> index {
return %add2 : index
}

// CHECK-LABEL: @tripleAddSub0Ovf
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
// CHECK: return %[[add]]
func.func @tripleAddSub0Ovf(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
%add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
return %add2 : index
}

// CHECK-LABEL: @tripleAddSub1
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
Expand All @@ -857,6 +893,18 @@ func.func @tripleAddSub1(%arg0: index) -> index {
return %add2 : index
}

// CHECK-LABEL: @tripleAddSub1Ovf
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
// CHECK: return %[[add]]
func.func @tripleAddSub1Ovf(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
%add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
return %add2 : index
}

// CHECK-LABEL: @tripleSubAdd0
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
Expand All @@ -869,6 +917,18 @@ func.func @tripleSubAdd0(%arg0: index) -> index {
return %add2 : index
}

// CHECK-LABEL: @tripleSubAdd0Ovf
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
// CHECK: return %[[add]]
func.func @tripleSubAdd0Ovf(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
%add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
return %add2 : index
}

// CHECK-LABEL: @tripleSubAdd1
// CHECK: %[[cres:.+]] = arith.constant -25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
Expand All @@ -891,6 +951,16 @@ func.func @subSub0(%arg0: index, %arg1: index) -> index {
return %sub2 : index
}

// CHECK-LABEL: @subSub0Ovf
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[add:.+]] = arith.subi %[[c0]], %arg1 overflow<nsw, nuw> : index
// CHECK: return %[[add]]
func.func @subSub0Ovf(%arg0: index, %arg1: index) -> index {
%sub1 = arith.subi %arg0, %arg1 overflow<nsw, nuw> : index
%sub2 = arith.subi %sub1, %arg0 overflow<nsw, nuw> : index
return %sub2 : index
}

// CHECK-LABEL: @tripleSubSub0
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
Expand All @@ -903,6 +973,19 @@ func.func @tripleSubSub0(%arg0: index) -> index {
return %add2 : index
}

// CHECK-LABEL: @tripleSubSub0Ovf
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
// CHECK: return %[[add]]
func.func @tripleSubSub0Ovf(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
%add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
return %add2 : index
}


// CHECK-LABEL: @tripleSubSub1
// CHECK: %[[cres:.+]] = arith.constant -25 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
Expand All @@ -915,6 +998,18 @@ func.func @tripleSubSub1(%arg0: index) -> index {
return %add2 : index
}

// CHECK-LABEL: @tripleSubSub1Ovf
// CHECK: %[[cres:.+]] = arith.constant -25 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
// CHECK: return %[[add]]
func.func @tripleSubSub1Ovf(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
%add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
return %add2 : index
}

// CHECK-LABEL: @tripleSubSub2
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
Expand All @@ -927,6 +1022,18 @@ func.func @tripleSubSub2(%arg0: index) -> index {
return %add2 : index
}

// CHECK-LABEL: @tripleSubSub2Ovf
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
// CHECK: return %[[add]]
func.func @tripleSubSub2Ovf(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
%add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
return %add2 : index
}

// CHECK-LABEL: @tripleSubSub3
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %arg0, %[[cres]] : index
Expand All @@ -939,6 +1046,18 @@ func.func @tripleSubSub3(%arg0: index) -> index {
return %add2 : index
}

// CHECK-LABEL: @tripleSubSub3Ovf
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %arg0, %[[cres]] overflow<nsw, nuw> : index
// CHECK: return %[[add]]
func.func @tripleSubSub3Ovf(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
%add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
return %add2 : index
}

// CHECK-LABEL: @subAdd1
// CHECK-NEXT: return %arg0
func.func @subAdd1(%arg0: index, %arg1 : index) -> index {
Expand Down