Skip to content

Commit b354130

Browse files
author
Spenser Bauman
committed
Convert uses of OptionalIntRange to IntegerValueRange
IntegerValueRange already exists and encodes the extact information that we want to represent with OptionalIntRange. This makes the APIs clearer than passing an std::optional everywhere.
1 parent 7410f33 commit b354130

File tree

11 files changed

+367
-334
lines changed

11 files changed

+367
-334
lines changed

mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,51 +24,6 @@
2424
namespace mlir {
2525
namespace dataflow {
2626

27-
/// This lattice value represents the integer range of an SSA value.
28-
class IntegerValueRange {
29-
public:
30-
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
31-
/// range that is used to mark the value as unable to be analyzed further,
32-
/// where `t` is the type of `value`.
33-
static IntegerValueRange getMaxRange(Value value);
34-
35-
/// Create an integer value range lattice value.
36-
IntegerValueRange(OptionalIntRanges value = std::nullopt)
37-
: value(std::move(value)) {}
38-
39-
/// Whether the range is uninitialized. This happens when the state hasn't
40-
/// been set during the analysis.
41-
bool isUninitialized() const { return !value.has_value(); }
42-
43-
/// Get the known integer value range.
44-
const ConstantIntRanges &getValue() const {
45-
assert(!isUninitialized());
46-
return *value;
47-
}
48-
49-
/// Compare two ranges.
50-
bool operator==(const IntegerValueRange &rhs) const {
51-
return value == rhs.value;
52-
}
53-
54-
/// Take the union of two ranges.
55-
static IntegerValueRange join(const IntegerValueRange &lhs,
56-
const IntegerValueRange &rhs) {
57-
if (lhs.isUninitialized())
58-
return rhs;
59-
if (rhs.isUninitialized())
60-
return lhs;
61-
return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
62-
}
63-
64-
/// Print the integer value range.
65-
void print(raw_ostream &os) const { os << value; }
66-
67-
private:
68-
/// The known integer value range.
69-
std::optional<ConstantIntRanges> value;
70-
};
71-
7227
/// This lattice element represents the integer value range of an SSA value.
7328
/// When this lattice is updated, it automatically updates the constant value
7429
/// of the SSA value (if the range can be narrowed to one).

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,60 @@ class ConstantIntRanges {
105105

106106
raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
107107

108-
using OptionalIntRanges = std::optional<ConstantIntRanges>;
108+
/// This lattice value represents the integer range of an SSA value.
109+
class IntegerValueRange {
110+
public:
111+
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
112+
/// range that is used to mark the value as unable to be analyzed further,
113+
/// where `t` is the type of `value`.
114+
static IntegerValueRange getMaxRange(Value value);
115+
116+
/// Create an integer value range lattice value.
117+
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
118+
119+
/// Create an integer value range lattice value.
120+
IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
121+
: value(std::move(value)) {}
122+
123+
/// Whether the range is uninitialized. This happens when the state hasn't
124+
/// been set during the analysis.
125+
bool isUninitialized() const { return !value.has_value(); }
126+
127+
/// Get the known integer value range.
128+
const ConstantIntRanges &getValue() const {
129+
assert(!isUninitialized());
130+
return *value;
131+
}
132+
133+
/// Compare two ranges.
134+
bool operator==(const IntegerValueRange &rhs) const {
135+
return value == rhs.value;
136+
}
137+
138+
/// Compute the least upper bound of two ranges.
139+
static IntegerValueRange join(const IntegerValueRange &lhs,
140+
const IntegerValueRange &rhs) {
141+
if (lhs.isUninitialized())
142+
return rhs;
143+
if (rhs.isUninitialized())
144+
return lhs;
145+
return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
146+
}
147+
148+
/// Print the integer value range.
149+
void print(raw_ostream &os) const { os << value; }
150+
151+
private:
152+
/// The known integer value range.
153+
std::optional<ConstantIntRanges> value;
154+
};
155+
156+
raw_ostream &operator<<(raw_ostream &, const IntegerValueRange &);
157+
109158
/// The type of the `setResultRanges` callback provided to ops implementing
110159
/// InferIntRangeInterface. It should be called once for each integer result
111160
/// value and be passed the ConstantIntRanges corresponding to that value.
112-
using SetIntRangeFn = function_ref<void(Value, const OptionalIntRanges &)>;
161+
using SetIntRangeFn = function_ref<void(Value, const IntegerValueRange &)>;
113162
} // end namespace mlir
114163

115164
#include "mlir/Interfaces/InferIntRangeInterface.h.inc"

mlir/include/mlir/Interfaces/InferIntRangeInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
4545
APInts in their `argRanges` element.
4646
}],
4747
"void", "inferResultRanges", (ins
48-
"::llvm::ArrayRef<::std::optional<::mlir::ConstantIntRanges>>":$argRanges,
48+
"::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
4949
"::mlir::SetIntRangeFn":$setResultRanges)
5050
>];
5151
}

mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ namespace intrange {
2727
using InferRangeFn =
2828
std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
2929

30-
using OptionalRangeFn =
31-
std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>;
30+
/// Function that performs inferrence on an array of `IntegerValueRange`.
31+
using InferIntegerValueRangeFn =
32+
std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>;
3233

3334
static constexpr unsigned indexMinWidth = 32;
3435
static constexpr unsigned indexMaxWidth = 64;
@@ -47,7 +48,11 @@ enum class OverflowFlags : uint32_t {
4748
using InferRangeWithOvfFlagsFn =
4849
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
4950

50-
OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
51+
/// Perform a pointwise extension of a function operating on `ConstantIntRanges`
52+
/// to a function operating on `IntegerValueRange` such that undefined input
53+
/// ranges propagate.
54+
InferIntegerValueRangeFn
55+
inferFromIntegerValueRange(intrange::InferRangeFn inferFn);
5156

5257
/// Compute `inferFn` on `ranges`, whose size should be the index storage
5358
/// bitwidth. Then, compute the function on `argRanges` again after truncating
@@ -57,7 +62,7 @@ OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
5762
///
5863
/// The `mode` argument specifies if the unsigned, signed, or both results of
5964
/// the inference computation should be used when comparing the results.
60-
ConstantIntRanges inferIndexOp(InferRangeFn inferFn,
65+
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn,
6166
ArrayRef<ConstantIntRanges> argRanges,
6267
CmpMode mode);
6368

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -36,33 +36,6 @@
3636
using namespace mlir;
3737
using namespace mlir::dataflow;
3838

39-
namespace {
40-
41-
OptionalIntRanges getOptionalRange(const IntegerValueRange &range) {
42-
if (range.isUninitialized())
43-
return std::nullopt;
44-
return range.getValue();
45-
}
46-
47-
OptionalIntRanges
48-
getOptionalRangeFromLattice(const IntegerValueRangeLattice *lattice) {
49-
return getOptionalRange(lattice->getValue());
50-
}
51-
52-
} // end namespace
53-
54-
IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
55-
unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
56-
if (width == 0)
57-
return {};
58-
59-
APInt umin = APInt::getMinValue(width);
60-
APInt umax = APInt::getMaxValue(width);
61-
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
62-
APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
63-
return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
64-
}
65-
6639
void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
6740
Lattice::onUpdate(solver);
6841

@@ -94,9 +67,12 @@ void IntegerRangeAnalysis::visitOperation(
9467
return setAllToEntryStates(results);
9568

9669
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
97-
auto argRanges = llvm::map_to_vector(operands, getOptionalRangeFromLattice);
70+
auto argRanges = llvm::map_to_vector(
71+
operands, [](const IntegerValueRangeLattice *lattice) {
72+
return lattice->getValue();
73+
});
9874

99-
auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
75+
auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
10076
auto result = dyn_cast<OpResult>(v);
10177
if (!result)
10278
return;
@@ -106,9 +82,7 @@ void IntegerRangeAnalysis::visitOperation(
10682
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
10783
IntegerValueRange oldRange = lattice->getValue();
10884

109-
ChangeResult changed =
110-
attrs ? lattice->join(IntegerValueRange{attrs})
111-
: lattice->join(IntegerValueRange::getMaxRange(v));
85+
ChangeResult changed = lattice->join(attrs);
11286

11387
// Catch loop results with loop variant bounds and conservatively make
11488
// them [-inf, inf] so we don't circle around infinitely often (because
@@ -133,17 +107,12 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
133107
ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
134108
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
135109
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
136-
// If the lattice on any operand is unitialized, bail out.
137-
if (llvm::any_of(op->getOperands(), [&](Value value) {
138-
return getLatticeElementFor(op, value)->getValue().isUninitialized();
139-
}))
140-
return;
141-
SmallVector<OptionalIntRanges> argRanges(
142-
llvm::map_range(op->getOperands(), [&](Value value) {
143-
return getOptionalRangeFromLattice(getLatticeElementFor(op, value));
144-
}));
145110

146-
auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
111+
auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
112+
return getLatticeElementFor(op, value)->getValue();
113+
});
114+
115+
auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
147116
auto arg = dyn_cast<BlockArgument>(v);
148117
if (!arg)
149118
return;
@@ -154,9 +123,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
154123
IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
155124
IntegerValueRange oldRange = lattice->getValue();
156125

157-
ChangeResult changed =
158-
attrs ? lattice->join(IntegerValueRange{attrs})
159-
: lattice->join(IntegerValueRange::getMaxRange(v));
126+
ChangeResult changed = lattice->join(attrs);
160127

161128
// Catch loop results with loop variant bounds and conservatively make
162129
// them [-inf, inf] so we don't circle around infinitely often (because

0 commit comments

Comments
 (0)