@@ -181,18 +181,21 @@ ConstantIntRanges
181
181
mlir::intrange::inferAdd (ArrayRef<ConstantIntRanges> argRanges,
182
182
OverflowFlags ovfFlags) {
183
183
const ConstantIntRanges &lhs = argRanges[0 ], &rhs = argRanges[1 ];
184
- ConstArithFn uadd = [ovfFlags](const APInt &a,
184
+
185
+ bool saturateUnsigned = any (ovfFlags & OverflowFlags::Nuw);
186
+ bool saturateSigned = any (ovfFlags & OverflowFlags::Nsw);
187
+ ConstArithFn uadd = [saturateUnsigned](const APInt &a,
185
188
const APInt &b) -> std::optional<APInt> {
186
189
bool overflowed = false ;
187
- APInt result = any (ovfFlags & OverflowFlags::Nuw)
190
+ APInt result = saturateUnsigned
188
191
? a.uadd_sat (b)
189
192
: a.uadd_ov (b, overflowed);
190
193
return overflowed ? std::optional<APInt>() : result;
191
194
};
192
- ConstArithFn sadd = [ovfFlags ](const APInt &a,
195
+ ConstArithFn sadd = [saturateSigned ](const APInt &a,
193
196
const APInt &b) -> std::optional<APInt> {
194
197
bool overflowed = false ;
195
- APInt result = any (ovfFlags & OverflowFlags::Nsw)
198
+ APInt result = saturateSigned
196
199
? a.sadd_sat (b)
197
200
: a.sadd_ov (b, overflowed);
198
201
return overflowed ? std::optional<APInt>() : result;
@@ -214,20 +217,19 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
214
217
OverflowFlags ovfFlags) {
215
218
const ConstantIntRanges &lhs = argRanges[0 ], &rhs = argRanges[1 ];
216
219
217
- ConstArithFn usub = [ovfFlags](const APInt &a,
218
- const APInt &b) -> std::optional<APInt> {
220
+ bool saturateUnsigned = any (ovfFlags & OverflowFlags::Nuw);
221
+ bool saturateSigned = any (ovfFlags & OverflowFlags::Nsw);
222
+ ConstArithFn usub =
223
+ [saturateUnsigned](const APInt &a,
224
+ const APInt &b) -> std::optional<APInt> {
219
225
bool overflowed = false ;
220
- APInt result = any (ovfFlags & OverflowFlags::Nuw)
221
- ? a.usub_sat (b)
222
- : a.usub_ov (b, overflowed);
226
+ APInt result = saturateUnsigned ? a.usub_sat (b) : a.usub_ov (b, overflowed);
223
227
return overflowed ? std::optional<APInt>() : result;
224
228
};
225
- ConstArithFn ssub = [ovfFlags ](const APInt &a,
226
- const APInt &b) -> std::optional<APInt> {
229
+ ConstArithFn ssub = [saturateSigned ](const APInt &a,
230
+ const APInt &b) -> std::optional<APInt> {
227
231
bool overflowed = false ;
228
- APInt result = any (ovfFlags & OverflowFlags::Nsw)
229
- ? a.ssub_sat (b)
230
- : a.ssub_ov (b, overflowed);
232
+ APInt result = saturateSigned ? a.ssub_sat (b) : a.ssub_ov (b, overflowed);
231
233
return overflowed ? std::optional<APInt>() : result;
232
234
};
233
235
ConstantIntRanges urange = computeBoundsBy (
@@ -246,20 +248,19 @@ mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
246
248
OverflowFlags ovfFlags) {
247
249
const ConstantIntRanges &lhs = argRanges[0 ], &rhs = argRanges[1 ];
248
250
249
- ConstArithFn umul = [ovfFlags](const APInt &a,
250
- const APInt &b) -> std::optional<APInt> {
251
+ bool saturateUnsigned = any (ovfFlags & OverflowFlags::Nuw);
252
+ bool saturateSigned = any (ovfFlags & OverflowFlags::Nsw);
253
+ ConstArithFn umul =
254
+ [saturateUnsigned](const APInt &a,
255
+ const APInt &b) -> std::optional<APInt> {
251
256
bool overflowed = false ;
252
- APInt result = any (ovfFlags & OverflowFlags::Nuw)
253
- ? a.umul_sat (b)
254
- : a.umul_ov (b, overflowed);
257
+ APInt result = saturateUnsigned ? a.umul_sat (b) : a.umul_ov (b, overflowed);
255
258
return overflowed ? std::optional<APInt>() : result;
256
259
};
257
- ConstArithFn smul = [ovfFlags ](const APInt &a,
258
- const APInt &b) -> std::optional<APInt> {
260
+ ConstArithFn smul = [saturateSigned ](const APInt &a,
261
+ const APInt &b) -> std::optional<APInt> {
259
262
bool overflowed = false ;
260
- APInt result = any (ovfFlags & OverflowFlags::Nsw)
261
- ? a.smul_sat (b)
262
- : a.smul_ov (b, overflowed);
263
+ APInt result = saturateSigned ? a.smul_sat (b) : a.smul_ov (b, overflowed);
263
264
return overflowed ? std::optional<APInt>() : result;
264
265
};
265
266
@@ -564,20 +565,19 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
564
565
565
566
// The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
566
567
// 2^rhs.
567
- ConstArithFn ushl = [ovfFlags](const APInt &l,
568
- const APInt &r) -> std::optional<APInt> {
568
+ bool saturateUnsigned = any (ovfFlags & OverflowFlags::Nuw);
569
+ bool saturateSigned = any (ovfFlags & OverflowFlags::Nsw);
570
+ ConstArithFn ushl =
571
+ [saturateUnsigned](const APInt &l,
572
+ const APInt &r) -> std::optional<APInt> {
569
573
bool overflowed = false ;
570
- APInt result = any (ovfFlags & OverflowFlags::Nuw)
571
- ? l.ushl_sat (r)
572
- : l.ushl_ov (r, overflowed);
574
+ APInt result = saturateUnsigned ? l.ushl_sat (r) : l.ushl_ov (r, overflowed);
573
575
return overflowed ? std::optional<APInt>() : result;
574
576
};
575
- ConstArithFn sshl = [ovfFlags ](const APInt &l,
576
- const APInt &r) -> std::optional<APInt> {
577
+ ConstArithFn sshl = [saturateSigned ](const APInt &l,
578
+ const APInt &r) -> std::optional<APInt> {
577
579
bool overflowed = false ;
578
- APInt result = any (ovfFlags & OverflowFlags::Nsw)
579
- ? l.sshl_sat (r)
580
- : l.sshl_ov (r, overflowed);
580
+ APInt result = saturateSigned ? l.sshl_sat (r) : l.sshl_ov (r, overflowed);
581
581
return overflowed ? std::optional<APInt>() : result;
582
582
};
583
583
0 commit comments