Skip to content

[mlir][dataflow] Fix for integer range analysis propagation bug #93199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 0 additions & 45 deletions mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,51 +24,6 @@
namespace mlir {
namespace dataflow {

/// This lattice value represents the integer range of an SSA value.
class IntegerValueRange {
public:
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
/// range that is used to mark the value as unable to be analyzed further,
/// where `t` is the type of `value`.
static IntegerValueRange getMaxRange(Value value);

/// Create an integer value range lattice value.
IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
: value(std::move(value)) {}

/// Whether the range is uninitialized. This happens when the state hasn't
/// been set during the analysis.
bool isUninitialized() const { return !value.has_value(); }

/// Get the known integer value range.
const ConstantIntRanges &getValue() const {
assert(!isUninitialized());
return *value;
}

/// Compare two ranges.
bool operator==(const IntegerValueRange &rhs) const {
return value == rhs.value;
}

/// Take the union of two ranges.
static IntegerValueRange join(const IntegerValueRange &lhs,
const IntegerValueRange &rhs) {
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;
return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
}

/// Print the integer value range.
void print(raw_ostream &os) const { os << value; }

private:
/// The known integer value range.
std::optional<ConstantIntRanges> value;
};

/// This lattice element represents the integer value range of an SSA value.
/// When this lattice is updated, it automatically updates the constant value
/// of the SSA value (if the range can be narrowed to one).
Expand Down
16 changes: 8 additions & 8 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
// Base class for integer binary operations.
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
Results<(outs SignlessIntegerLike:$result)>;

Expand Down Expand Up @@ -107,7 +107,7 @@ class Arith_IToICastOp<string mnemonic, list<Trait> traits = []> :
Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike,
SignlessFixedWidthIntegerLike,
traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface>]>;
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>;
// Cast from an integer type to a floating point type.
class Arith_IToFCastOp<string mnemonic, list<Trait> traits = []> :
Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike, FloatLike, traits>;
Expand Down Expand Up @@ -139,7 +139,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :

class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
DefaultValuedAttr<
Expand All @@ -159,7 +159,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
[ConstantLike, Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllTypesMatch<["value", "result"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "integer or floating point constant";
let description = [{
The `constant` operation produces an SSA value equal to some integer or
Expand Down Expand Up @@ -1327,7 +1327,7 @@ def IndexCastTypeConstraint : TypeConstraint<Or<[

def Arith_IndexCastOp
: Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint,
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "cast between index and integer types";
let description = [{
Casts between scalar or vector integers and corresponding 'index' scalar or
Expand All @@ -1346,7 +1346,7 @@ def Arith_IndexCastOp

def Arith_IndexCastUIOp
: Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "unsigned cast between index and integer types";
let description = [{
Casts between scalar or vector integers and corresponding 'index' scalar or
Expand Down Expand Up @@ -1400,7 +1400,7 @@ def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,

def Arith_CmpIOp
: Arith_CompareOpOfAnyRank<"cmpi",
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "integer comparison operation";
let description = [{
The `cmpi` operation is a generic comparison for integer-like types. Its two
Expand Down Expand Up @@ -1555,7 +1555,7 @@ class ScalarConditionOrMatchingShape<list<string> names> :
def SelectOp : Arith_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
ScalarConditionOrMatchingShape<["condition", "result"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
] # ElementwiseMappable.traits> {
let summary = "select operation";
let description = [{
Expand Down
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def GPU_DimensionAttr : EnumAttr<GPU_Dialect, GPU_Dimension, "dim">;
class GPU_IndexOp<string mnemonic, list<Trait> traits = []> :
GPU_Op<mnemonic, !listconcat(traits, [
Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>])>,
Arguments<(ins GPU_DimensionAttr:$dimension)>, Results<(outs Index)> {
let assemblyFormat = "$dimension attr-dict";
Expand Down Expand Up @@ -144,7 +144,7 @@ def GPU_ThreadIdOp : GPU_IndexOp<"thread_id"> {
}

def GPU_LaneIdOp : GPU_Op<"lane_id", [
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let description = [{
Returns the lane id within the subgroup (warp/wave).

Expand All @@ -158,7 +158,7 @@ def GPU_LaneIdOp : GPU_Op<"lane_id", [
}

def GPU_SubgroupIdOp : GPU_Op<"subgroup_id", [
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
Arguments<(ins)>, Results<(outs Index:$result)> {
let description = [{
Returns the subgroup id, i.e., the index of the current subgroup within the
Expand Down Expand Up @@ -190,7 +190,7 @@ def GPU_GlobalIdOp : GPU_IndexOp<"global_id"> {


def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
Arguments<(ins)>, Results<(outs Index:$result)> {
let description = [{
Returns the number of subgroups within a workgroup.
Expand All @@ -206,7 +206,7 @@ def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [
}

def GPU_SubgroupSizeOp : GPU_Op<"subgroup_size", [
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
Arguments<(ins)>, Results<(outs Index:$result)> {
let description = [{
Returns the number of threads within a subgroup.
Expand Down Expand Up @@ -687,7 +687,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [

def GPU_LaunchOp : GPU_Op<"launch", [
AutomaticAllocationScope, AttrSizedOperandSegments, GPU_AsyncOpInterface,
DeclareOpInterfaceMethods<InferIntRangeInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
RecursiveMemoryEffects]>,
Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Index/IR/IndexOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ include "mlir/IR/OpBase.td"
/// Base class for Index dialect operations.
class IndexOp<string mnemonic, list<Trait> traits = []>
: Op<IndexDialect, mnemonic,
[DeclareOpInterfaceMethods<InferIntRangeInterface>] # traits>;
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>] # traits>;

//===----------------------------------------------------------------------===//
// IndexBinaryOp
Expand Down
75 changes: 74 additions & 1 deletion mlir/include/mlir/Interfaces/InferIntRangeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,83 @@ class ConstantIntRanges {

raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);

/// This lattice value represents the integer range of an SSA value.
class IntegerValueRange {
public:
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
/// range that is used to mark the value as unable to be analyzed further,
/// where `t` is the type of `value`.
static IntegerValueRange getMaxRange(Value value);

/// Create an integer value range lattice value.
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}

/// Create an integer value range lattice value.
IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
: value(std::move(value)) {}

/// Whether the range is uninitialized. This happens when the state hasn't
/// been set during the analysis.
bool isUninitialized() const { return !value.has_value(); }

/// Get the known integer value range.
const ConstantIntRanges &getValue() const {
assert(!isUninitialized());
return *value;
}

/// Compare two ranges.
bool operator==(const IntegerValueRange &rhs) const {
return value == rhs.value;
}

/// Compute the least upper bound of two ranges.
static IntegerValueRange join(const IntegerValueRange &lhs,
const IntegerValueRange &rhs) {
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;
return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
}

/// Print the integer value range.
void print(raw_ostream &os) const { os << value; }

private:
/// The known integer value range.
std::optional<ConstantIntRanges> value;
};

raw_ostream &operator<<(raw_ostream &, const IntegerValueRange &);

/// The type of the `setResultRanges` callback provided to ops implementing
/// InferIntRangeInterface. It should be called once for each integer result
/// value and be passed the ConstantIntRanges corresponding to that value.
using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
using SetIntRangeFn =
llvm::function_ref<void(Value, const ConstantIntRanges &)>;

/// Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
/// This is the `setResultRanges` callback for the IntegerValueRange based
/// interface method.
using SetIntLatticeFn =
llvm::function_ref<void(Value, const IntegerValueRange &)>;

class InferIntRangeInterface;

namespace intrange::detail {
/// Default implementation of `inferResultRanges` which dispatches to the
/// `inferResultRangesFromOptional`.
void defaultInferResultRanges(InferIntRangeInterface interface,
ArrayRef<IntegerValueRange> argRanges,
SetIntLatticeFn setResultRanges);

/// Default implementation of `inferResultRangesFromOptional` which dispatches
/// to the `inferResultRanges`.
void defaultInferResultRangesFromOptional(InferIntRangeInterface interface,
ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges);
} // end namespace intrange::detail
} // end namespace mlir

#include "mlir/Interfaces/InferIntRangeInterface.h.inc"
Expand Down
46 changes: 36 additions & 10 deletions mlir/include/mlir/Interfaces/InferIntRangeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
Infer the bounds on the results of this op given the bounds on its arguments.
For each result value or block argument (that isn't a branch argument,
since the dataflow analysis handles those case), the method should call
`setValueRange` with that `Value` as an argument. When `setValueRange`
is not called for some value, it will recieve a default value of the mimimum
and maximum values for its type (the unbounded range).
`setValueRange` with that `Value` as an argument. When implemented,
`setValueRange` should be called on all result values for the operation.
When operations take non-integer inputs, the
`inferResultRangesFromOptional` method should be implemented instead.

When called on an op that also implements the RegionBranchOpInterface
or BranchOpInterface, this method should not attempt to infer the values
Expand All @@ -39,14 +40,39 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {

This function will only be called when at least one result of the op is a
scalar integer value or the op has a region.
}],
/*retTy=*/"void",
/*methodName=*/"inferResultRanges",
/*args=*/(ins "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
"::mlir::SetIntRangeFn":$setResultRanges),
/*methodBody=*/"",
/*defaultImplementation=*/[{
::mlir::intrange::detail::defaultInferResultRangesFromOptional($_op,
argRanges,
setResultRanges);
}]>,

InterfaceMethod<[{
Infer the bounds on the results of this op given the lattice representation
of the bounds for its arguments. For each result value or block argument
(that isn't a branch argument, since the dataflow analysis handles
those case), the method should call `setValueRange` with that `Value`
as an argument. When implemented, `setValueRange` should be called on
all result values for the operation.

`argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS
order. Non-integer arguments will have the an unbounded range of width-0
APInts in their `argRanges` element.
This method allows for more precise implementations when operations
want to reason about inputs which may be undefined during the analysis.
}],
"void", "inferResultRanges", (ins
"::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
"::mlir::SetIntRangeFn":$setResultRanges)
>];
/*retTy=*/"void",
/*methodName=*/"inferResultRangesFromOptional",
/*args=*/(ins "::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
"::mlir::SetIntLatticeFn":$setResultRanges),
/*methodBody=*/"",
/*defaultImplementation=*/[{
::mlir::intrange::detail::defaultInferResultRanges($_op,
argRanges,
setResultRanges);
}]>
];
}
#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE
8 changes: 6 additions & 2 deletions mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ namespace intrange {
/// abstracted away here to permit writing the function that handles both
/// 64- and 32-bit index types.
using InferRangeFn =
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;

/// Function that performs inferrence on an array of `IntegerValueRange`.
using InferIntegerValueRangeFn =
std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>;

static constexpr unsigned indexMinWidth = 32;
static constexpr unsigned indexMaxWidth = 64;
Expand All @@ -52,7 +56,7 @@ using InferRangeWithOvfFlagsFn =
///
/// The `mode` argument specifies if the unsigned, signed, or both results of
/// the inference computation should be used when comparing the results.
ConstantIntRanges inferIndexOp(InferRangeFn inferFn,
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn,
ArrayRef<ConstantIntRanges> argRanges,
CmpMode mode);

Expand Down
Loading
Loading