Skip to content

Commit 35df525

Browse files
authored
[mlir][Vector] Add support for poison indices to Extract/IndexOp (#123488)
Following up on #122188, this PR adds support for poison indices to `ExtractOp` and `InsertOp`. It also includes canonicalization patterns to turn extract/insert ops with poison indices into `ub.poison`.
1 parent a06c893 commit 35df525

File tree

15 files changed

+190
-38
lines changed

15 files changed

+190
-38
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1454,7 +1454,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14541454
def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
14551455
let summary = "Convert Vector dialect to SPIR-V dialect";
14561456
let constructor = "mlir::createConvertVectorToSPIRVPass()";
1457-
let dependentDialects = ["spirv::SPIRVDialect"];
1457+
let dependentDialects = [
1458+
"spirv::SPIRVDialect",
1459+
"ub::UBDialect"
1460+
];
14581461
}
14591462

14601463
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Vector/IR/Vector.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ def Vector_Dialect : Dialect {
2626

2727
// Base class for Vector dialect ops.
2828
class Vector_Op<string mnemonic, list<Trait> traits = []> :
29-
Op<Vector_Dialect, mnemonic, traits>;
29+
Op<Vector_Dialect, mnemonic, traits> {
30+
31+
// Includes definitions for operations that support the use of poison values
32+
// within positive index ranges.
33+
code extraPoisonClassDeclaration = [{
34+
// Integer to represent a poison index within a static and positive integer
35+
// range.
36+
static constexpr int64_t kPoisonIndex = -1;
37+
}];
38+
}
3039

3140
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,7 @@ def Vector_ShuffleOp
469469
```
470470
}];
471471

472-
let extraClassDeclaration = [{
473-
// Integer to represent a poison value in a vector shuffle mask.
474-
static constexpr int64_t kMaskPoisonValue = -1;
475-
472+
let extraClassDeclaration = extraPoisonClassDeclaration # [{
476473
VectorType getV1VectorType() {
477474
return ::llvm::cast<VectorType>(getV1().getType());
478475
}
@@ -693,9 +690,10 @@ def Vector_ExtractOp :
693690
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
694691
the proper position. Degenerates to an element type if n-k is zero.
695692

696-
Dynamic indices must be greater or equal to zero and less than the size of
697-
the corresponding dimension. The result is undefined if any index is
698-
out-of-bounds.
693+
Static and dynamic indices must be greater or equal to zero and less than
694+
the size of the corresponding dimension. The result is undefined if any
695+
index is out-of-bounds. The value `-1` represents a poison index, which
696+
specifies that the extracted element is poison.
699697

700698
Example:
701699

@@ -705,9 +703,8 @@ def Vector_ExtractOp :
705703
%3 = vector.extract %1[]: vector<f32> from vector<f32>
706704
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
707705
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
706+
%6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
708707
```
709-
710-
TODO: Implement support for poison indices.
711708
}];
712709

713710
let arguments = (ins
@@ -724,7 +721,7 @@ def Vector_ExtractOp :
724721
OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
725722
];
726723

727-
let extraClassDeclaration = [{
724+
let extraClassDeclaration = extraPoisonClassDeclaration # [{
728725
VectorType getSourceVectorType() {
729726
return ::llvm::cast<VectorType>(getVector().getType());
730727
}
@@ -885,9 +882,10 @@ def Vector_InsertOp :
885882
and inserts the n-D source into the (n+k)-D destination at the proper
886883
position. Degenerates to a scalar or a 0-d vector source type when n = 0.
887884

888-
Dynamic indices must be greater or equal to zero and less than the size of
889-
the corresponding dimension. The result is undefined if any index is
890-
out-of-bounds.
885+
Static and dynamic indices must be greater or equal to zero and less than
886+
the size of the corresponding dimension. The result is undefined if any
887+
index is out-of-bounds. The value `-1` represents a poison index, which
888+
specifies that the resulting vector is poison.
891889

892890
Example:
893891

@@ -897,9 +895,8 @@ def Vector_InsertOp :
897895
%8 = vector.insert %6, %7[] : f32 into vector<f32>
898896
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
899897
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
898+
%13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
900899
```
901-
902-
TODO: Implement support for poison indices.
903900
}];
904901

905902
let arguments = (ins
@@ -917,7 +914,7 @@ def Vector_InsertOp :
917914
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
918915
];
919916

920-
let extraClassDeclaration = [{
917+
let extraClassDeclaration = extraPoisonClassDeclaration # [{
921918
Type getSourceType() { return getSource().getType(); }
922919
VectorType getDestVectorType() {
923920
return ::llvm::cast<VectorType>(getDest().getType());
@@ -990,15 +987,13 @@ def Vector_ScalableInsertOp :
990987
```mlir
991988
%2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
992989
```
993-
994-
TODO: Implement support for poison indices.
995990
}];
996991

997992
let assemblyFormat = [{
998993
$source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
999994
}];
1000995

1001-
let extraClassDeclaration = [{
996+
let extraClassDeclaration = extraPoisonClassDeclaration # [{
1002997
VectorType getSourceVectorType() {
1003998
return ::llvm::cast<VectorType>(getSource().getType());
1004999
}
@@ -1043,15 +1038,13 @@ def Vector_ScalableExtractOp :
10431038
```mlir
10441039
%1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
10451040
```
1046-
1047-
TODO: Implement support for poison indices.
10481041
}];
10491042

10501043
let assemblyFormat = [{
10511044
$source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
10521045
}];
10531046

1054-
let extraClassDeclaration = [{
1047+
let extraClassDeclaration = extraPoisonClassDeclaration # [{
10551048
VectorType getSourceVectorType() {
10561049
return ::llvm::cast<VectorType>(getSource().getType());
10571050
}
@@ -1089,8 +1082,6 @@ def Vector_InsertStridedSliceOp :
10891082
{offsets = [0, 0, 2], strides = [1, 1]}:
10901083
vector<2x4xf32> into vector<16x4x8xf32>
10911084
```
1092-
1093-
TODO: Implement support for poison indices.
10941085
}];
10951086

10961087
let assemblyFormat = [{

mlir/include/mlir/Transforms/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def Canonicalizer : Pass<"canonicalize"> {
2828
details.
2929
}];
3030
let constructor = "mlir::createCanonicalizerPass()";
31+
let dependentDialects = ["ub::UBDialect"];
3132
let options = [
3233
Option<"topDownProcessingEnabled", "top-down", "bool",
3334
/*default=*/"true",

mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1919
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2020
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
21-
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
2221
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
2322
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2423
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -27,7 +26,6 @@
2726
#include "mlir/Pass/Pass.h"
2827
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
2928
#include "mlir/Transforms/DialectConversion.h"
30-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3129
#include <memory>
3230

3331
#define DEBUG_TYPE "convert-to-spirv"

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
16-
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1716
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1817
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1918
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
1616
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1717
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18+
#include "mlir/Dialect/UB/IR/UBOps.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/DialectConversion.h"
2021

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
2020
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2121
#include "mlir/Dialect/Tensor/IR/Tensor.h"
22+
#include "mlir/Dialect/UB/IR/UBOps.h"
2223
#include "mlir/Dialect/Utils/IndexingUtils.h"
2324
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2425
#include "mlir/IR/AffineExpr.h"
@@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
12741275
return srcElements[posIdx];
12751276
}
12761277

1278+
// Returns `true` if `index` is either within [0, maxIndex) or equal to
1279+
// `poisonValue`.
1280+
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1281+
int64_t maxIndex) {
1282+
return index == poisonValue || (index >= 0 && index < maxIndex);
1283+
}
1284+
12771285
//===----------------------------------------------------------------------===//
12781286
// ExtractOp
12791287
//===----------------------------------------------------------------------===//
@@ -1355,11 +1363,12 @@ LogicalResult vector::ExtractOp::verify() {
13551363
for (auto [idx, pos] : llvm::enumerate(position)) {
13561364
if (auto attr = dyn_cast<Attribute>(pos)) {
13571365
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1358-
if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1366+
if (!isValidPositiveIndexOrPoison(
1367+
constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
13591368
return emitOpError("expected position attribute #")
13601369
<< (idx + 1)
13611370
<< " to be a non-negative integer smaller than the "
1362-
"corresponding vector dimension";
1371+
"corresponding vector dimension or poison (-1)";
13631372
}
13641373
}
13651374
}
@@ -1977,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
19771986
return fromElementsOp.getElements()[flatIndex];
19781987
}
19791988

1980-
OpFoldResult ExtractOp::fold(FoldAdaptor) {
1989+
/// Fold an insert or extract operation into an poison value when a poison index
1990+
/// is found at any dimension of the static position.
1991+
static ub::PoisonAttr
1992+
foldPoisonIndexInsertExtractOp(MLIRContext *context,
1993+
ArrayRef<int64_t> staticPos, int64_t poisonVal) {
1994+
if (!llvm::is_contained(staticPos, poisonVal))
1995+
return ub::PoisonAttr();
1996+
1997+
return ub::PoisonAttr::get(context);
1998+
}
1999+
2000+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
19812001
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
19822002
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
19832003
// mismatch).
19842004
if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
19852005
return getVector();
2006+
if (auto res = foldPoisonIndexInsertExtractOp(
2007+
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2008+
return res;
19862009
if (succeeded(foldExtractOpFromExtractChain(*this)))
19872010
return getResult();
19882011
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2249,6 +2272,21 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
22492272
resultType.getNumElements()));
22502273
return success();
22512274
}
2275+
2276+
/// Fold an insert or extract operation into an poison value when a poison index
2277+
/// is found at any dimension of the static position.
2278+
template <typename OpTy>
2279+
LogicalResult
2280+
canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) {
2281+
if (auto poisonAttr = foldPoisonIndexInsertExtractOp(
2282+
op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) {
2283+
rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr);
2284+
return success();
2285+
}
2286+
2287+
return failure();
2288+
}
2289+
22522290
} // namespace
22532291

22542292
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
22572295
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
22582296
results.add(foldExtractFromShapeCastToShapeCast);
22592297
results.add(foldExtractFromFromElements);
2298+
results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
22602299
}
22612300

22622301
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -2600,7 +2639,7 @@ LogicalResult ShuffleOp::verify() {
26002639
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
26012640
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
26022641
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2603-
if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
2642+
if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
26042643
return emitOpError("mask index #") << (idx + 1) << " out of range";
26052644
}
26062645
return success();
@@ -2882,7 +2921,8 @@ LogicalResult InsertOp::verify() {
28822921
for (auto [idx, pos] : llvm::enumerate(position)) {
28832922
if (auto attr = pos.dyn_cast<Attribute>()) {
28842923
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2885-
if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2924+
if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
2925+
destVectorType.getDimSize(idx))) {
28862926
return emitOpError("expected position attribute #")
28872927
<< (idx + 1)
28882928
<< " to be a non-negative integer smaller than the "
@@ -3020,6 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
30203060
MLIRContext *context) {
30213061
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
30223062
InsertOpConstantFolder>(context);
3063+
results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>);
30233064
}
30243065

30253066
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
@@ -3028,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
30283069
// (type mismatch).
30293070
if (getNumIndices() == 0 && getSourceType() == getType())
30303071
return getSource();
3072+
if (auto res = foldPoisonIndexInsertExtractOp(
3073+
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3074+
return res;
3075+
30313076
return {};
30323077
}
30333078

mlir/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms
3737
MLIRSideEffectInterfaces
3838
MLIRSupport
3939
MLIRTransformUtils
40+
MLIRUBDialect
4041
)

mlir/lib/Transforms/Canonicalizer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "mlir/Transforms/Passes.h"
1515

16+
#include "mlir/Dialect/UB/IR/UBOps.h"
1617
#include "mlir/Pass/Pass.h"
1718
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1819

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,16 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 {
12501250

12511251
// -----
12521252

1253+
func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 {
1254+
%0 = vector.extract %arg0[-1]: f32 from vector<16xf32>
1255+
return %0 : f32
1256+
}
1257+
// CHECK-LABEL: @extract_poison_idx
1258+
// CHECK: %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64
1259+
// CHECK: llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32>
1260+
1261+
// -----
1262+
12531263
func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
12541264
%0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
12551265
return %0 : f32

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,14 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
175175

176176
// -----
177177

178+
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
179+
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
180+
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
181+
return %0: f32
182+
}
183+
184+
// -----
185+
178186
// CHECK-LABEL: @extract_size1_vector
179187
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
180188
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
@@ -256,6 +264,14 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
256264

257265
// -----
258266

267+
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
268+
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
269+
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
270+
return %1: vector<4xf32>
271+
}
272+
273+
// -----
274+
259275
// CHECK-LABEL: @insert_index_vector
260276
// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
261277
func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {

0 commit comments

Comments
 (0)