Skip to content

[mlir][Vector] Add support for poison indices to Extract/IndexOp #123488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1454,7 +1454,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
let summary = "Convert Vector dialect to SPIR-V dialect";
let constructor = "mlir::createConvertVectorToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
let dependentDialects = [
"spirv::SPIRVDialect",
"ub::UBDialect"
];
}

//===----------------------------------------------------------------------===//
Expand Down
11 changes: 10 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/Vector.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ def Vector_Dialect : Dialect {

// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<Trait> traits = []> :
Op<Vector_Dialect, mnemonic, traits>;
Op<Vector_Dialect, mnemonic, traits> {

// Includes definitions for operations that support the use of poison values
// within positive index ranges.
code extraPoisonClassDeclaration = [{
// Integer to represent a poison index within a static and positive integer
// range.
static constexpr int64_t kPoisonIndex = -1;
}];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazy me started with this approach to refactor and declare the poison value index for different operations (shuffle, insert, extract, ...), thinking that I would turn this into an interface eventually. Giving it another thought, I feel like using an interface with a getPoisonIndexValue method that returns -1 could be an overkill? WDYT?

I was also wondering if the OpFoldResult implementation would have a place somewhere to accommodate this declaration. cc: @matthias-springer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an opinion just now, but will have in due course 😂 (when I start interacting with this more).

To me this is an implementation detail that can always be updated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No opinion here, but -1 is a very easy to misplace number that someone can write.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 is the value used by LLVM so I'm just sticking to that to prevent unnecessary conversion bugs. We use numeric_limits::min for dynamic shapes so no conflict there. I can't think of an alternative that would make a difference. Any negative number would lead to a verification error or UB so...

}

#endif // MLIR_DIALECT_VECTOR_IR_VECTOR
39 changes: 15 additions & 24 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,7 @@ def Vector_ShuffleOp
```
}];

let extraClassDeclaration = [{
// Integer to represent a poison value in a vector shuffle mask.
static constexpr int64_t kMaskPoisonValue = -1;

let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getV1VectorType() {
return ::llvm::cast<VectorType>(getV1().getType());
}
Expand Down Expand Up @@ -693,9 +690,10 @@ def Vector_ExtractOp :
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
the proper position. Degenerates to an element type if n-k is zero.

Dynamic indices must be greater or equal to zero and less than the size of
the corresponding dimension. The result is undefined if any index is
out-of-bounds.
Static and dynamic indices must be greater or equal to zero and less than
the size of the corresponding dimension. The result is undefined if any
index is out-of-bounds. The value `-1` represents a poison index, which
specifies that the extracted element is poison.

Example:

Expand All @@ -705,9 +703,8 @@ def Vector_ExtractOp :
%3 = vector.extract %1[]: vector<f32> from vector<f32>
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
%6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
```

TODO: Implement support for poison indices.
}];

let arguments = (ins
Expand All @@ -724,7 +721,7 @@ def Vector_ExtractOp :
OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
];

let extraClassDeclaration = [{
let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
}
Expand Down Expand Up @@ -885,9 +882,10 @@ def Vector_InsertOp :
and inserts the n-D source into the (n+k)-D destination at the proper
position. Degenerates to a scalar or a 0-d vector source type when n = 0.

Dynamic indices must be greater or equal to zero and less than the size of
the corresponding dimension. The result is undefined if any index is
out-of-bounds.
Static and dynamic indices must be greater or equal to zero and less than
the size of the corresponding dimension. The result is undefined if any
index is out-of-bounds. The value `-1` represents a poison index, which
specifies that the resulting vector is poison.

Example:

Expand All @@ -897,9 +895,8 @@ def Vector_InsertOp :
%8 = vector.insert %6, %7[] : f32 into vector<f32>
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
%13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
```

TODO: Implement support for poison indices.
}];

let arguments = (ins
Expand All @@ -917,7 +914,7 @@ def Vector_InsertOp :
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
];

let extraClassDeclaration = [{
let extraClassDeclaration = extraPoisonClassDeclaration # [{
Type getSourceType() { return getSource().getType(); }
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
Expand Down Expand Up @@ -990,15 +987,13 @@ def Vector_ScalableInsertOp :
```mlir
%2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
```

TODO: Implement support for poison indices.
}];

let assemblyFormat = [{
$source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
}];

let extraClassDeclaration = [{
let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getSource().getType());
}
Expand Down Expand Up @@ -1043,15 +1038,13 @@ def Vector_ScalableExtractOp :
```mlir
%1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
```

TODO: Implement support for poison indices.
}];

let assemblyFormat = [{
$source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
}];

let extraClassDeclaration = [{
let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getSource().getType());
}
Expand Down Expand Up @@ -1089,8 +1082,6 @@ def Vector_InsertStridedSliceOp :
{offsets = [0, 0, 2], strides = [1, 1]}:
vector<2x4xf32> into vector<16x4x8xf32>
```

TODO: Implement support for poison indices.
}];

let assemblyFormat = [{
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def Canonicalizer : Pass<"canonicalize"> {
details.
}];
let constructor = "mlir::createCanonicalizerPass()";
let dependentDialects = ["ub::UBDialect"];
let options = [
Option<"topDownProcessingEnabled", "top-down", "bool",
/*default=*/"true",
Expand Down
2 changes: 0 additions & 2 deletions mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
Expand All @@ -27,7 +26,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>

#define DEBUG_TYPE "convert-to-spirv"
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

Expand Down
55 changes: 50 additions & 5 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
Expand Down Expand Up @@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
return srcElements[posIdx];
}

// Returns `true` if `index` is either within [0, maxIndex) or equal to
// `poisonValue`.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
int64_t maxIndex) {
return index == poisonValue || (index >= 0 && index < maxIndex);
}

//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1355,11 +1363,12 @@ LogicalResult vector::ExtractOp::verify() {
for (auto [idx, pos] : llvm::enumerate(position)) {
if (auto attr = dyn_cast<Attribute>(pos)) {
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
if (!isValidPositiveIndexOrPoison(
constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
return emitOpError("expected position attribute #")
<< (idx + 1)
<< " to be a non-negative integer smaller than the "
"corresponding vector dimension";
"corresponding vector dimension or poison (-1)";
}
}
}
Expand Down Expand Up @@ -1977,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return fromElementsOp.getElements()[flatIndex];
}

OpFoldResult ExtractOp::fold(FoldAdaptor) {
/// Fold an insert or extract operation into an poison value when a poison index
/// is found at any dimension of the static position.
static ub::PoisonAttr
foldPoisonIndexInsertExtractOp(MLIRContext *context,
ArrayRef<int64_t> staticPos, int64_t poisonVal) {
if (!llvm::is_contained(staticPos, poisonVal))
return ub::PoisonAttr();

return ub::PoisonAttr::get(context);
}

OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
// mismatch).
if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
return getVector();
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
Expand Down Expand Up @@ -2249,6 +2272,21 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
resultType.getNumElements()));
return success();
}

/// Fold an insert or extract operation into an poison value when a poison index
/// is found at any dimension of the static position.
template <typename OpTy>
LogicalResult
canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) {
if (auto poisonAttr = foldPoisonIndexInsertExtractOp(
op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) {
rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr);
return success();
}

return failure();
}

} // namespace

void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
Expand All @@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I posted another comment about this, we probably don't need to duplicate something that's in the folder here.

But no strong opinion, if you think we want to keep it here, we can, and then we can discuss in another pr if we want folders in the canonicalizer or not like this (other folders seem to also be here, which is weird to me).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current state, AFAICT, is that the canonicalizer does not apply foldings as it uses applyPatternsGreedly and not applyPatternsAndFoldGreedly so foldings are not applied as part of canonicalization. I guess it makes sense as we may not want all the foldings to be part of the canonical form (e.g., foldings that might remove structural information). For this particular case, we don't want poison indices to be part of the canonical form so I think it makes sense to have it in both places.

}

static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
Expand Down Expand Up @@ -2600,7 +2639,7 @@ LogicalResult ShuffleOp::verify() {
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
return emitOpError("mask index #") << (idx + 1) << " out of range";
}
return success();
Expand Down Expand Up @@ -2882,7 +2921,8 @@ LogicalResult InsertOp::verify() {
for (auto [idx, pos] : llvm::enumerate(position)) {
if (auto attr = pos.dyn_cast<Attribute>()) {
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
destVectorType.getDimSize(idx))) {
return emitOpError("expected position attribute #")
<< (idx + 1)
<< " to be a non-negative integer smaller than the "
Expand Down Expand Up @@ -3020,6 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
InsertOpConstantFolder>(context);
results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>);
}

OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
Expand All @@ -3028,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;

return {};
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
MLIRUBDialect
)
1 change: 1 addition & 0 deletions mlir/lib/Transforms/Canonicalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "mlir/Transforms/Passes.h"

#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,16 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 {

// -----

func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 {
%0 = vector.extract %arg0[-1]: f32 from vector<16xf32>
return %0 : f32
}
// CHECK-LABEL: @extract_poison_idx
// CHECK: %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64
// CHECK: llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32>

// -----

func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
%0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
return %0 : f32
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,14 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {

// -----

func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
return %0: f32
}

// -----

// CHECK-LABEL: @extract_size1_vector
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
Expand Down Expand Up @@ -256,6 +264,14 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {

// -----

func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
return %1: vector<4xf32>
}

// -----

// CHECK-LABEL: @insert_index_vector
// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
Expand Down
Loading
Loading