Skip to content

Commit 2f5a403

Browse files
sabaumaSpenser Bauman
authored andcommitted
Improved range analysis fixes
1 parent 59da46c commit 2f5a403

File tree

10 files changed

+364
-203
lines changed

10 files changed

+364
-203
lines changed

mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class IntegerValueRange {
3333
static IntegerValueRange getMaxRange(Value value);
3434

3535
/// Create an integer value range lattice value.
36-
IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
36+
IntegerValueRange(OptionalIntRanges value = std::nullopt)
3737
: value(std::move(value)) {}
3838

3939
/// Whether the range is uninitialized. This happens when the state hasn't

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ class ConstantIntRanges {
105105

106106
raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
107107

108+
using OptionalIntRanges = std::optional<ConstantIntRanges>;
108109
/// The type of the `setResultRanges` callback provided to ops implementing
109110
/// InferIntRangeInterface. It should be called once for each integer result
110111
/// value and be passed the ConstantIntRanges corresponding to that value.
111-
using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
112+
using SetIntRangeFn = function_ref<void(Value, const OptionalIntRanges &)>;
112113
} // end namespace mlir
113114

114115
#include "mlir/Interfaces/InferIntRangeInterface.h.inc"

mlir/include/mlir/Interfaces/InferIntRangeInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
4545
APInts in their `argRanges` element.
4646
}],
4747
"void", "inferResultRanges", (ins
48-
"::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
48+
"::llvm::ArrayRef<::std::optional<::mlir::ConstantIntRanges>>":$argRanges,
4949
"::mlir::SetIntRangeFn":$setResultRanges)
5050
>];
5151
}

mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ namespace intrange {
2727
using InferRangeFn =
2828
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
2929

30+
using OptionalRangeFn =
31+
std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>;
32+
3033
static constexpr unsigned indexMinWidth = 32;
3134
static constexpr unsigned indexMaxWidth = 64;
3235

@@ -44,6 +47,8 @@ enum class OverflowFlags : uint32_t {
4447
using InferRangeWithOvfFlagsFn =
4548
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
4649

50+
OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
51+
4752
/// Compute `inferFn` on `ranges`, whose size should be the index storage
4853
/// bitwidth. Then, compute the function on `argRanges` again after truncating
4954
/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,25 @@
3636
using namespace mlir;
3737
using namespace mlir::dataflow;
3838

39+
namespace {
40+
41+
OptionalIntRanges getOptionalRange(const IntegerValueRange &range) {
42+
if (range.isUninitialized())
43+
return std::nullopt;
44+
return range.getValue();
45+
}
46+
47+
OptionalIntRanges getOptionalRangeFromLattice(const IntegerValueRangeLattice* lattice) {
48+
return getOptionalRange(lattice->getValue());
49+
}
50+
51+
} // end namespace
52+
3953
IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
4054
unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
55+
if (width == 0)
56+
return {};
57+
4158
APInt umin = APInt::getMinValue(width);
4259
APInt umax = APInt::getMaxValue(width);
4360
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
@@ -71,23 +88,15 @@ void IntegerRangeAnalysis::visitOperation(
7188
Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
7289
ArrayRef<IntegerValueRangeLattice *> results) {
7390
// If the lattice on any operand is unitialized, bail out.
74-
if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
75-
return lattice->getValue().isUninitialized();
76-
})) {
77-
return;
78-
}
79-
8091
auto inferrable = dyn_cast<InferIntRangeInterface>(op);
8192
if (!inferrable)
8293
return setAllToEntryStates(results);
8394

8495
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
85-
SmallVector<ConstantIntRanges> argRanges(
86-
llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
87-
return val->getValue().getValue();
88-
}));
96+
SmallVector<OptionalIntRanges> argRanges(llvm::map_range(
97+
operands, getOptionalRangeFromLattice));
8998

90-
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
99+
auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
91100
auto result = dyn_cast<OpResult>(v);
92101
if (!result)
93102
return;
@@ -97,7 +106,9 @@ void IntegerRangeAnalysis::visitOperation(
97106
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
98107
IntegerValueRange oldRange = lattice->getValue();
99108

100-
ChangeResult changed = lattice->join(IntegerValueRange{attrs});
109+
ChangeResult changed =
110+
attrs ? lattice->join(IntegerValueRange{attrs})
111+
: lattice->join(IntegerValueRange::getMaxRange(v));
101112

102113
// Catch loop results with loop variant bounds and conservatively make
103114
// them [-inf, inf] so we don't circle around infinitely often (because
@@ -127,12 +138,12 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
127138
return getLatticeElementFor(op, value)->getValue().isUninitialized();
128139
}))
129140
return;
130-
SmallVector<ConstantIntRanges> argRanges(
141+
SmallVector<OptionalIntRanges> argRanges(
131142
llvm::map_range(op->getOperands(), [&](Value value) {
132-
return getLatticeElementFor(op, value)->getValue().getValue();
143+
return getOptionalRangeFromLattice(getLatticeElementFor(op, value));
133144
}));
134145

135-
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
146+
auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
136147
auto arg = dyn_cast<BlockArgument>(v);
137148
if (!arg)
138149
return;
@@ -143,7 +154,9 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
143154
IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
144155
IntegerValueRange oldRange = lattice->getValue();
145156

146-
ChangeResult changed = lattice->join(IntegerValueRange{attrs});
157+
ChangeResult changed =
158+
attrs ? lattice->join(IntegerValueRange{attrs})
159+
: lattice->join(IntegerValueRange::getMaxRange(v));
147160

148161
// Catch loop results with loop variant bounds and conservatively make
149162
// them [-inf, inf] so we don't circle around infinitely often (because

0 commit comments

Comments
 (0)