Skip to content

[mlir] IntegerRangeAnalysis: add support for vector type #112292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS

include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/EnumAttr.td"
Comment on lines +20 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look like unrelated changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added mlir/Interfaces/InferIntRangeInterface.td and sorted rests of the includes.


// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
Expand Down Expand Up @@ -346,6 +347,7 @@ def Vector_MultiDimReductionOp :

def Vector_BroadcastOp :
Vector_Op<"broadcast", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyType:$source)>,
Expand Down Expand Up @@ -627,6 +629,7 @@ def Vector_DeinterleaveOp :

def Vector_ExtractElementOp :
Vector_Op<"extractelement", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"result type matches element type of vector operand",
"vector", "result",
"::llvm::cast<VectorType>($_self).getElementType()">]>,
Expand Down Expand Up @@ -673,6 +676,7 @@ def Vector_ExtractElementOp :

def Vector_ExtractOp :
Vector_Op<"extract", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
InferTypeOpAdaptorWithIsCompatible]> {
Expand Down Expand Up @@ -810,6 +814,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [

def Vector_InsertElementOp :
Vector_Op<"insertelement", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"source operand type matches element type of result",
"result", "source",
"::llvm::cast<VectorType>($_self).getElementType()">,
Expand Down Expand Up @@ -858,6 +863,7 @@ def Vector_InsertElementOp :

def Vector_InsertOp :
Vector_Op<"insert", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "result"]>]> {
Expand Down Expand Up @@ -2204,7 +2210,9 @@ def Vector_CompressStoreOp :
}

def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [Pure]>,
Vector_Op<"shape_cast", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]>,
Arguments<(ins AnyVectorOfAnyRank:$source)>,
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
Expand Down Expand Up @@ -2801,6 +2809,7 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,

def Vector_SplatOp : Vector_Op<"splat", [
Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
"::llvm::cast<VectorType>($_self).getElementType()">
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
Expand Down Expand Up @@ -53,9 +54,10 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
dialect = parent->getDialect();
else
dialect = value.getParentBlock()->getParentOp()->getDialect();

Type type = getElementTypeOrSelf(value);
solver->propagateIfChanged(
cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
dialect)));
cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
}

LogicalResult IntegerRangeAnalysis::visitOperation(
Expand Down
18 changes: 15 additions & 3 deletions mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,22 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {

void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
if (constAttr) {
const APInt &value = constAttr.getValue();
if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
const APInt &value = scalarCstAttr.getValue();
setResultRange(getResult(), ConstantIntRanges::constant(value));
return;
}
if (auto arrayCstAttr =
llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
std::optional<ConstantIntRanges> result;
for (const APInt &val : arrayCstAttr) {
auto range = ConstantIntRanges::constant(val);
result = (result ? result->rangeUnion(range) : range);
}

assert(result && "Zero-sized vectors are not allowed");
setResultRange(getResult(), *result);
return;
}
}

Expand Down
18 changes: 12 additions & 6 deletions mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,27 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
if (!maybeConstValue.has_value())
return failure();

Type type = value.getType();
Location loc = value.getLoc();
Operation *maybeDefiningOp = value.getDefiningOp();
Dialect *valueDialect =
maybeDefiningOp ? maybeDefiningOp->getDialect()
: value.getParentRegion()->getParentOp()->getDialect();
Attribute constAttr =
rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
Operation *constOp = valueDialect->materializeConstant(
rewriter, constAttr, value.getType(), value.getLoc());

Attribute constAttr;
if (auto shaped = dyn_cast<ShapedType>(type)) {
constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
} else {
constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
}
Operation *constOp =
valueDialect->materializeConstant(rewriter, constAttr, type, loc);
// Fall back to arith.constant if the dialect materializer doesn't know what
// to do with an integer constant.
if (!constOp)
constOp = rewriter.getContext()
->getLoadedDialect<ArithDialect>()
->materializeConstant(rewriter, constAttr, value.getType(),
value.getLoc());
->materializeConstant(rewriter, constAttr, type, loc);
if (!constOp)
return failure();

Expand Down
35 changes: 35 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,11 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ExtractElementOp
//===----------------------------------------------------------------------===//

void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
}

void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
Value source) {
result.addOperands({source});
Expand Down Expand Up @@ -1273,6 +1278,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
// ExtractOp
//===----------------------------------------------------------------------===//

void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
}

void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, int64_t position) {
build(builder, result, source, ArrayRef<int64_t>{position});
Expand Down Expand Up @@ -2252,6 +2262,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
// BroadcastOp
//===----------------------------------------------------------------------===//

void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
}

/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
static llvm::SetVector<int64_t>
Expand Down Expand Up @@ -2713,6 +2728,11 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
// InsertElementOp
//===----------------------------------------------------------------------===//

void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
}

void InsertElementOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest) {
build(builder, result, source, dest, {});
Expand Down Expand Up @@ -2762,6 +2782,11 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//

void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
}

void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest, int64_t position) {
build(builder, result, source, dest, ArrayRef<int64_t>{position});
Expand Down Expand Up @@ -5277,6 +5302,11 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ShapeCastOp
//===----------------------------------------------------------------------===//

void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
}

/// Returns true if each element of 'a' is equal to the product of a contiguous
/// sequence of the elements of 'b'. Returns false otherwise.
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
Expand Down Expand Up @@ -6423,6 +6453,11 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(getType(), {constOperand});
}

void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
}

//===----------------------------------------------------------------------===//
// WarpExecuteOnLane0Op
//===----------------------------------------------------------------------===//
Expand Down
106 changes: 106 additions & 0 deletions mlir/test/Dialect/Vector/int-range-interface.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s


// CHECK-LABEL: func @constant_vec
// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
func.func @constant_vec() -> vector<8xindex> {
%0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
%1 = test.reflect_bounds %0 : vector<8xindex>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really know what these test Ops do and I couldn't find any documentation in code. Could add some docs, pls?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test ops are from existing integer range inference tests - they have an implement of the integer range inference interface that sets attributes on reflect_bounds to match the bounds of the input

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I gathered that much from mlir/test/lib/Dialect/Test/TestOps.td, but it doesn’t quite clarify things for me. The lack of documentation for these operations makes it hard to understand the distinction between test.reflect_bounds and test.with_bounds.

Given that @Hardcode84 is already using these ops for testing, it would be fantastic if some of that expertise could be shared through documentation. This would benefit everyone working with these tests!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some description to the ops

func.return %1 : vector<8xindex>
}

// CHECK-LABEL: func @constant_splat
// CHECK: test.reflect_bounds {smax = 3 : si32, smin = 3 : si32, umax = 3 : ui32, umin = 3 : ui32}
func.func @constant_splat() -> vector<8xi32> {
%0 = arith.constant dense<3> : vector<8xi32>
%1 = test.reflect_bounds %0 : vector<8xi32>
func.return %1 : vector<8xi32>
}

// CHECK-LABEL: func @vector_splat
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
func.func @vector_splat() -> vector<4xindex> {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
%1 = vector.splat %0 : vector<4xindex>
%2 = test.reflect_bounds %1 : vector<4xindex>
func.return %2 : vector<4xindex>
}

// CHECK-LABEL: func @vector_broadcast
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
func.func @vector_broadcast() -> vector<4x16xindex> {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<16xindex>
%1 = vector.broadcast %0 : vector<16xindex> to vector<4x16xindex>
%2 = test.reflect_bounds %1 : vector<4x16xindex>
func.return %2 : vector<4x16xindex>
}

// CHECK-LABEL: func @vector_shape_cast
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
func.func @vector_shape_cast() -> vector<4x4xindex> {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<16xindex>
%1 = vector.shape_cast %0 : vector<16xindex> to vector<4x4xindex>
%2 = test.reflect_bounds %1 : vector<4x4xindex>
func.return %2 : vector<4x4xindex>
}

// CHECK-LABEL: func @vector_extract
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
func.func @vector_extract() -> index {
%0 = test.with_bounds { umin = 5 : index, umax = 6 : index, smin = 5 : index, smax = 6 : index } : vector<4xindex>
%1 = vector.extract %0[0] : index from vector<4xindex>
%2 = test.reflect_bounds %1 : index
func.return %2 : index
}

// CHECK-LABEL: func @vector_extractelement
// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index}
func.func @vector_extractelement() -> index {
%c0 = arith.constant 0 : index
%0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
%1 = vector.extractelement %0[%c0 : index] : vector<4xindex>
%2 = test.reflect_bounds %1 : index
func.return %2 : index
}

// CHECK-LABEL: func @vector_add
// CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index}
func.func @vector_add() -> vector<4xindex> {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
%1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
%2 = arith.addi %0, %1 : vector<4xindex>
%3 = test.reflect_bounds %2 : vector<4xindex>
func.return %3 : vector<4xindex>
}

// CHECK-LABEL: func @vector_insert
// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
func.func @vector_insert() -> vector<4xindex> {
%0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex>
%1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index
%2 = vector.insert %1, %0[0] : index into vector<4xindex>
%3 = test.reflect_bounds %2 : vector<4xindex>
func.return %3 : vector<4xindex>
}

// CHECK-LABEL: func @vector_insertelement
// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
func.func @vector_insertelement() -> vector<4xindex> {
%c0 = arith.constant 0 : index
%0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex>
%1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index
%2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex>
%3 = test.reflect_bounds %2 : vector<4xindex>
func.return %3 : vector<4xindex>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, just because I can't remember exactly rangeUnion() does on un-annotated values, could I get a test that goes something like

func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
  %c0 = arith.constant 0 : index 
  %v = vector.load %memref[%c0] : vector<4xi32>
  %e = vector.extract %v[0]
  %bounds = test.reflect_bounds %e : i32
  func.return %bounds : i32
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


// CHECK-LABEL: func @test_loaded_vector_extract
// No bounds
// CHECK: test.reflect_bounds %{{.*}} : i32
func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
%c0 = arith.constant 0 : index
%v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32>
%e = vector.extract %v[0] : i32 from vector<4xi32>
%bounds = test.reflect_bounds %e : i32
func.return %bounds : i32
}
5 changes: 3 additions & 2 deletions mlir/test/lib/Dialect/Test/TestOpDefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,12 +760,13 @@ void TestReflectBoundsOp::inferResultRanges(
Type sIntTy, uIntTy;
// For plain `IntegerType`s, we can derive the appropriate signed and unsigned
// Types for the Attributes.
if (auto intTy = llvm::dyn_cast<IntegerType>(getType())) {
Type type = getElementTypeOrSelf(getType());
if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
unsigned bitwidth = intTy.getWidth();
sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
} else
sIntTy = uIntTy = getType();
sIntTy = uIntTy = type;

setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
Expand Down
Loading
Loading