Skip to content

Commit 3092534

Browse files
committed
Fix inference and formatting
1 parent c60ef79 commit 3092534

File tree

1 file changed

+36
-35
lines changed

1 file changed

+36
-35
lines changed

mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -182,20 +182,18 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
182182
OverflowFlags ovfFlags) {
183183
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
184184

185-
bool saturateUnsigned = any(ovfFlags & OverflowFlags::Nuw);
186-
bool saturateSigned = any(ovfFlags & OverflowFlags::Nsw);
187-
ConstArithFn uadd = [saturateUnsigned](const APInt &a,
188-
const APInt &b) -> std::optional<APInt> {
185+
std::function uadd = [=](const APInt &a,
186+
const APInt &b) -> std::optional<APInt> {
189187
bool overflowed = false;
190-
APInt result = saturateUnsigned
188+
APInt result = any(ovfFlags & OverflowFlags::Nuw)
191189
? a.uadd_sat(b)
192190
: a.uadd_ov(b, overflowed);
193191
return overflowed ? std::optional<APInt>() : result;
194192
};
195-
ConstArithFn sadd = [saturateSigned](const APInt &a,
196-
const APInt &b) -> std::optional<APInt> {
193+
std::function sadd = [=](const APInt &a,
194+
const APInt &b) -> std::optional<APInt> {
197195
bool overflowed = false;
198-
APInt result = saturateSigned
196+
APInt result = any(ovfFlags & OverflowFlags::Nsw)
199197
? a.sadd_sat(b)
200198
: a.sadd_ov(b, overflowed);
201199
return overflowed ? std::optional<APInt>() : result;
@@ -217,19 +215,20 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
217215
OverflowFlags ovfFlags) {
218216
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
219217

220-
bool saturateUnsigned = any(ovfFlags & OverflowFlags::Nuw);
221-
bool saturateSigned = any(ovfFlags & OverflowFlags::Nsw);
222-
ConstArithFn usub =
223-
[saturateUnsigned](const APInt &a,
224-
const APInt &b) -> std::optional<APInt> {
218+
std::function usub = [=](const APInt &a,
219+
const APInt &b) -> std::optional<APInt> {
225220
bool overflowed = false;
226-
APInt result = saturateUnsigned ? a.usub_sat(b) : a.usub_ov(b, overflowed);
221+
APInt result = any(ovfFlags & OverflowFlags::Nuw)
222+
? a.usub_sat(b)
223+
: a.usub_ov(b, overflowed);
227224
return overflowed ? std::optional<APInt>() : result;
228225
};
229-
ConstArithFn ssub = [saturateSigned](const APInt &a,
230-
const APInt &b) -> std::optional<APInt> {
226+
std::function ssub = [=](const APInt &a,
227+
const APInt &b) -> std::optional<APInt> {
231228
bool overflowed = false;
232-
APInt result = saturateSigned ? a.ssub_sat(b) : a.ssub_ov(b, overflowed);
229+
APInt result = any(ovfFlags & OverflowFlags::Nsw)
230+
? a.ssub_sat(b)
231+
: a.ssub_ov(b, overflowed);
233232
return overflowed ? std::optional<APInt>() : result;
234233
};
235234
ConstantIntRanges urange = computeBoundsBy(
@@ -248,19 +247,20 @@ mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
248247
OverflowFlags ovfFlags) {
249248
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
250249

251-
bool saturateUnsigned = any(ovfFlags & OverflowFlags::Nuw);
252-
bool saturateSigned = any(ovfFlags & OverflowFlags::Nsw);
253-
ConstArithFn umul =
254-
[saturateUnsigned](const APInt &a,
255-
const APInt &b) -> std::optional<APInt> {
250+
std::function umul = [=](const APInt &a,
251+
const APInt &b) -> std::optional<APInt> {
256252
bool overflowed = false;
257-
APInt result = saturateUnsigned ? a.umul_sat(b) : a.umul_ov(b, overflowed);
253+
APInt result = any(ovfFlags & OverflowFlags::Nuw)
254+
? a.umul_sat(b)
255+
: a.umul_ov(b, overflowed);
258256
return overflowed ? std::optional<APInt>() : result;
259257
};
260-
ConstArithFn smul = [saturateSigned](const APInt &a,
261-
const APInt &b) -> std::optional<APInt> {
258+
std::function smul = [=](const APInt &a,
259+
const APInt &b) -> std::optional<APInt> {
262260
bool overflowed = false;
263-
APInt result = saturateSigned ? a.smul_sat(b) : a.smul_ov(b, overflowed);
261+
APInt result = any(ovfFlags & OverflowFlags::Nsw)
262+
? a.smul_sat(b)
263+
: a.smul_ov(b, overflowed);
264264
return overflowed ? std::optional<APInt>() : result;
265265
};
266266

@@ -565,19 +565,20 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
565565

566566
// The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
567567
// 2^rhs.
568-
bool saturateUnsigned = any(ovfFlags & OverflowFlags::Nuw);
569-
bool saturateSigned = any(ovfFlags & OverflowFlags::Nsw);
570-
ConstArithFn ushl =
571-
[saturateUnsigned](const APInt &l,
572-
const APInt &r) -> std::optional<APInt> {
568+
std::function ushl = [=](const APInt &l,
569+
const APInt &r) -> std::optional<APInt> {
573570
bool overflowed = false;
574-
APInt result = saturateUnsigned ? l.ushl_sat(r) : l.ushl_ov(r, overflowed);
571+
APInt result = any(ovfFlags & OverflowFlags::Nuw)
572+
? l.ushl_sat(r)
573+
: l.ushl_ov(r, overflowed);
575574
return overflowed ? std::optional<APInt>() : result;
576575
};
577-
ConstArithFn sshl = [saturateSigned](const APInt &l,
578-
const APInt &r) -> std::optional<APInt> {
576+
std::function sshl = [=](const APInt &l,
577+
const APInt &r) -> std::optional<APInt> {
579578
bool overflowed = false;
580-
APInt result = saturateSigned ? l.sshl_sat(r) : l.sshl_ov(r, overflowed);
579+
APInt result = any(ovfFlags & OverflowFlags::Nsw)
580+
? l.sshl_sat(r)
581+
: l.sshl_ov(r, overflowed);
581582
return overflowed ? std::optional<APInt>() : result;
582583
};
583584

0 commit comments

Comments
 (0)