Skip to content

Commit 4a9ae1c

Browse files
committed
[mlir][Vector] Add support for poison indices to Extract/IndexOp
Following up on #122188, this PR adds support for poison indices to `ExtractOp` and `InsertOp`. It also includes canonicalization patterns to turn extract/insert ops with poison indices into `ub.poison`.
1 parent 5a81a55 commit 4a9ae1c

File tree

9 files changed

+108
-25
lines changed

9 files changed

+108
-25
lines changed

mlir/include/mlir/Dialect/Vector/IR/Vector.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ def Vector_Dialect : Dialect {
2626

2727
// Base class for Vector dialect ops.
2828
class Vector_Op<string mnemonic, list<Trait> traits = []> :
29-
Op<Vector_Dialect, mnemonic, traits>;
29+
Op<Vector_Dialect, mnemonic, traits> {
30+
31+
// Includes definitions for operations that support the use of poison values
32+
// within positive index ranges.
33+
code extraPoisonClassDeclaration = [{
34+
// Integer to represent a poison index within a static and positive integer
35+
// range.
36+
static constexpr int64_t kPoisonIndex = -1;
37+
}];
38+
}
3039

3140
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,7 @@ def Vector_ShuffleOp
469469
```
470470
}];
471471

472-
let extraClassDeclaration = [{
473-
// Integer to represent a poison value in a vector shuffle mask.
474-
static constexpr int64_t kMaskPoisonValue = -1;
475-
472+
let extraClassDeclaration = extraPoisonClassDeclaration # [{
476473
VectorType getV1VectorType() {
477474
return ::llvm::cast<VectorType>(getV1().getType());
478475
}
@@ -706,8 +703,6 @@ def Vector_ExtractOp :
706703
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
707704
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
708705
```
709-
710-
TODO: Implement support for poison indices.
711706
}];
712707

713708
let arguments = (ins
@@ -724,7 +719,7 @@ def Vector_ExtractOp :
724719
OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
725720
];
726721

727-
let extraClassDeclaration = [{
722+
let extraClassDeclaration = extraPoisonClassDeclaration # [{
728723
VectorType getSourceVectorType() {
729724
return ::llvm::cast<VectorType>(getVector().getType());
730725
}
@@ -898,8 +893,6 @@ def Vector_InsertOp :
898893
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
899894
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
900895
```
901-
902-
TODO: Implement support for poison indices.
903896
}];
904897

905898
let arguments = (ins
@@ -917,7 +910,7 @@ def Vector_InsertOp :
917910
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
918911
];
919912

920-
let extraClassDeclaration = [{
913+
let extraClassDeclaration = extraPoisonClassDeclaration # [{
921914
Type getSourceType() { return getSource().getType(); }
922915
VectorType getDestVectorType() {
923916
return ::llvm::cast<VectorType>(getDest().getType());
@@ -990,15 +983,13 @@ def Vector_ScalableInsertOp :
990983
```mlir
991984
%2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
992985
```
993-
994-
TODO: Implement support for poison indices.
995986
}];
996987

997988
let assemblyFormat = [{
998989
$source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
999990
}];
1000991

1001-
let extraClassDeclaration = [{
992+
let extraClassDeclaration = extraPoisonClassDeclaration # [{
1002993
VectorType getSourceVectorType() {
1003994
return ::llvm::cast<VectorType>(getSource().getType());
1004995
}
@@ -1043,15 +1034,13 @@ def Vector_ScalableExtractOp :
10431034
```mlir
10441035
%1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
10451036
```
1046-
1047-
TODO: Implement support for poison indices.
10481037
}];
10491038

10501039
let assemblyFormat = [{
10511040
$source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
10521041
}];
10531042

1054-
let extraClassDeclaration = [{
1043+
let extraClassDeclaration = extraPoisonClassDeclaration # [{
10551044
VectorType getSourceVectorType() {
10561045
return ::llvm::cast<VectorType>(getSource().getType());
10571046
}
@@ -1089,8 +1078,6 @@ def Vector_InsertStridedSliceOp :
10891078
{offsets = [0, 0, 2], strides = [1, 1]}:
10901079
vector<2x4xf32> into vector<16x4x8xf32>
10911080
```
1092-
1093-
TODO: Implement support for poison indices.
10941081
}];
10951082

10961083
let assemblyFormat = [{

mlir/include/mlir/Transforms/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def Canonicalizer : Pass<"canonicalize"> {
2828
details.
2929
}];
3030
let constructor = "mlir::createCanonicalizerPass()";
31+
let dependentDialects = ["ub::UBDialect"];
3132
let options = [
3233
Option<"topDownProcessingEnabled", "top-down", "bool",
3334
/*default=*/"true",

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
2020
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2121
#include "mlir/Dialect/Tensor/IR/Tensor.h"
22+
#include "mlir/Dialect/UB/IR/UBOps.h"
2223
#include "mlir/Dialect/Utils/IndexingUtils.h"
2324
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2425
#include "mlir/IR/AffineExpr.h"
@@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
12741275
return srcElements[posIdx];
12751276
}
12761277

1278+
// Returns `true` if `index` is either within [0, maxIndex) or equal to
1279+
// `poisonValue`.
1280+
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1281+
int64_t maxIndex) {
1282+
return index == poisonValue || (index >= 0 && index < maxIndex);
1283+
}
1284+
12771285
//===----------------------------------------------------------------------===//
12781286
// ExtractOp
12791287
//===----------------------------------------------------------------------===//
@@ -1355,7 +1363,8 @@ LogicalResult vector::ExtractOp::verify() {
13551363
for (auto [idx, pos] : llvm::enumerate(position)) {
13561364
if (auto attr = dyn_cast<Attribute>(pos)) {
13571365
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1358-
if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1366+
if (!isValidPositiveIndexOrPoison(
1367+
constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
13591368
return emitOpError("expected position attribute #")
13601369
<< (idx + 1)
13611370
<< " to be a non-negative integer smaller than the "
@@ -2249,6 +2258,23 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
22492258
resultType.getNumElements()));
22502259
return success();
22512260
}
2261+
2262+
/// Fold an insert or extract operation into an poison value when a poison index
2263+
/// is found at any dimension of the static position.
2264+
template <typename OpTy>
2265+
LogicalResult foldPoisonIndexInsertExtractOp(OpTy op,
2266+
PatternRewriter &rewriter) {
2267+
auto hasPoisonIndex = [](int64_t index) {
2268+
return index == OpTy::kPoisonIndex;
2269+
};
2270+
2271+
if (llvm::none_of(op.getStaticPosition(), hasPoisonIndex))
2272+
return failure();
2273+
2274+
rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getResult().getType());
2275+
return success();
2276+
}
2277+
22522278
} // namespace
22532279

22542280
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2257,6 +2283,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
22572283
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
22582284
results.add(foldExtractFromShapeCastToShapeCast);
22592285
results.add(foldExtractFromFromElements);
2286+
results.add(foldPoisonIndexInsertExtractOp<ExtractOp>);
22602287
}
22612288

22622289
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -2600,7 +2627,7 @@ LogicalResult ShuffleOp::verify() {
26002627
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
26012628
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
26022629
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2603-
if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
2630+
if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
26042631
return emitOpError("mask index #") << (idx + 1) << " out of range";
26052632
}
26062633
return success();
@@ -2882,7 +2909,8 @@ LogicalResult InsertOp::verify() {
28822909
for (auto [idx, pos] : llvm::enumerate(position)) {
28832910
if (auto attr = pos.dyn_cast<Attribute>()) {
28842911
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2885-
if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2912+
if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
2913+
destVectorType.getDimSize(idx))) {
28862914
return emitOpError("expected position attribute #")
28872915
<< (idx + 1)
28882916
<< " to be a non-negative integer smaller than the "
@@ -3020,6 +3048,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
30203048
MLIRContext *context) {
30213049
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
30223050
InsertOpConstantFolder>(context);
3051+
results.add(foldPoisonIndexInsertExtractOp<InsertOp>);
30233052
}
30243053

30253054
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {

mlir/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms
3737
MLIRSideEffectInterfaces
3838
MLIRSupport
3939
MLIRTransformUtils
40+
MLIRUBDialect
4041
)

mlir/lib/Transforms/Canonicalizer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "mlir/Transforms/Passes.h"
1515

16+
#include "mlir/Dialect/UB/IR/UBOps.h"
1617
#include "mlir/Pass/Pass.h"
1718
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1819

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,26 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
132132

133133
// -----
134134

135+
// CHECK-LABEL: @extract_scalar_poison_idx
136+
func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
137+
// CHECK-NOT: vector.extract
138+
// CHECK-NEXT: ub.poison : f32
139+
%0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
140+
return %0 : f32
141+
}
142+
143+
// -----
144+
145+
// CHECK-LABEL: @extract_vector_poison_idx
146+
func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
147+
// CHECK-NOT: vector.extract
148+
// CHECK-NEXT: ub.poison : vector<5xf32>
149+
%0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32>
150+
return %0 : vector<5xf32>
151+
}
152+
153+
// -----
154+
135155
// CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
136156
// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
137157
func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
@@ -2778,7 +2798,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
27782798
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
27792799
}
27802800

2781-
27822801
// -----
27832802

27842803
// CHECK-LABEL: func @vector_insert_const_regression(
@@ -2792,6 +2811,28 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
27922811

27932812
// -----
27942813

2814+
// CHECK-LABEL: @insert_scalar_poison_idx
2815+
func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
2816+
-> vector<4x5xf32> {
2817+
// CHECK-NOT: vector.insert
2818+
// CHECK-NEXT: ub.poison : vector<4x5xf32>
2819+
%0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
2820+
return %0 : vector<4x5xf32>
2821+
}
2822+
2823+
// -----
2824+
2825+
// CHECK-LABEL: @insert_vector_poison_idx
2826+
func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
2827+
-> vector<4x5xf32> {
2828+
// CHECK-NOT: vector.insert
2829+
// CHECK-NEXT: ub.poison : vector<4x5xf32>
2830+
%0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32>
2831+
return %0 : vector<4x5xf32>
2832+
}
2833+
2834+
// -----
2835+
27952836
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
27962837
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
27972838
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ func.func @extract_0d(%arg0: vector<f32>) {
187187

188188
func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
189189
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
190-
%1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32>
190+
%1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>
191191
}
192192

193193
// -----
@@ -247,7 +247,7 @@ func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
247247

248248
func.func @insert_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
249249
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}}
250-
%1 = vector.insert %a, %b[0, 0, -1] : f32 into vector<4x8x16xf32>
250+
%1 = vector.insert %a, %b[0, 0, -5] : f32 into vector<4x8x16xf32>
251251
}
252252

253253
// -----

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,13 @@ func.func @extract_0d(%a: vector<f32>) -> f32 {
247247
return %0 : f32
248248
}
249249

250+
// CHECK-LABEL: @extract_poison_idx
251+
func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 {
252+
// CHECK-NEXT: vector.extract %{{.*}}[-1, 0] : f32 from vector<4x5xf32>
253+
%0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
254+
return %0 : f32
255+
}
256+
250257
// CHECK-LABEL: @insert_element_0d
251258
func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
252259
// CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
@@ -299,6 +306,13 @@ func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f
299306
return %1, %2 : vector<f32>, vector<2x3xf32>
300307
}
301308

309+
// CHECK-LABEL: @insert_poison_idx
310+
func.func @insert_poison_idx(%a: vector<4x5xf32>, %b: f32) {
311+
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[-1, 0] : f32 into vector<4x5xf32>
312+
vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
313+
return
314+
}
315+
302316
// CHECK-LABEL: @outerproduct
303317
func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
304318
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>

0 commit comments

Comments
 (0)