Skip to content

Commit 6aeea70

Browse files
sabaumaSpenser Bauman
andauthored
[mlir][dataflow] Fix for integer range analysis propagation bug (#93199)
Integer range analysis will not update the range of an operation when any of the inferred input lattices are uninitialized. In the current behavior, all lattice values for non integer types are uninitialized. For operations like arith.cmpf ```mlir %3 = arith.cmpf ugt, %arg0, %arg1 : f32 ``` that will result in the range of the output also being uninitialized, and so on for any consumer of the arith.cmpf result. When control-flow ops are involved, the lack of propagation results in incorrect ranges, as the back edges for loop carried values are not properly joined with the definitions from the body region. For example, an scf.while loop whose body region produces a value that is in a dataflow relationship with some floating-point values through an arith.cmpf operation: ```mlir func.func @test_bad_range(%arg0: f32, %arg1: f32) -> (index, index) { %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %3 = arith.cmpf ugt, %arg0, %arg1 : f32 %1:2 = scf.while (%arg2 = %c0, %arg3 = %c0) : (index, index) -> (index, index) { %2 = arith.cmpi ult, %arg2, %c4 : index scf.condition(%2) %arg2, %arg3 : index, index } do { ^bb0(%arg2: index, %arg3: index): %4 = arith.select %3, %arg3, %arg3 : index %5 = arith.addi %arg2, %c1 : index scf.yield %5, %4 : index, index } return %1#0, %1#1 : index, index } ``` The existing behavior results in the control condition %2 being optimized to true, turning the while loop into an infinite loop. The update to %arg2 through the body region is never factored into the range calculation, as the ranges for the body ops all test as uninitialized. This change causes all values initialized with setToEntryState to be set to some initialized range, even if the values are not integers. --------- Co-authored-by: Spenser Bauman <sabauma@fastmail>
1 parent b9cdea6 commit 6aeea70

File tree

13 files changed

+230
-121
lines changed

13 files changed

+230
-121
lines changed

mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,51 +24,6 @@
2424
namespace mlir {
2525
namespace dataflow {
2626

27-
/// This lattice value represents the integer range of an SSA value.
28-
class IntegerValueRange {
29-
public:
30-
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
31-
/// range that is used to mark the value as unable to be analyzed further,
32-
/// where `t` is the type of `value`.
33-
static IntegerValueRange getMaxRange(Value value);
34-
35-
/// Create an integer value range lattice value.
36-
IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
37-
: value(std::move(value)) {}
38-
39-
/// Whether the range is uninitialized. This happens when the state hasn't
40-
/// been set during the analysis.
41-
bool isUninitialized() const { return !value.has_value(); }
42-
43-
/// Get the known integer value range.
44-
const ConstantIntRanges &getValue() const {
45-
assert(!isUninitialized());
46-
return *value;
47-
}
48-
49-
/// Compare two ranges.
50-
bool operator==(const IntegerValueRange &rhs) const {
51-
return value == rhs.value;
52-
}
53-
54-
/// Take the union of two ranges.
55-
static IntegerValueRange join(const IntegerValueRange &lhs,
56-
const IntegerValueRange &rhs) {
57-
if (lhs.isUninitialized())
58-
return rhs;
59-
if (rhs.isUninitialized())
60-
return lhs;
61-
return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
62-
}
63-
64-
/// Print the integer value range.
65-
void print(raw_ostream &os) const { os << value; }
66-
67-
private:
68-
/// The known integer value range.
69-
std::optional<ConstantIntRanges> value;
70-
};
71-
7227
/// This lattice element represents the integer value range of an SSA value.
7328
/// When this lattice is updated, it automatically updates the constant value
7429
/// of the SSA value (if the range can be narrowed to one).

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
4949
// Base class for integer binary operations.
5050
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
5151
Arith_BinaryOp<mnemonic, traits #
52-
[DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
52+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
5353
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
5454
Results<(outs SignlessIntegerLike:$result)>;
5555

@@ -107,7 +107,7 @@ class Arith_IToICastOp<string mnemonic, list<Trait> traits = []> :
107107
Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike,
108108
SignlessFixedWidthIntegerLike,
109109
traits #
110-
[DeclareOpInterfaceMethods<InferIntRangeInterface>]>;
110+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>;
111111
// Cast from an integer type to a floating point type.
112112
class Arith_IToFCastOp<string mnemonic, list<Trait> traits = []> :
113113
Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike, FloatLike, traits>;
@@ -139,7 +139,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
139139

140140
class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
141141
Arith_BinaryOp<mnemonic, traits #
142-
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
142+
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
143143
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
144144
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
145145
DefaultValuedAttr<
@@ -159,7 +159,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
159159
[ConstantLike, Pure,
160160
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
161161
AllTypesMatch<["value", "result"]>,
162-
DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
162+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
163163
let summary = "integer or floating point constant";
164164
let description = [{
165165
The `constant` operation produces an SSA value equal to some integer or
@@ -1327,7 +1327,7 @@ def IndexCastTypeConstraint : TypeConstraint<Or<[
13271327

13281328
def Arith_IndexCastOp
13291329
: Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint,
1330-
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
1330+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
13311331
let summary = "cast between index and integer types";
13321332
let description = [{
13331333
Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1346,7 +1346,7 @@ def Arith_IndexCastOp
13461346

13471347
def Arith_IndexCastUIOp
13481348
: Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
1349-
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
1349+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
13501350
let summary = "unsigned cast between index and integer types";
13511351
let description = [{
13521352
Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1400,7 +1400,7 @@ def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
14001400

14011401
def Arith_CmpIOp
14021402
: Arith_CompareOpOfAnyRank<"cmpi",
1403-
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
1403+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
14041404
let summary = "integer comparison operation";
14051405
let description = [{
14061406
The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -1555,7 +1555,7 @@ class ScalarConditionOrMatchingShape<list<string> names> :
15551555
def SelectOp : Arith_Op<"select", [Pure,
15561556
AllTypesMatch<["true_value", "false_value", "result"]>,
15571557
ScalarConditionOrMatchingShape<["condition", "result"]>,
1558-
DeclareOpInterfaceMethods<InferIntRangeInterface>,
1558+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
15591559
] # ElementwiseMappable.traits> {
15601560
let summary = "select operation";
15611561
let description = [{

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def GPU_DimensionAttr : EnumAttr<GPU_Dialect, GPU_Dimension, "dim">;
5252
class GPU_IndexOp<string mnemonic, list<Trait> traits = []> :
5353
GPU_Op<mnemonic, !listconcat(traits, [
5454
Pure,
55-
DeclareOpInterfaceMethods<InferIntRangeInterface>,
55+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
5656
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>])>,
5757
Arguments<(ins GPU_DimensionAttr:$dimension)>, Results<(outs Index)> {
5858
let assemblyFormat = "$dimension attr-dict";
@@ -144,7 +144,7 @@ def GPU_ThreadIdOp : GPU_IndexOp<"thread_id"> {
144144
}
145145

146146
def GPU_LaneIdOp : GPU_Op<"lane_id", [
147-
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
147+
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
148148
let description = [{
149149
Returns the lane id within the subgroup (warp/wave).
150150

@@ -158,7 +158,7 @@ def GPU_LaneIdOp : GPU_Op<"lane_id", [
158158
}
159159

160160
def GPU_SubgroupIdOp : GPU_Op<"subgroup_id", [
161-
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
161+
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
162162
Arguments<(ins)>, Results<(outs Index:$result)> {
163163
let description = [{
164164
Returns the subgroup id, i.e., the index of the current subgroup within the
@@ -190,7 +190,7 @@ def GPU_GlobalIdOp : GPU_IndexOp<"global_id"> {
190190

191191

192192
def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [
193-
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
193+
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
194194
Arguments<(ins)>, Results<(outs Index:$result)> {
195195
let description = [{
196196
Returns the number of subgroups within a workgroup.
@@ -206,7 +206,7 @@ def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [
206206
}
207207

208208
def GPU_SubgroupSizeOp : GPU_Op<"subgroup_size", [
209-
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
209+
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
210210
Arguments<(ins)>, Results<(outs Index:$result)> {
211211
let description = [{
212212
Returns the number of threads within a subgroup.
@@ -687,7 +687,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
687687

688688
def GPU_LaunchOp : GPU_Op<"launch", [
689689
AutomaticAllocationScope, AttrSizedOperandSegments, GPU_AsyncOpInterface,
690-
DeclareOpInterfaceMethods<InferIntRangeInterface>,
690+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
691691
RecursiveMemoryEffects]>,
692692
Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
693693
Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,

mlir/include/mlir/Dialect/Index/IR/IndexOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ include "mlir/IR/OpBase.td"
2525
/// Base class for Index dialect operations.
2626
class IndexOp<string mnemonic, list<Trait> traits = []>
2727
: Op<IndexDialect, mnemonic,
28-
[DeclareOpInterfaceMethods<InferIntRangeInterface>] # traits>;
28+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>] # traits>;
2929

3030
//===----------------------------------------------------------------------===//
3131
// IndexBinaryOp

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,83 @@ class ConstantIntRanges {
105105

106106
raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
107107

108+
/// This lattice value represents the integer range of an SSA value.
109+
class IntegerValueRange {
110+
public:
111+
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
112+
/// range that is used to mark the value as unable to be analyzed further,
113+
/// where `t` is the type of `value`.
114+
static IntegerValueRange getMaxRange(Value value);
115+
116+
/// Create an integer value range lattice value.
117+
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
118+
119+
/// Create an integer value range lattice value.
120+
IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
121+
: value(std::move(value)) {}
122+
123+
/// Whether the range is uninitialized. This happens when the state hasn't
124+
/// been set during the analysis.
125+
bool isUninitialized() const { return !value.has_value(); }
126+
127+
/// Get the known integer value range.
128+
const ConstantIntRanges &getValue() const {
129+
assert(!isUninitialized());
130+
return *value;
131+
}
132+
133+
/// Compare two ranges.
134+
bool operator==(const IntegerValueRange &rhs) const {
135+
return value == rhs.value;
136+
}
137+
138+
/// Compute the least upper bound of two ranges.
139+
static IntegerValueRange join(const IntegerValueRange &lhs,
140+
const IntegerValueRange &rhs) {
141+
if (lhs.isUninitialized())
142+
return rhs;
143+
if (rhs.isUninitialized())
144+
return lhs;
145+
return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
146+
}
147+
148+
/// Print the integer value range.
149+
void print(raw_ostream &os) const { os << value; }
150+
151+
private:
152+
/// The known integer value range.
153+
std::optional<ConstantIntRanges> value;
154+
};
155+
156+
raw_ostream &operator<<(raw_ostream &, const IntegerValueRange &);
157+
108158
/// The type of the `setResultRanges` callback provided to ops implementing
109159
/// InferIntRangeInterface. It should be called once for each integer result
110160
/// value and be passed the ConstantIntRanges corresponding to that value.
111-
using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
161+
using SetIntRangeFn =
162+
llvm::function_ref<void(Value, const ConstantIntRanges &)>;
163+
164+
/// Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
165+
/// This is the `setResultRanges` callback for the IntegerValueRange based
166+
/// interface method.
167+
using SetIntLatticeFn =
168+
llvm::function_ref<void(Value, const IntegerValueRange &)>;
169+
170+
class InferIntRangeInterface;
171+
172+
namespace intrange::detail {
173+
/// Default implementation of `inferResultRanges` which dispatches to the
174+
/// `inferResultRangesFromOptional`.
175+
void defaultInferResultRanges(InferIntRangeInterface interface,
176+
ArrayRef<IntegerValueRange> argRanges,
177+
SetIntLatticeFn setResultRanges);
178+
179+
/// Default implementation of `inferResultRangesFromOptional` which dispatches
180+
/// to the `inferResultRanges`.
181+
void defaultInferResultRangesFromOptional(InferIntRangeInterface interface,
182+
ArrayRef<ConstantIntRanges> argRanges,
183+
SetIntRangeFn setResultRanges);
184+
} // end namespace intrange::detail
112185
} // end namespace mlir
113186

114187
#include "mlir/Interfaces/InferIntRangeInterface.h.inc"

mlir/include/mlir/Interfaces/InferIntRangeInterface.td

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
2828
Infer the bounds on the results of this op given the bounds on its arguments.
2929
For each result value or block argument (that isn't a branch argument,
3030
since the dataflow analysis handles those case), the method should call
31-
`setValueRange` with that `Value` as an argument. When `setValueRange`
32-
is not called for some value, it will recieve a default value of the mimimum
33-
and maximum values for its type (the unbounded range).
31+
`setValueRange` with that `Value` as an argument. When implemented,
32+
`setValueRange` should be called on all result values for the operation.
33+
When operations take non-integer inputs, the
34+
`inferResultRangesFromOptional` method should be implemented instead.
3435

3536
When called on an op that also implements the RegionBranchOpInterface
3637
or BranchOpInterface, this method should not attempt to infer the values
@@ -39,14 +40,39 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
3940

4041
This function will only be called when at least one result of the op is a
4142
scalar integer value or the op has a region.
43+
}],
44+
/*retTy=*/"void",
45+
/*methodName=*/"inferResultRanges",
46+
/*args=*/(ins "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
47+
"::mlir::SetIntRangeFn":$setResultRanges),
48+
/*methodBody=*/"",
49+
/*defaultImplementation=*/[{
50+
::mlir::intrange::detail::defaultInferResultRangesFromOptional($_op,
51+
argRanges,
52+
setResultRanges);
53+
}]>,
54+
55+
InterfaceMethod<[{
56+
Infer the bounds on the results of this op given the lattice representation
57+
of the bounds for its arguments. For each result value or block argument
58+
(that isn't a branch argument, since the dataflow analysis handles
59+
those case), the method should call `setValueRange` with that `Value`
60+
as an argument. When implemented, `setValueRange` should be called on
61+
all result values for the operation.
4262

43-
`argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS
44-
order. Non-integer arguments will have the an unbounded range of width-0
45-
APInts in their `argRanges` element.
63+
This method allows for more precise implementations when operations
64+
want to reason about inputs which may be undefined during the analysis.
4665
}],
47-
"void", "inferResultRanges", (ins
48-
"::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
49-
"::mlir::SetIntRangeFn":$setResultRanges)
50-
>];
66+
/*retTy=*/"void",
67+
/*methodName=*/"inferResultRangesFromOptional",
68+
/*args=*/(ins "::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
69+
"::mlir::SetIntLatticeFn":$setResultRanges),
70+
/*methodBody=*/"",
71+
/*defaultImplementation=*/[{
72+
::mlir::intrange::detail::defaultInferResultRanges($_op,
73+
argRanges,
74+
setResultRanges);
75+
}]>
76+
];
5177
}
5278
#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE

mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ namespace intrange {
2525
/// abstracted away here to permit writing the function that handles both
2626
/// 64- and 32-bit index types.
2727
using InferRangeFn =
28-
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
28+
std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
29+
30+
/// Function that performs inferrence on an array of `IntegerValueRange`.
31+
using InferIntegerValueRangeFn =
32+
std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>;
2933

3034
static constexpr unsigned indexMinWidth = 32;
3135
static constexpr unsigned indexMaxWidth = 64;
@@ -52,7 +56,7 @@ using InferRangeWithOvfFlagsFn =
5256
///
5357
/// The `mode` argument specifies if the unsigned, signed, or both results of
5458
/// the inference computation should be used when comparing the results.
55-
ConstantIntRanges inferIndexOp(InferRangeFn inferFn,
59+
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn,
5660
ArrayRef<ConstantIntRanges> argRanges,
5761
CmpMode mode);
5862

0 commit comments

Comments
 (0)