Skip to content

Commit 345f57d

Browse files
authored
[mlir][arith] Overflow flags propagation in arith canonicalizations. (#91646)
1 parent efe91cf commit 345f57d

File tree

3 files changed

+145
-17
lines changed

3 files changed

+145
-17
lines changed

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
2424
// Multiply two integer attributes and create a new one with the result.
2525
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
2626

27-
// TODO: Canonicalizations currently doesn't take into account integer overflow
28-
// flags and always reset them to default (wraparound) which is safe but can
29-
// inhibit later optimizations. Individual patterns must be reviewed for
30-
// better handling of overflow flags.
27+
// Merge overflow flags from 2 ops, selecting the most conservative combination.
28+
def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
29+
30+
// Default overflow flag (all wraparounds allowed).
3131
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
3232

3333
class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
@@ -45,23 +45,23 @@ def AddIAddConstant :
4545
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
4646
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
4747
(Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
48-
DefOverflow)>;
48+
(MergeOverflow $ovf1, $ovf2))>;
4949

5050
// addi(subi(x, c0), c1) -> addi(x, c1 - c0)
5151
def AddISubConstantRHS :
5252
Pat<(Arith_AddIOp:$res
5353
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
5454
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
5555
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
56-
DefOverflow)>;
56+
(MergeOverflow $ovf1, $ovf2))>;
5757

5858
// addi(subi(c0, x), c1) -> subi(c0 + c1, x)
5959
def AddISubConstantLHS :
6060
Pat<(Arith_AddIOp:$res
6161
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
6262
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
6363
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
64-
DefOverflow)>;
64+
(MergeOverflow $ovf1, $ovf2))>;
6565

6666
def IsScalarOrSplatNegativeOne :
6767
Constraint<And<[
@@ -73,15 +73,15 @@ def AddIMulNegativeOneRhs :
7373
Pat<(Arith_AddIOp
7474
$x,
7575
(Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
76-
(Arith_SubIOp $x, $y, DefOverflow),
76+
(Arith_SubIOp $x, $y, DefOverflow), // TODO: overflow flags
7777
[(IsScalarOrSplatNegativeOne $c0)]>;
7878

7979
// addi(muli(x, -1), y) -> subi(y, x)
8080
def AddIMulNegativeOneLhs :
8181
Pat<(Arith_AddIOp
8282
(Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
8383
$y, $ovf2),
84-
(Arith_SubIOp $y, $x, DefOverflow),
84+
(Arith_SubIOp $y, $x, DefOverflow), // TODO: overflow flags
8585
[(IsScalarOrSplatNegativeOne $c0)]>;
8686

8787
// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
@@ -90,7 +90,7 @@ def MulIMulIConstant :
9090
(Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
9191
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
9292
(Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
93-
DefOverflow)>;
93+
(MergeOverflow $ovf1, $ovf2))>;
9494

9595
//===----------------------------------------------------------------------===//
9696
// AddUIExtendedOp
@@ -113,52 +113,53 @@ def SubIRHSAddConstant :
113113
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
114114
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
115115
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
116-
DefOverflow)>;
116+
DefOverflow)>; // TODO: overflow flags
117117

118118
// subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
119119
def SubILHSAddConstant :
120120
Pat<(Arith_SubIOp:$res
121121
(ConstantLikeMatcher APIntAttr:$c1),
122122
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
123123
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
124-
DefOverflow)>;
124+
(MergeOverflow $ovf1, $ovf2))>;
125125

126126
// subi(subi(x, c0), c1) -> subi(x, c0 + c1)
127127
def SubIRHSSubConstantRHS :
128128
Pat<(Arith_SubIOp:$res
129129
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
130130
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
131131
(Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
132-
DefOverflow)>;
132+
(MergeOverflow $ovf1, $ovf2))>;
133133

134134
// subi(subi(c0, x), c1) -> subi(c0 - c1, x)
135135
def SubIRHSSubConstantLHS :
136136
Pat<(Arith_SubIOp:$res
137137
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
138138
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
139139
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
140-
DefOverflow)>;
140+
(MergeOverflow $ovf1, $ovf2))>;
141141

142142
// subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
143143
def SubILHSSubConstantRHS :
144144
Pat<(Arith_SubIOp:$res
145145
(ConstantLikeMatcher APIntAttr:$c1),
146146
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
147147
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
148-
DefOverflow)>;
148+
(MergeOverflow $ovf1, $ovf2))>;
149149

150150
// subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
151151
def SubILHSSubConstantLHS :
152152
Pat<(Arith_SubIOp:$res
153153
(ConstantLikeMatcher APIntAttr:$c1),
154154
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
155155
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
156-
DefOverflow)>;
156+
(MergeOverflow $ovf1, $ovf2))>;
157157

158158
// subi(subi(a, b), a) -> subi(0, b)
159159
def SubISubILHSRHSLHS :
160160
Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
161-
(Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
161+
(Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y,
162+
(MergeOverflow $ovf1, $ovf2))>;
162163

163164
//===----------------------------------------------------------------------===//
164165
// MulSIExtendedOp

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
6464
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
6565
}
6666

67+
// Merge overflow flags from 2 ops, selecting the most conservative combination.
68+
static IntegerOverflowFlagsAttr
69+
mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
70+
IntegerOverflowFlagsAttr val2) {
71+
return IntegerOverflowFlagsAttr::get(val1.getContext(),
72+
val1.getValue() & val2.getValue());
73+
}
74+
6775
/// Invert an integer comparison predicate.
6876
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
6977
switch (pred) {

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,30 @@ func.func @tripleAddAdd(%arg0: index) -> index {
833833
return %add2 : index
834834
}
835835

836+
// CHECK-LABEL: @tripleAddAddOvf1
837+
// CHECK: %[[cres:.+]] = arith.constant 59 : index
838+
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
839+
// CHECK: return %[[add]]
840+
func.func @tripleAddAddOvf1(%arg0: index) -> index {
841+
%c17 = arith.constant 17 : index
842+
%c42 = arith.constant 42 : index
843+
%add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
844+
%add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
845+
return %add2 : index
846+
}
847+
848+
// CHECK-LABEL: @tripleAddAddOvf2
849+
// CHECK: %[[cres:.+]] = arith.constant 59 : index
850+
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
851+
// CHECK: return %[[add]]
852+
func.func @tripleAddAddOvf2(%arg0: index) -> index {
853+
%c17 = arith.constant 17 : index
854+
%c42 = arith.constant 42 : index
855+
%add1 = arith.addi %c17, %arg0 overflow<nsw> : index
856+
%add2 = arith.addi %c42, %add1 overflow<nuw> : index
857+
return %add2 : index
858+
}
859+
836860
// CHECK-LABEL: @tripleAddSub0
837861
// CHECK: %[[cres:.+]] = arith.constant 59 : index
838862
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -845,6 +869,18 @@ func.func @tripleAddSub0(%arg0: index) -> index {
845869
return %add2 : index
846870
}
847871

872+
// CHECK-LABEL: @tripleAddSub0Ovf
873+
// CHECK: %[[cres:.+]] = arith.constant 59 : index
874+
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
875+
// CHECK: return %[[add]]
876+
func.func @tripleAddSub0Ovf(%arg0: index) -> index {
877+
%c17 = arith.constant 17 : index
878+
%c42 = arith.constant 42 : index
879+
%add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
880+
%add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
881+
return %add2 : index
882+
}
883+
848884
// CHECK-LABEL: @tripleAddSub1
849885
// CHECK: %[[cres:.+]] = arith.constant 25 : index
850886
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -857,6 +893,18 @@ func.func @tripleAddSub1(%arg0: index) -> index {
857893
return %add2 : index
858894
}
859895

896+
// CHECK-LABEL: @tripleAddSub1Ovf
897+
// CHECK: %[[cres:.+]] = arith.constant 25 : index
898+
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
899+
// CHECK: return %[[add]]
900+
func.func @tripleAddSub1Ovf(%arg0: index) -> index {
901+
%c17 = arith.constant 17 : index
902+
%c42 = arith.constant 42 : index
903+
%add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
904+
%add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
905+
return %add2 : index
906+
}
907+
860908
// CHECK-LABEL: @tripleSubAdd0
861909
// CHECK: %[[cres:.+]] = arith.constant 25 : index
862910
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -869,6 +917,18 @@ func.func @tripleSubAdd0(%arg0: index) -> index {
869917
return %add2 : index
870918
}
871919

920+
// CHECK-LABEL: @tripleSubAdd0Ovf
921+
// CHECK: %[[cres:.+]] = arith.constant 25 : index
922+
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
923+
// CHECK: return %[[add]]
924+
func.func @tripleSubAdd0Ovf(%arg0: index) -> index {
925+
%c17 = arith.constant 17 : index
926+
%c42 = arith.constant 42 : index
927+
%add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
928+
%add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
929+
return %add2 : index
930+
}
931+
872932
// CHECK-LABEL: @tripleSubAdd1
873933
// CHECK: %[[cres:.+]] = arith.constant -25 : index
874934
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -891,6 +951,16 @@ func.func @subSub0(%arg0: index, %arg1: index) -> index {
891951
return %sub2 : index
892952
}
893953

954+
// CHECK-LABEL: @subSub0Ovf
955+
// CHECK: %[[c0:.+]] = arith.constant 0 : index
956+
// CHECK: %[[add:.+]] = arith.subi %[[c0]], %arg1 overflow<nsw, nuw> : index
957+
// CHECK: return %[[add]]
958+
func.func @subSub0Ovf(%arg0: index, %arg1: index) -> index {
959+
%sub1 = arith.subi %arg0, %arg1 overflow<nsw, nuw> : index
960+
%sub2 = arith.subi %sub1, %arg0 overflow<nsw, nuw> : index
961+
return %sub2 : index
962+
}
963+
894964
// CHECK-LABEL: @tripleSubSub0
895965
// CHECK: %[[cres:.+]] = arith.constant 25 : index
896966
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -903,6 +973,19 @@ func.func @tripleSubSub0(%arg0: index) -> index {
903973
return %add2 : index
904974
}
905975

976+
// CHECK-LABEL: @tripleSubSub0Ovf
977+
// CHECK: %[[cres:.+]] = arith.constant 25 : index
978+
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
979+
// CHECK: return %[[add]]
980+
func.func @tripleSubSub0Ovf(%arg0: index) -> index {
981+
%c17 = arith.constant 17 : index
982+
%c42 = arith.constant 42 : index
983+
%add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
984+
%add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
985+
return %add2 : index
986+
}
987+
988+
906989
// CHECK-LABEL: @tripleSubSub1
907990
// CHECK: %[[cres:.+]] = arith.constant -25 : index
908991
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -915,6 +998,18 @@ func.func @tripleSubSub1(%arg0: index) -> index {
915998
return %add2 : index
916999
}
9171000

1001+
// CHECK-LABEL: @tripleSubSub1Ovf
1002+
// CHECK: %[[cres:.+]] = arith.constant -25 : index
1003+
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
1004+
// CHECK: return %[[add]]
1005+
func.func @tripleSubSub1Ovf(%arg0: index) -> index {
1006+
%c17 = arith.constant 17 : index
1007+
%c42 = arith.constant 42 : index
1008+
%add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
1009+
%add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
1010+
return %add2 : index
1011+
}
1012+
9181013
// CHECK-LABEL: @tripleSubSub2
9191014
// CHECK: %[[cres:.+]] = arith.constant 59 : index
9201015
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -927,6 +1022,18 @@ func.func @tripleSubSub2(%arg0: index) -> index {
9271022
return %add2 : index
9281023
}
9291024

1025+
// CHECK-LABEL: @tripleSubSub2Ovf
1026+
// CHECK: %[[cres:.+]] = arith.constant 59 : index
1027+
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
1028+
// CHECK: return %[[add]]
1029+
func.func @tripleSubSub2Ovf(%arg0: index) -> index {
1030+
%c17 = arith.constant 17 : index
1031+
%c42 = arith.constant 42 : index
1032+
%add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
1033+
%add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
1034+
return %add2 : index
1035+
}
1036+
9301037
// CHECK-LABEL: @tripleSubSub3
9311038
// CHECK: %[[cres:.+]] = arith.constant 59 : index
9321039
// CHECK: %[[add:.+]] = arith.subi %arg0, %[[cres]] : index
@@ -939,6 +1046,18 @@ func.func @tripleSubSub3(%arg0: index) -> index {
9391046
return %add2 : index
9401047
}
9411048

1049+
// CHECK-LABEL: @tripleSubSub3Ovf
1050+
// CHECK: %[[cres:.+]] = arith.constant 59 : index
1051+
// CHECK: %[[add:.+]] = arith.subi %arg0, %[[cres]] overflow<nsw, nuw> : index
1052+
// CHECK: return %[[add]]
1053+
func.func @tripleSubSub3Ovf(%arg0: index) -> index {
1054+
%c17 = arith.constant 17 : index
1055+
%c42 = arith.constant 42 : index
1056+
%add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
1057+
%add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
1058+
return %add2 : index
1059+
}
1060+
9421061
// CHECK-LABEL: @subAdd1
9431062
// CHECK-NEXT: return %arg0
9441063
func.func @subAdd1(%arg0: index, %arg1 : index) -> index {

0 commit comments

Comments
 (0)