Skip to content

Commit 1fe6568

Browse files
aartbiktensorflower-gardener
authored andcommitted
[VectorOps] Add a ShuffleOp to the VectorOps dialect
For example %0 = vector.shuffle %x, %y [3 : i32, 2 : i32, 1 : i32, 0 : i32] : vector<2xf32>, vector<2xf32> yields a vector<4xf32> result with a permutation of the elements of %x and %y PiperOrigin-RevId: 284657191
1 parent 0e963b9 commit 1fe6568

File tree

5 files changed

+211
-19
lines changed

5 files changed

+211
-19
lines changed

mlir/include/mlir/Dialect/VectorOps/VectorOps.td

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,59 @@ def Vector_BroadcastOp :
214214
}];
215215
}
216216

217+
def Vector_ShuffleOp :
218+
Vector_Op<"shuffle", [NoSideEffect,
219+
PredOpTrait<"first operand v1 and result have same element type",
220+
TCresVTEtIsSameAsOpBase<0, 0>>,
221+
PredOpTrait<"second operand v2 and result have same element type",
222+
TCresVTEtIsSameAsOpBase<0, 1>>]>,
223+
Arguments<(ins AnyVector:$v1, AnyVector:$v2, I32ArrayAttr:$mask)>,
224+
Results<(outs AnyVector:$vector)> {
225+
let summary = "shuffle operation";
226+
let description = [{
227+
The shuffle operation constructs a permutation (or duplication) of elements
228+
from two input vectors, returning a vector with the same element type as
229+
the input and a length that is the same as the shuffle mask. The two input
230+
vectors must have the same element type, rank, and trailing dimension sizes
231+
and shuffles their values in the leading dimension (which may differ in size)
232+
according to the given mask. The legality rules are:
233+
* the two operands must have the same element type as the result
234+
* the two operands and the result must have the same rank and trailing
235+
dimension sizes, viz. given two k-D operands
236+
v1 : <s_1 x s_2 x .. x s_k x type> and
237+
v2 : <t_1 x t_2 x .. x t_k x type>
238+
we have s_i = t_i for all 1 < i <= k
239+
* the mask length equals the leading dimension size of the result
240+
* numbering the input vector indices left to right accross the operands, all
241+
mask values must be within range, viz. given two k-D operands v1 and v2
242+
above, all mask values are in the range [0,s_1+t_1)
243+
244+
Examples:
245+
```
246+
%0 = vector.shuffle %a, %b[0:i32, 3:i32]
247+
: vector<2xf32>, vector<2xf32> ; yields vector<2xf32>
248+
%1 = vector.shuffle %c, %b[0:i32, 1:i32, 2:i32]
249+
: vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32>
250+
%2 = vector.shuffle %a, %b[3:i32, 2:i32, 1:i32 : 0:i32]
251+
: vector<2xf32>, vector<2xf32> ; yields vector<4xf32>
252+
253+
```
254+
}];
255+
let builders = [OpBuilder<"Builder *builder, OperationState &result, Value *v1, Value *v2, ArrayRef<int32_t>">];
256+
let extraClassDeclaration = [{
257+
static StringRef getMaskAttrName() { return "mask"; }
258+
VectorType getV1VectorType() {
259+
return v1()->getType().cast<VectorType>();
260+
}
261+
VectorType getV2VectorType() {
262+
return v2()->getType().cast<VectorType>();
263+
}
264+
VectorType getVectorType() {
265+
return vector()->getType().cast<VectorType>();
266+
}
267+
}];
268+
}
269+
217270
def Vector_ExtractOp :
218271
Vector_Op<"extract", [NoSideEffect,
219272
PredOpTrait<"operand and result have same element type",

mlir/lib/Dialect/VectorOps/VectorOps.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,92 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser,
458458
parser.addTypeToList(vectorType, result.types));
459459
}
460460

461+
//===----------------------------------------------------------------------===//
462+
// ShuffleOp
463+
//===----------------------------------------------------------------------===//
464+
465+
void ShuffleOp::build(Builder *builder, OperationState &result, Value *v1,
466+
Value *v2, ArrayRef<int32_t> mask) {
467+
result.addOperands({v1, v2});
468+
auto maskAttr = builder->getI32ArrayAttr(mask);
469+
result.addTypes(v1->getType());
470+
result.addAttribute(getMaskAttrName(), maskAttr);
471+
}
472+
473+
static void print(OpAsmPrinter &p, ShuffleOp op) {
474+
p << op.getOperationName() << " " << *op.v1() << ", " << *op.v2() << " "
475+
<< op.mask();
476+
p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()});
477+
p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
478+
}
479+
480+
static LogicalResult verify(ShuffleOp op) {
481+
VectorType resultType = op.getVectorType();
482+
VectorType v1Type = op.getV1VectorType();
483+
VectorType v2Type = op.getV2VectorType();
484+
// Verify ranks.
485+
int64_t resRank = resultType.getRank();
486+
int64_t v1Rank = v1Type.getRank();
487+
int64_t v2Rank = v2Type.getRank();
488+
if (resRank != v1Rank || v1Rank != v2Rank)
489+
return op.emitOpError("rank mismatch");
490+
// Verify all but leading dimension sizes.
491+
for (int64_t r = 1; r < v1Rank; ++r) {
492+
int64_t resDim = resultType.getDimSize(r);
493+
int64_t v1Dim = v1Type.getDimSize(r);
494+
int64_t v2Dim = v2Type.getDimSize(r);
495+
if (resDim != v1Dim || v1Dim != v2Dim)
496+
return op.emitOpError("dimension mismatch");
497+
}
498+
// Verify mask length.
499+
auto maskAttr = op.mask().getValue();
500+
int64_t maskLength = maskAttr.size();
501+
if (maskLength != resultType.getDimSize(0))
502+
return op.emitOpError("mask length mismatch");
503+
// Verify all indices.
504+
int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
505+
for (auto en : llvm::enumerate(maskAttr)) {
506+
auto attr = en.value().dyn_cast<IntegerAttr>();
507+
if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
508+
return op.emitOpError("mask index #")
509+
<< (en.index() + 1) << " out of range";
510+
}
511+
return success();
512+
}
513+
514+
static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
515+
OpAsmParser::OperandType v1, v2;
516+
Attribute attr;
517+
VectorType v1Type, v2Type;
518+
if (parser.parseOperand(v1) || parser.parseComma() ||
519+
parser.parseOperand(v2) ||
520+
parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
521+
result.attributes) ||
522+
parser.parseOptionalAttrDict(result.attributes) ||
523+
parser.parseColonType(v1Type) || parser.parseComma() ||
524+
parser.parseType(v2Type) ||
525+
parser.resolveOperand(v1, v1Type, result.operands) ||
526+
parser.resolveOperand(v2, v2Type, result.operands))
527+
return failure();
528+
// Construct resulting type: leading dimension matches mask length,
529+
// all trailing dimensions match the operands.
530+
auto maskAttr = attr.dyn_cast<ArrayAttr>();
531+
if (!maskAttr)
532+
return parser.emitError(parser.getNameLoc(), "missing mask attribute");
533+
int64_t maskLength = maskAttr.size();
534+
if (maskLength <= 0)
535+
return parser.emitError(parser.getNameLoc(), "invalid mask length");
536+
int64_t v1Rank = v1Type.getRank();
537+
SmallVector<int64_t, 4> shape;
538+
shape.reserve(v1Rank);
539+
shape.push_back(maskLength);
540+
for (int64_t r = 1; r < v1Rank; ++r)
541+
shape.push_back(v1Type.getDimSize(r));
542+
VectorType resType = VectorType::get(shape, v1Type.getElementType());
543+
parser.addTypeToList(resType, result.types);
544+
return success();
545+
}
546+
461547
//===----------------------------------------------------------------------===//
462548
// InsertOp
463549
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,18 +235,18 @@ func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32>
235235
return %0 : vector<3x16xf32>
236236
}
237237
// CHECK-LABEL: extract_vec_2d_from_vec_3d
238-
// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
239-
// CHECK: llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]">
238+
// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
239+
// CHECK: llvm.return {{.*}} : !llvm<"[3 x <16 x float>]">
240240

241241
func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
242242
%0 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
243243
return %0 : f32
244244
}
245245
// CHECK-LABEL: extract_element_from_vec_3d
246-
// CHECK: llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
246+
// CHECK: llvm.extractvalue {{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
247247
// CHECK: llvm.mlir.constant(0 : i32) : !llvm.i32
248-
// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
249-
// CHECK: llvm.return %{{.*}} : !llvm.float
248+
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
249+
// CHECK: llvm.return {{.*}} : !llvm.float
250250

251251
func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
252252
%0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>

mlir/test/Dialect/VectorOps/invalid.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,41 @@ func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
3131

3232
// -----
3333

34+
func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) {
35+
// expected-error@+1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}}
36+
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xi32>
37+
}
38+
39+
// -----
40+
41+
func @shuffle_rank_mismatch(%arg0: vector<2xf32>, %arg1: vector<4x2xf32>) {
42+
// expected-error@+1 {{'vector.shuffle' op rank mismatch}}
43+
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<4x2xf32>
44+
}
45+
46+
// -----
47+
48+
func @shuffle_trailing_dim_size_mismatch(%arg0: vector<2x2xf32>, %arg1: vector<2x4xf32>) {
49+
// expected-error@+1 {{'vector.shuffle' op dimension mismatch}}
50+
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2x2xf32>, vector<2x4xf32>
51+
}
52+
53+
// -----
54+
55+
func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
56+
// expected-error@+1 {{'vector.shuffle' op mask index #2 out of range}}
57+
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 4 : i32] : vector<2xf32>, vector<2xf32>
58+
}
59+
60+
// -----
61+
62+
func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
63+
// expected-error@+1 {{custom op 'vector.shuffle' invalid mask length}}
64+
%1 = vector.shuffle %arg0, %arg1 [] : vector<2xf32>, vector<2xf32>
65+
}
66+
67+
// -----
68+
3469
func @extract_vector_type(%arg0: index) {
3570
// expected-error@+1 {{expected vector type}}
3671
%1 = vector.extract %arg0[] : index

mlir/test/Dialect/VectorOps/ops.mlir

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,38 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
2424

2525
// CHECK-LABEL: @vector_broadcast
2626
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
27-
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
27+
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
2828
%0 = vector.broadcast %a : f32 to vector<16xf32>
29-
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
29+
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
3030
%1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
31-
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
31+
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
3232
%2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32>
33-
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
33+
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
3434
%3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32>
3535
return %3 : vector<8x16xf32>
3636
}
3737

38+
// CHECK-LABEL: @shuffle1D
39+
func @shuffle1D(%a: vector<2xf32>, %b: vector<4xf32>) -> vector<2xf32> {
40+
// CHECK: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32, 3 : i32] : vector<2xf32>, vector<2xf32>
41+
%1 = vector.shuffle %a, %a[0 : i32, 1 : i32, 2: i32, 3 : i32] : vector<2xf32>, vector<2xf32>
42+
// CHECK-NEXT: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32] : vector<4xf32>, vector<4xf32>
43+
%2 = vector.shuffle %1, %b[0 : i32, 1 : i32, 2 : i32] : vector<4xf32>, vector<4xf32>
44+
// CHECK-NEXT: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 6 : i32] : vector<3xf32>, vector<4xf32>
45+
%3 = vector.shuffle %2, %b[0 : i32, 6 : i32] : vector<3xf32>, vector<4xf32>
46+
return %3 : vector<2xf32>
47+
}
48+
49+
// CHECK-LABEL: @shuffle2D
50+
func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
51+
// CHECK: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32] : vector<1x4xf32>, vector<2x4xf32>
52+
%1 = vector.shuffle %a, %b[0 : i32, 1 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
53+
return %1 : vector<3x4xf32>
54+
}
55+
3856
// CHECK-LABEL: @extract
3957
func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
40-
// CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32>
58+
// CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32>
4159
%1 = vector.extract %arg0[3 : i32] : vector<4x8x16xf32>
4260
// CHECK-NEXT: vector.extract {{.*}}[3 : i32, 3 : i32] : vector<4x8x16xf32>
4361
%2 = vector.extract %arg0[3 : i32, 3 : i32] : vector<4x8x16xf32>
@@ -47,35 +65,35 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f
4765
}
4866

4967
// CHECK-LABEL: @insert
50-
func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) {
51-
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
68+
func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
69+
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
5270
%1 = vector.insert %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
53-
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
71+
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
5472
%2 = vector.insert %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
55-
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
73+
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
5674
%3 = vector.insert %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
57-
return
75+
return %3 : vector<4x8x16xf32>
5876
}
5977

6078
// CHECK-LABEL: @outerproduct
6179
func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
62-
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
80+
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
6381
%0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
64-
// CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32>
82+
// CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32>
6583
%1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32>
6684
return %1 : vector<4x8xf32>
6785
}
6886

6987
// CHECK-LABEL: @insert_strided_slice
7088
func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
71-
// CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
89+
// CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
7290
%1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
7391
return
7492
}
7593

7694
// CHECK-LABEL: @strided_slice
7795
func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
78-
// CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
96+
// CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
7997
%1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
8098
return %1: vector<2x2x16xf32>
8199
}

0 commit comments

Comments
 (0)