Skip to content

Commit 7410f33

Browse files
sabaumaSpenser Bauman
authored andcommitted
Rework integer range analysis interfaces
Modify the integer range analysis interfaces to handle uninitialized values by allowing the inferred input ranges to be optional.
1 parent 377db1a commit 7410f33

File tree

10 files changed

+366
-205
lines changed

10 files changed

+366
-205
lines changed

mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class IntegerValueRange {
3333
static IntegerValueRange getMaxRange(Value value);
3434

3535
/// Create an integer value range lattice value.
36-
IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
36+
IntegerValueRange(OptionalIntRanges value = std::nullopt)
3737
: value(std::move(value)) {}
3838

3939
/// Whether the range is uninitialized. This happens when the state hasn't

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ class ConstantIntRanges {
105105

106106
raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
107107

108+
using OptionalIntRanges = std::optional<ConstantIntRanges>;
108109
/// The type of the `setResultRanges` callback provided to ops implementing
109110
/// InferIntRangeInterface. It should be called once for each integer result
110111
/// value and be passed the ConstantIntRanges corresponding to that value.
111-
using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
112+
using SetIntRangeFn = function_ref<void(Value, const OptionalIntRanges &)>;
112113
} // end namespace mlir
113114

114115
#include "mlir/Interfaces/InferIntRangeInterface.h.inc"

mlir/include/mlir/Interfaces/InferIntRangeInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
4545
APInts in their `argRanges` element.
4646
}],
4747
"void", "inferResultRanges", (ins
48-
"::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
48+
"::llvm::ArrayRef<::std::optional<::mlir::ConstantIntRanges>>":$argRanges,
4949
"::mlir::SetIntRangeFn":$setResultRanges)
5050
>];
5151
}

mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ namespace intrange {
2525
/// abstracted away here to permit writing the function that handles both
2626
/// 64- and 32-bit index types.
2727
using InferRangeFn =
28-
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
28+
std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
29+
30+
using OptionalRangeFn =
31+
std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>;
2932

3033
static constexpr unsigned indexMinWidth = 32;
3134
static constexpr unsigned indexMaxWidth = 64;
@@ -44,6 +47,8 @@ enum class OverflowFlags : uint32_t {
4447
using InferRangeWithOvfFlagsFn =
4548
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
4649

50+
OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
51+
4752
/// Compute `inferFn` on `ranges`, whose size should be the index storage
4853
/// bitwidth. Then, compute the function on `argRanges` again after truncating
4954
/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,26 @@
3636
using namespace mlir;
3737
using namespace mlir::dataflow;
3838

39+
namespace {
40+
41+
OptionalIntRanges getOptionalRange(const IntegerValueRange &range) {
42+
if (range.isUninitialized())
43+
return std::nullopt;
44+
return range.getValue();
45+
}
46+
47+
OptionalIntRanges
48+
getOptionalRangeFromLattice(const IntegerValueRangeLattice *lattice) {
49+
return getOptionalRange(lattice->getValue());
50+
}
51+
52+
} // end namespace
53+
3954
IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
4055
unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
56+
if (width == 0)
57+
return {};
58+
4159
APInt umin = APInt::getMinValue(width);
4260
APInt umax = APInt::getMaxValue(width);
4361
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
@@ -71,23 +89,14 @@ void IntegerRangeAnalysis::visitOperation(
7189
Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
7290
ArrayRef<IntegerValueRangeLattice *> results) {
7391
// If the lattice on any operand is unitialized, bail out.
74-
if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
75-
return lattice->getValue().isUninitialized();
76-
})) {
77-
return;
78-
}
79-
8092
auto inferrable = dyn_cast<InferIntRangeInterface>(op);
8193
if (!inferrable)
8294
return setAllToEntryStates(results);
8395

8496
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
85-
SmallVector<ConstantIntRanges> argRanges(
86-
llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
87-
return val->getValue().getValue();
88-
}));
97+
auto argRanges = llvm::map_to_vector(operands, getOptionalRangeFromLattice);
8998

90-
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
99+
auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
91100
auto result = dyn_cast<OpResult>(v);
92101
if (!result)
93102
return;
@@ -97,7 +106,9 @@ void IntegerRangeAnalysis::visitOperation(
97106
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
98107
IntegerValueRange oldRange = lattice->getValue();
99108

100-
ChangeResult changed = lattice->join(IntegerValueRange{attrs});
109+
ChangeResult changed =
110+
attrs ? lattice->join(IntegerValueRange{attrs})
111+
: lattice->join(IntegerValueRange::getMaxRange(v));
101112

102113
// Catch loop results with loop variant bounds and conservatively make
103114
// them [-inf, inf] so we don't circle around infinitely often (because
@@ -127,12 +138,12 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
127138
return getLatticeElementFor(op, value)->getValue().isUninitialized();
128139
}))
129140
return;
130-
SmallVector<ConstantIntRanges> argRanges(
141+
SmallVector<OptionalIntRanges> argRanges(
131142
llvm::map_range(op->getOperands(), [&](Value value) {
132-
return getLatticeElementFor(op, value)->getValue().getValue();
143+
return getOptionalRangeFromLattice(getLatticeElementFor(op, value));
133144
}));
134145

135-
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
146+
auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
136147
auto arg = dyn_cast<BlockArgument>(v);
137148
if (!arg)
138149
return;
@@ -143,7 +154,9 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
143154
IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
144155
IntegerValueRange oldRange = lattice->getValue();
145156

146-
ChangeResult changed = lattice->join(IntegerValueRange{attrs});
157+
ChangeResult changed =
158+
attrs ? lattice->join(IntegerValueRange{attrs})
159+
: lattice->join(IntegerValueRange::getMaxRange(v));
147160

148161
// Catch loop results with loop variant bounds and conservatively make
149162
// them [-inf, inf] so we don't circle around infinitely often (because

0 commit comments

Comments
 (0)