Skip to content

Commit ef2c812

Browse files
author
Spenser Bauman
committed
Add new IntegerRangeAnalysis interface method
This new method allows downstream implementers to easily opt into the old behavior while providing an easy way to transition to the more powerful interface methods.
1 parent b354130 commit ef2c812

File tree

15 files changed

+303
-404
lines changed

15 files changed

+303
-404
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
4949
// Base class for integer binary operations.
5050
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
5151
Arith_BinaryOp<mnemonic, traits #
52-
[DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
52+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
5353
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
5454
Results<(outs SignlessIntegerLike:$result)>;
5555

@@ -107,7 +107,7 @@ class Arith_IToICastOp<string mnemonic, list<Trait> traits = []> :
107107
Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike,
108108
SignlessFixedWidthIntegerLike,
109109
traits #
110-
[DeclareOpInterfaceMethods<InferIntRangeInterface>]>;
110+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>;
111111
// Cast from an integer type to a floating point type.
112112
class Arith_IToFCastOp<string mnemonic, list<Trait> traits = []> :
113113
Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike, FloatLike, traits>;
@@ -139,7 +139,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
139139

140140
class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
141141
Arith_BinaryOp<mnemonic, traits #
142-
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
142+
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
143143
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
144144
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
145145
DefaultValuedAttr<
@@ -159,7 +159,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
159159
[ConstantLike, Pure,
160160
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
161161
AllTypesMatch<["value", "result"]>,
162-
DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
162+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
163163
let summary = "integer or floating point constant";
164164
let description = [{
165165
The `constant` operation produces an SSA value equal to some integer or
@@ -1327,7 +1327,7 @@ def IndexCastTypeConstraint : TypeConstraint<Or<[
13271327

13281328
def Arith_IndexCastOp
13291329
: Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint,
1330-
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
1330+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
13311331
let summary = "cast between index and integer types";
13321332
let description = [{
13331333
Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1346,7 +1346,7 @@ def Arith_IndexCastOp
13461346

13471347
def Arith_IndexCastUIOp
13481348
: Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
1349-
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
1349+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
13501350
let summary = "unsigned cast between index and integer types";
13511351
let description = [{
13521352
Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1400,7 +1400,7 @@ def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
14001400

14011401
def Arith_CmpIOp
14021402
: Arith_CompareOpOfAnyRank<"cmpi",
1403-
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
1403+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
14041404
let summary = "integer comparison operation";
14051405
let description = [{
14061406
The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -1555,7 +1555,7 @@ class ScalarConditionOrMatchingShape<list<string> names> :
15551555
def SelectOp : Arith_Op<"select", [Pure,
15561556
AllTypesMatch<["true_value", "false_value", "result"]>,
15571557
ScalarConditionOrMatchingShape<["condition", "result"]>,
1558-
DeclareOpInterfaceMethods<InferIntRangeInterface>,
1558+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
15591559
] # ElementwiseMappable.traits> {
15601560
let summary = "select operation";
15611561
let description = [{

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def GPU_DimensionAttr : EnumAttr<GPU_Dialect, GPU_Dimension, "dim">;
5252
class GPU_IndexOp<string mnemonic, list<Trait> traits = []> :
5353
GPU_Op<mnemonic, !listconcat(traits, [
5454
Pure,
55-
DeclareOpInterfaceMethods<InferIntRangeInterface>,
55+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
5656
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>])>,
5757
Arguments<(ins GPU_DimensionAttr:$dimension)>, Results<(outs Index)> {
5858
let assemblyFormat = "$dimension attr-dict";
@@ -144,7 +144,7 @@ def GPU_ThreadIdOp : GPU_IndexOp<"thread_id"> {
144144
}
145145

146146
def GPU_LaneIdOp : GPU_Op<"lane_id", [
147-
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
147+
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
148148
let description = [{
149149
Returns the lane id within the subgroup (warp/wave).
150150

@@ -158,7 +158,7 @@ def GPU_LaneIdOp : GPU_Op<"lane_id", [
158158
}
159159

160160
def GPU_SubgroupIdOp : GPU_Op<"subgroup_id", [
161-
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
161+
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
162162
Arguments<(ins)>, Results<(outs Index:$result)> {
163163
let description = [{
164164
Returns the subgroup id, i.e., the index of the current subgroup within the
@@ -190,7 +190,7 @@ def GPU_GlobalIdOp : GPU_IndexOp<"global_id"> {
190190

191191

192192
def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [
193-
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
193+
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
194194
Arguments<(ins)>, Results<(outs Index:$result)> {
195195
let description = [{
196196
Returns the number of subgroups within a workgroup.
@@ -206,7 +206,7 @@ def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [
206206
}
207207

208208
def GPU_SubgroupSizeOp : GPU_Op<"subgroup_size", [
209-
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
209+
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
210210
Arguments<(ins)>, Results<(outs Index:$result)> {
211211
let description = [{
212212
Returns the number of threads within a subgroup.
@@ -687,7 +687,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
687687

688688
def GPU_LaunchOp : GPU_Op<"launch", [
689689
AutomaticAllocationScope, AttrSizedOperandSegments, GPU_AsyncOpInterface,
690-
DeclareOpInterfaceMethods<InferIntRangeInterface>,
690+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
691691
RecursiveMemoryEffects]>,
692692
Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
693693
Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,

mlir/include/mlir/Dialect/Index/IR/IndexOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ include "mlir/IR/OpBase.td"
2525
/// Base class for Index dialect operations.
2626
class IndexOp<string mnemonic, list<Trait> traits = []>
2727
: Op<IndexDialect, mnemonic,
28-
[DeclareOpInterfaceMethods<InferIntRangeInterface>] # traits>;
28+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>] # traits>;
2929

3030
//===----------------------------------------------------------------------===//
3131
// IndexBinaryOp

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,30 @@ raw_ostream &operator<<(raw_ostream &, const IntegerValueRange &);
158158
/// The type of the `setResultRanges` callback provided to ops implementing
159159
/// InferIntRangeInterface. It should be called once for each integer result
160160
/// value and be passed the ConstantIntRanges corresponding to that value.
161-
using SetIntRangeFn = function_ref<void(Value, const IntegerValueRange &)>;
161+
using SetIntRangeFn =
162+
llvm::function_ref<void(Value, const ConstantIntRanges &)>;
163+
164+
/// Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
165+
/// This is the `setResultRanges` callback for the IntegerValueRange based
166+
/// interface method.
167+
using SetIntLatticeFn =
168+
llvm::function_ref<void(Value, const IntegerValueRange &)>;
169+
170+
class InferIntRangeInterface;
171+
172+
namespace intrange::detail {
173+
/// Default implementation of `inferResultRanges` which dispatches to the
174+
/// `inferResultRangesFromOptional`.
175+
void defaultInferResultRanges(InferIntRangeInterface interface,
176+
ArrayRef<IntegerValueRange> argRanges,
177+
SetIntLatticeFn setResultRanges);
178+
179+
/// Default implementation of `inferResultRangesFromOptional` which dispatches
180+
/// to the `inferResultRanges`.
181+
void defaultInferResultRangesFromOptional(InferIntRangeInterface interface,
182+
ArrayRef<ConstantIntRanges> argRanges,
183+
SetIntRangeFn setResultRanges);
184+
} // end namespace intrange::detail
162185
} // end namespace mlir
163186

164187
#include "mlir/Interfaces/InferIntRangeInterface.h.inc"

mlir/include/mlir/Interfaces/InferIntRangeInterface.td

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
2828
Infer the bounds on the results of this op given the bounds on its arguments.
2929
For each result value or block argument (that isn't a branch argument,
3030
since the dataflow analysis handles those case), the method should call
31-
`setValueRange` with that `Value` as an argument. When `setValueRange`
32-
is not called for some value, it will recieve a default value of the mimimum
33-
and maximum values for its type (the unbounded range).
31+
`setValueRange` with that `Value` as an argument. When implemented,
32+
`setValueRange` should be called on all result values for the operation.
33+
When operations take non-integer inputs, the
34+
`inferResultRangesFromOptional` method should be implemented instead.
3435

3536
When called on an op that also implements the RegionBranchOpInterface
3637
or BranchOpInterface, this method should not attempt to infer the values
@@ -39,14 +40,39 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
3940

4041
This function will only be called when at least one result of the op is a
4142
scalar integer value or the op has a region.
43+
}],
44+
/*retTy=*/"void",
45+
/*methodName=*/"inferResultRanges",
46+
/*args=*/(ins "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
47+
"::mlir::SetIntRangeFn":$setResultRanges),
48+
/*methodBody=*/"",
49+
/*defaultImplementation=*/[{
50+
::mlir::intrange::detail::defaultInferResultRangesFromOptional($_op,
51+
argRanges,
52+
setResultRanges);
53+
}]>,
54+
55+
InterfaceMethod<[{
56+
Infer the bounds on the results of this op given the lattice representation
57+
of the bounds for its arguments. For each result value or block argument
58+
(that isn't a branch argument, since the dataflow analysis handles
59+
those case), the method should call `setValueRange` with that `Value`
60+
as an argument. When implemented, `setValueRange` should be called on
61+
all result values for the operation.
4262

43-
`argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS
44-
order. Non-integer arguments will have the an unbounded range of width-0
45-
APInts in their `argRanges` element.
63+
This method allows for more precise implementations when operations
64+
want to reason about inputs which may be undefined during the analysis.
4665
}],
47-
"void", "inferResultRanges", (ins
48-
"::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
49-
"::mlir::SetIntRangeFn":$setResultRanges)
50-
>];
66+
/*retTy=*/"void",
67+
/*methodName=*/"inferResultRangesFromOptional",
68+
/*args=*/(ins "::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
69+
"::mlir::SetIntLatticeFn":$setResultRanges),
70+
/*methodBody=*/"",
71+
/*defaultImplementation=*/[{
72+
::mlir::intrange::detail::defaultInferResultRanges($_op,
73+
argRanges,
74+
setResultRanges);
75+
}]>
76+
];
5177
}
5278
#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE

mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,6 @@ enum class OverflowFlags : uint32_t {
4848
using InferRangeWithOvfFlagsFn =
4949
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
5050

51-
/// Perform a pointwise extension of a function operating on `ConstantIntRanges`
52-
/// to a function operating on `IntegerValueRange` such that undefined input
53-
/// ranges propagate.
54-
InferIntegerValueRangeFn
55-
inferFromIntegerValueRange(intrange::InferRangeFn inferFn);
56-
5751
/// Compute `inferFn` on `ranges`, whose size should be the index storage
5852
/// bitwidth. Then, compute the function on `argRanges` again after truncating
5953
/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
6161
void IntegerRangeAnalysis::visitOperation(
6262
Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
6363
ArrayRef<IntegerValueRangeLattice *> results) {
64-
// If the lattice on any operand is unitialized, bail out.
6564
auto inferrable = dyn_cast<InferIntRangeInterface>(op);
6665
if (!inferrable)
6766
return setAllToEntryStates(results);
@@ -99,7 +98,7 @@ void IntegerRangeAnalysis::visitOperation(
9998
propagateIfChanged(lattice, changed);
10099
};
101100

102-
inferrable.inferResultRanges(argRanges, joinCallback);
101+
inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
103102
}
104103

105104
void IntegerRangeAnalysis::visitNonControlFlowArguments(
@@ -140,7 +139,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
140139
propagateIfChanged(lattice, changed);
141140
};
142141

143-
inferrable.inferResultRanges(argRanges, joinCallback);
142+
inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
144143
return;
145144
}
146145

0 commit comments

Comments
 (0)