Skip to content

Commit 7deabd4

Browse files
committed
[mlir][Vector] Support poison in vector.shuffle mask
This PR extends the existing poison support in https://mlir.llvm.org/docs/Dialects/UBOps/ by representing poison mask values in `vector.shuffle`. Similar to LLVM (see https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/Instructions.h#L1884) this requires defining an integer value (`-1`) representing poison in the `vector.shuffle` mask. The current implementation parses and prints `-1` for the poison value. I implemented a custom parser/printer to use the `poison` keyword instead but I think it's an overkill to have to introduce a hand-written parsers/printers for every operation supporting poison. I also explored adding new flavors of `DenseIXArrayAttr` that could take an argument to represent the poison value, but I also desisted as the resulting code was too complex. Happy to get feedback about this and improve the assembly format as a follow-up.
1 parent d6315af commit 7deabd4

File tree

5 files changed

+37
-3
lines changed

5 files changed

+37
-3
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def Vector_ShuffleOp
434434
The shuffle operation constructs a permutation (or duplication) of elements
435435
from two input vectors, returning a vector with the same element type as
436436
the input and a length that is the same as the shuffle mask. The two input
437-
vectors must have the same element type, same rank , and trailing dimension
437+
vectors must have the same element type, same rank, and trailing dimension
438438
sizes and shuffles their values in the
439439
leading dimension (which may differ in size) according to the given mask.
440440
The legality rules are:
@@ -448,7 +448,8 @@ def Vector_ShuffleOp
448448
* the mask length equals the leading dimension size of the result
449449
* numbering the input vector indices left to right across the operands, all
450450
mask values must be within range, viz. given two k-D operands v1 and v2
451-
above, all mask values are in the range [0,s_1+t_1)
451+
above, all mask values are in the range [0,s_1+t_1). -1 is used to
452+
represent a poison mask value.
452453

453454
Note, scalable vectors are not supported.
454455

@@ -463,10 +464,15 @@ def Vector_ShuffleOp
463464
: vector<2xf32>, vector<2xf32> ; yields vector<4xf32>
464465
%3 = vector.shuffle %a, %b[0, 1]
465466
: vector<f32>, vector<f32> ; yields vector<2xf32>
467+
%4 = vector.shuffle %a, %b[0, 4, -1, -1, -1, -1]
468+
: vector<4xf32>, vector<4xf32> ; yields vector<6xf32>
466469
```
467470
}];
468471

469472
let extraClassDeclaration = [{
473+
// Integer to represent a poison value in a vector shuffle mask.
474+
static constexpr int64_t kMaskPoisonValue = -1;
475+
470476
VectorType getV1VectorType() {
471477
return ::llvm::cast<VectorType>(getV1().getType());
472478
}

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2600,7 +2600,7 @@ LogicalResult ShuffleOp::verify() {
26002600
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
26012601
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
26022602
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2603-
if (maskPos < 0 || maskPos >= indexSize)
2603+
if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
26042604
return emitOpError("mask index #") << (idx + 1) << " out of range";
26052605
}
26062606
return success();

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,16 @@ func.func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex
11051105

11061106
// -----
11071107

1108+
func.func @shuffle_poison_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<4xf32> {
1109+
%1 = vector.shuffle %arg0, %arg1 [0, -1, 3, -1] : vector<2xf32>, vector<2xf32>
1110+
return %1 : vector<4xf32>
1111+
}
1112+
// CHECK-LABEL: @shuffle_poison_mask(
1113+
// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>)
1114+
// CHECK: %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, -1, 3, -1] : vector<2xf32>
1115+
1116+
// -----
1117+
11081118
func.func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
11091119
%1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
11101120
return %1 : vector<5xf32>

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,17 @@ func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
613613

614614
// -----
615615

616+
// CHECK-LABEL: func @shuffle
617+
// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
618+
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [1 : i32, -1 : i32, 5 : i32, -1 : i32] %[[ARG0]], %[[ARG1]] : vector<4xi32>, vector<4xi32> -> vector<4xi32>
619+
// CHECK: return %[[SHUFFLE]] : vector<4xi32>
620+
func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<4xi32> {
621+
%shuffle = vector.shuffle %v0, %v1 [1, -1, 5, -1] : vector<4xi32>, vector<4xi32>
622+
return %shuffle : vector<4xi32>
623+
}
624+
625+
// -----
626+
616627
// CHECK-LABEL: func @interleave
617628
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
618629
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ func.func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32
190190
return %1 : vector<3x4xf32>
191191
}
192192

193+
// CHECK-LABEL: @shuffle_poison_mask
194+
func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4xf32> {
195+
// CHECK: vector.shuffle %{{.*}}, %{{.*}}[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
196+
%1 = vector.shuffle %a, %a[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
197+
return %1 : vector<4xf32>
198+
}
199+
193200
// CHECK-LABEL: @extract_element_0d
194201
func.func @extract_element_0d(%a: vector<f32>) -> f32 {
195202
// CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>

0 commit comments

Comments
 (0)