Skip to content

Commit eae5ca9

Browse files
authored
[mlir][Vector] Support poison in vector.shuffle mask (#122188)
This PR extends the existing poison support in https://mlir.llvm.org/docs/Dialects/UBOps/ by representing poison mask values in `vector.shuffle`. Similar to LLVM (see https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/Instructions.h#L1884) this requires defining an integer value (`-1`) to represent poison in the `vector.shuffle` mask.
1 parent bb59eb8 commit eae5ca9

File tree

5 files changed

+52
-6
lines changed

5 files changed

+52
-6
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,10 +434,9 @@ def Vector_ShuffleOp
434434
The shuffle operation constructs a permutation (or duplication) of elements
435435
from two input vectors, returning a vector with the same element type as
436436
the input and a length that is the same as the shuffle mask. The two input
437-
vectors must have the same element type, same rank , and trailing dimension
438-
sizes and shuffles their values in the
439-
leading dimension (which may differ in size) according to the given mask.
440-
The legality rules are:
437+
vectors must have the same element type, same rank, and trailing dimension
438+
sizes and shuffles their values in the leading dimension (which may differ
439+
in size) according to the given mask. The legality rules are:
441440
* the two operands must have the same element type as the result
442441
- Either, the two operands and the result must have the same
443442
rank and trailing dimension sizes, viz. given two k-D operands
@@ -448,7 +447,9 @@ def Vector_ShuffleOp
448447
* the mask length equals the leading dimension size of the result
449448
* numbering the input vector indices left to right across the operands, all
450449
mask values must be within range, viz. given two k-D operands v1 and v2
451-
above, all mask values are in the range [0,s_1+t_1)
450+
above, all mask values are in the range [0,s_1+t_1). The value `-1`
451+
represents a poison mask value, which specifies that the selected element
452+
is poison.
452453

453454
Note, scalable vectors are not supported.
454455

@@ -463,10 +464,15 @@ def Vector_ShuffleOp
463464
: vector<2xf32>, vector<2xf32> ; yields vector<4xf32>
464465
%3 = vector.shuffle %a, %b[0, 1]
465466
: vector<f32>, vector<f32> ; yields vector<2xf32>
467+
%4 = vector.shuffle %a, %b[0, 4, -1, -1, -1, -1]
468+
: vector<4xf32>, vector<4xf32> ; yields vector<6xf32>
466469
```
467470
}];
468471

469472
let extraClassDeclaration = [{
473+
// Integer to represent a poison value in a vector shuffle mask.
474+
static constexpr int64_t kMaskPoisonValue = -1;
475+
470476
VectorType getV1VectorType() {
471477
return ::llvm::cast<VectorType>(getV1().getType());
472478
}
@@ -700,6 +706,8 @@ def Vector_ExtractOp :
700706
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
701707
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
702708
```
709+
710+
TODO: Implement support for poison indices.
703711
}];
704712

705713
let arguments = (ins
@@ -890,6 +898,8 @@ def Vector_InsertOp :
890898
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
891899
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
892900
```
901+
902+
TODO: Implement support for poison indices.
893903
}];
894904

895905
let arguments = (ins
@@ -980,6 +990,8 @@ def Vector_ScalableInsertOp :
980990
```mlir
981991
%2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
982992
```
993+
994+
TODO: Implement support for poison indices.
983995
}];
984996

985997
let assemblyFormat = [{
@@ -1031,6 +1043,8 @@ def Vector_ScalableExtractOp :
10311043
```mlir
10321044
%1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
10331045
```
1046+
1047+
TODO: Implement support for poison indices.
10341048
}];
10351049

10361050
let assemblyFormat = [{
@@ -1075,6 +1089,8 @@ def Vector_InsertStridedSliceOp :
10751089
{offsets = [0, 0, 2], strides = [1, 1]}:
10761090
vector<2x4xf32> into vector<16x4x8xf32>
10771091
```
1092+
1093+
TODO: Implement support for poison indices.
10781094
}];
10791095

10801096
let assemblyFormat = [{
@@ -1220,6 +1236,8 @@ def Vector_ExtractStridedSliceOp :
12201236
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
12211237
vector<4x8x16xf32> to vector<2x4x16xf32>
12221238
```
1239+
1240+
TODO: Implement support for poison indices.
12231241
}];
12241242
let builders = [
12251243
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$offsets,

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2600,7 +2600,7 @@ LogicalResult ShuffleOp::verify() {
26002600
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
26012601
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
26022602
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2603-
if (maskPos < 0 || maskPos >= indexSize)
2603+
if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
26042604
return emitOpError("mask index #") << (idx + 1) << " out of range";
26052605
}
26062606
return success();

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,16 @@ func.func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex
11051105

11061106
// -----
11071107

1108+
func.func @shuffle_poison_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<4xf32> {
1109+
%1 = vector.shuffle %arg0, %arg1 [0, -1, 3, -1] : vector<2xf32>, vector<2xf32>
1110+
return %1 : vector<4xf32>
1111+
}
1112+
// CHECK-LABEL: @shuffle_poison_mask(
1113+
// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>)
1114+
// CHECK: %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, -1, 3, -1] : vector<2xf32>
1115+
1116+
// -----
1117+
11081118
func.func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
11091119
%1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
11101120
return %1 : vector<5xf32>

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,17 @@ func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
613613

614614
// -----
615615

616+
// CHECK-LABEL: func @shuffle
617+
// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
618+
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [1 : i32, -1 : i32, 5 : i32, -1 : i32] %[[ARG0]], %[[ARG1]] : vector<4xi32>, vector<4xi32> -> vector<4xi32>
619+
// CHECK: return %[[SHUFFLE]] : vector<4xi32>
620+
func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<4xi32> {
621+
%shuffle = vector.shuffle %v0, %v1 [1, -1, 5, -1] : vector<4xi32>, vector<4xi32>
622+
return %shuffle : vector<4xi32>
623+
}
624+
625+
// -----
626+
616627
// CHECK-LABEL: func @interleave
617628
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
618629
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ func.func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32
190190
return %1 : vector<3x4xf32>
191191
}
192192

193+
// CHECK-LABEL: @shuffle_poison_mask
194+
func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4xf32> {
195+
// CHECK: vector.shuffle %{{.*}}, %{{.*}}[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
196+
%1 = vector.shuffle %a, %a[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
197+
return %1 : vector<4xf32>
198+
}
199+
193200
// CHECK-LABEL: @extract_element_0d
194201
func.func @extract_element_0d(%a: vector<f32>) -> f32 {
195202
// CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>

0 commit comments

Comments
 (0)