Skip to content

Commit 714aee3

Browse files
authored
[mlir][vector] Add result type to interleave assembly format (#93392)
This is to make it more obvious for what the result type is, especially with some less trivial cases like 0-d inputs resulting in 1-d inputs or interaction with scalable vector types. Note that `vector.deinterleave` uses the same format with explicit result type. Also improve examples and clean up surrounding code.
1 parent 8cdecd4 commit 714aee3

File tree

11 files changed

+62
-70
lines changed

11 files changed

+62
-70
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -480,24 +480,25 @@ def Vector_ShuffleOp :
480480
let hasCanonicalizer = 1;
481481
}
482482

483-
def Vector_InterleaveOp :
484-
Vector_Op<"interleave", [Pure,
485-
AllTypesMatch<["lhs", "rhs"]>,
486-
TypesMatchWith<
483+
def ResultIsDoubleSourceVectorType : TypesMatchWith<
487484
"type of 'result' is double the width of the inputs",
488485
"lhs", "result",
489486
[{
490487
[&]() -> ::mlir::VectorType {
491-
auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
488+
auto vectorType = ::llvm::cast<::mlir::VectorType>($_self);
492489
::mlir::VectorType::Builder builder(vectorType);
493490
if (vectorType.getRank() == 0) {
494-
static constexpr int64_t v2xty_shape[] = { 2 };
495-
return builder.setShape(v2xty_shape);
491+
static constexpr int64_t v2xTyShape[] = {2};
492+
return builder.setShape(v2xTyShape);
496493
}
497494
auto lastDim = vectorType.getRank() - 1;
498495
return builder.setDim(lastDim, vectorType.getDimSize(lastDim) * 2);
499496
}()
500-
}]>]> {
497+
}]>;
498+
499+
def Vector_InterleaveOp :
500+
Vector_Op<"interleave", [Pure, AllTypesMatch<["lhs", "rhs"]>,
501+
ResultIsDoubleSourceVectorType]> {
501502
let summary = "constructs a vector by interleaving two input vectors";
502503
let description = [{
503504
The interleave operation constructs a new vector by interleaving the
@@ -513,24 +514,24 @@ def Vector_InterleaveOp :
513514

514515
Example:
515516
```mlir
516-
%0 = vector.interleave %a, %b
517-
: vector<[4]xi32> ; yields vector<[8]xi32>
518-
%1 = vector.interleave %c, %d
519-
: vector<8xi8> ; yields vector<16xi8>
520-
%2 = vector.interleave %e, %f
521-
: vector<f16> ; yields vector<2xf16>
522-
%3 = vector.interleave %g, %h
523-
: vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64>
524-
%4 = vector.interleave %i, %j
525-
: vector<6x3xf32> ; yields vector<6x6xf32>
517+
%a = arith.constant dense<[0, 1]> : vector<2xi32>
518+
%b = arith.constant dense<[2, 3]> : vector<2xi32>
519+
// The value of `%0` is `[0, 2, 1, 3]`.
520+
%0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32>
521+
522+
// Examples showing allowed input and result types.
523+
%1 = vector.interleave %c, %d : vector<f16> -> vector<2xf16>
524+
%2 = vector.interleave %e, %f : vector<6x3xf32> -> vector<6x6xf32>
525+
%3 = vector.interleave %g, %h : vector<[4]xi32> -> vector<[8]xi32>
526+
%4 = vector.interleave %i, %j : vector<2x4x[2]xf64> -> vector<2x4x[4]xf64>
526527
```
527528
}];
528529

529530
let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs);
530531
let results = (outs AnyVector:$result);
531532

532533
let assemblyFormat = [{
533-
$lhs `,` $rhs attr-dict `:` type($lhs)
534+
$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)
534535
}];
535536

536537
let extraClassDeclaration = [{

mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace {
3030
/// Example:
3131
///
3232
/// ```mlir
33-
/// vector.interleave %a, %b : vector<1x2x3x4xi64>
33+
/// vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
3434
/// ```
3535
/// Would be unrolled to:
3636
/// ```mlir
@@ -39,14 +39,15 @@ namespace {
3939
/// : vector<4xi64> from vector<1x2x3x4xi64> |
4040
/// %1 = vector.extract %b[0, 0, 0] |
4141
/// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for
42-
/// %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions
42+
/// %2 = vector.interleave %0, %1 : | all leading positions
43+
/// : vector<4xi64> -> vector<8xi64> |
4344
/// %3 = vector.insert %2, %result [0, 0, 0] |
4445
/// : vector<8xi64> into vector<1x2x3x8xi64> ┘
4546
/// ```
4647
///
4748
/// Note: If any leading dimension before the `targetRank` is scalable the
4849
/// unrolling will stop before the scalable dimension.
49-
class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
50+
class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
5051
public:
5152
UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
5253
PatternBenefit benefit = 1)
@@ -84,7 +85,7 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
8485
/// Example:
8586
///
8687
/// ```mlir
87-
/// vector.interleave %a, %b : vector<7xi16>
88+
/// vector.interleave %a, %b : vector<7xi16> -> vector<14xi16>
8889
/// ```
8990
///
9091
/// Is rewritten into:
@@ -93,10 +94,8 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
9394
/// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13]
9495
/// : vector<7xi16>, vector<7xi16>
9596
/// ```
96-
class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
97-
public:
98-
InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
99-
: OpRewritePattern(context, benefit) {};
97+
struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
98+
using OpRewritePattern::OpRewritePattern;
10099

101100
LogicalResult matchAndRewrite(vector::InterleaveOp op,
102101
PatternRewriter &rewriter) const override {

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10901090
/// %1 = arith.shli %0, 4 : vector<4xi8>
10911091
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
10921092
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1093-
/// %4 = vector.interleave %2, %3 : vector<4xi8>
1093+
/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
10941094
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
10951095
///
10961096
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1099,7 +1099,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10991099
/// %1 = arith.shli %0, 4 : vector<4xi8>
11001100
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
11011101
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1102-
/// %4 = vector.interleave %2, %3 : vector<4xi8>
1102+
/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
11031103
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
11041104
///
11051105
/// Example (unsigned):
@@ -1108,7 +1108,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
11081108
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
11091109
/// %1 = arith.andi %0, 15 : vector<4xi8>
11101110
/// %2 = arith.shrui %0, 4 : vector<4xi8>
1111-
/// %3 = vector.interleave %1, %2 : vector<4xi8>
1111+
/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
11121112
/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
11131113
///
11141114
template <typename ConversionOpType, bool isSigned>

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2495,54 +2495,50 @@ func.func @vector_interleave_0d(%a: vector<i8>, %b: vector<i8>) -> vector<2xi8>
24952495
// CHECK: %[[RHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<i8> to vector<1xi8>
24962496
// CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS_RANK1]], %[[RHS_RANK1]] [0, 1] : vector<1xi8>
24972497
// CHECK: return %[[ZIP]]
2498-
%0 = vector.interleave %a, %b : vector<i8>
2498+
%0 = vector.interleave %a, %b : vector<i8> -> vector<2xi8>
24992499
return %0 : vector<2xi8>
25002500
}
25012501

25022502
// -----
25032503

25042504
// CHECK-LABEL: @vector_interleave_1d
25052505
// CHECK-SAME: %[[LHS:.*]]: vector<8xf32>, %[[RHS:.*]]: vector<8xf32>)
2506-
func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32>
2507-
{
2506+
func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32> {
25082507
// CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS]], %[[RHS]] [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>
25092508
// CHECK: return %[[ZIP]]
2510-
%0 = vector.interleave %a, %b : vector<8xf32>
2509+
%0 = vector.interleave %a, %b : vector<8xf32> -> vector<16xf32>
25112510
return %0 : vector<16xf32>
25122511
}
25132512

25142513
// -----
25152514

25162515
// CHECK-LABEL: @vector_interleave_1d_scalable
25172516
// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, %[[RHS:.*]]: vector<[4]xi32>)
2518-
func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32>
2519-
{
2517+
func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32> {
25202518
// CHECK: %[[ZIP:.*]] = "llvm.intr.vector.interleave2"(%[[LHS]], %[[RHS]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32>
25212519
// CHECK: return %[[ZIP]]
2522-
%0 = vector.interleave %a, %b : vector<[4]xi32>
2520+
%0 = vector.interleave %a, %b : vector<[4]xi32> -> vector<[8]xi32>
25232521
return %0 : vector<[8]xi32>
25242522
}
25252523

25262524
// -----
25272525

25282526
// CHECK-LABEL: @vector_interleave_2d
25292527
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
2530-
func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
2531-
{
2528+
func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> {
25322529
// CHECK: llvm.shufflevector
25332530
// CHECK-NOT: vector.interleave {{.*}} : vector<2x3xi8>
2534-
%0 = vector.interleave %a, %b : vector<2x3xi8>
2531+
%0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8>
25352532
return %0 : vector<2x6xi8>
25362533
}
25372534

25382535
// -----
25392536

25402537
// CHECK-LABEL: @vector_interleave_2d_scalable
25412538
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
2542-
func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
2543-
{
2539+
func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> {
25442540
// CHECK: llvm.intr.vector.interleave2
25452541
// CHECK-NOT: vector.interleave {{.*}} : vector<2x[8]xi16>
2546-
%0 = vector.interleave %a, %b : vector<2x[8]xi16>
2542+
%0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16>
25472543
return %0 : vector<2x[16]xi16>
25482544
}

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
488488
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
489489
// CHECK: return %[[SHUFFLE]]
490490
func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {
491-
%0 = vector.interleave %a, %b : vector<2xf32>
491+
%0 = vector.interleave %a, %b : vector<2xf32> -> vector<4xf32>
492492
return %0 : vector<4xf32>
493493
}
494494

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,9 +2576,8 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
25762576

25772577
// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
25782578
// CHECK-SAME: %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
2579-
func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
2580-
{
2581-
// CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
2579+
func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64> {
2580+
// CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64> -> vector<2xf64>
25822581
// CHECK: return %[[ZIP]]
25832582
%0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
25842583
return %0 : vector<2xf64>
@@ -2589,7 +2588,7 @@ func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>)
25892588
// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
25902589
// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
25912590
func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
2592-
// CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
2591+
// CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32> -> vector<12xi32>
25932592
// CHECK: return %[[ZIP]]
25942593
%0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
25952594
return %0 : vector<12xi32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,36 +1084,36 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 {
10841084

10851085
// CHECK-LABEL: @interleave_0d
10861086
func.func @interleave_0d(%a: vector<f32>, %b: vector<f32>) -> vector<2xf32> {
1087-
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32>
1088-
%0 = vector.interleave %a, %b : vector<f32>
1087+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32> -> vector<2xf32>
1088+
%0 = vector.interleave %a, %b : vector<f32> -> vector<2xf32>
10891089
return %0 : vector<2xf32>
10901090
}
10911091

10921092
// CHECK-LABEL: @interleave_1d
10931093
func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> {
10941094
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32>
1095-
%0 = vector.interleave %a, %b : vector<4xf32>
1095+
%0 = vector.interleave %a, %b : vector<4xf32> -> vector<8xf32>
10961096
return %0 : vector<8xf32>
10971097
}
10981098

10991099
// CHECK-LABEL: @interleave_1d_scalable
11001100
func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> {
11011101
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16>
1102-
%0 = vector.interleave %a, %b : vector<[8]xi16>
1102+
%0 = vector.interleave %a, %b : vector<[8]xi16> -> vector<[16]xi16>
11031103
return %0 : vector<[16]xi16>
11041104
}
11051105

11061106
// CHECK-LABEL: @interleave_2d
11071107
func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> {
11081108
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32>
1109-
%0 = vector.interleave %a, %b : vector<2x8xf32>
1109+
%0 = vector.interleave %a, %b : vector<2x8xf32> -> vector<2x16xf32>
11101110
return %0 : vector<2x16xf32>
11111111
}
11121112

11131113
// CHECK-LABEL: @interleave_2d_scalable
11141114
func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> {
11151115
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64>
1116-
%0 = vector.interleave %a, %b : vector<2x[2]xf64>
1116+
%0 = vector.interleave %a, %b : vector<2x[2]xf64> -> vector<2x[4]xf64>
11171117
return %0 : vector<2x[4]xf64>
11181118
}
11191119

mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
// CHECK-LABEL: @vector_interleave_2d
44
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
5-
func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
6-
{
5+
func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> {
76
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
87
// CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
98
// CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
@@ -14,14 +13,13 @@ func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vecto
1413
// CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
1514
// CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
1615
// CHECK-NEXT: return %[[RES_1]] : vector<2x6xi8>
17-
%0 = vector.interleave %a, %b : vector<2x3xi8>
16+
%0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8>
1817
return %0 : vector<2x6xi8>
1918
}
2019

2120
// CHECK-LABEL: @vector_interleave_2d_scalable
2221
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
23-
func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
24-
{
22+
func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> {
2523
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
2624
// CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
2725
// CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
@@ -32,7 +30,7 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]
3230
// CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
3331
// CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
3432
// CHECK-NEXT: return %[[RES_1]] : vector<2x[16]xi16>
35-
%0 = vector.interleave %a, %b : vector<2x[8]xi16>
33+
%0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16>
3634
return %0 : vector<2x[16]xi16>
3735
}
3836

@@ -44,17 +42,17 @@ func.func @vector_interleave_4d(%a: vector<1x2x3x4xi64>, %b: vector<1x2x3x4xi64>
4442
// CHECK: %[[RHS_0:.*]] = vector.extract %[[RHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
4543
// CHECK: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]] : vector<4xi64>
4644
// CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<8xi64> into vector<1x2x3x8xi64>
47-
// CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64>
48-
%0 = vector.interleave %a, %b : vector<1x2x3x4xi64>
45+
// CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64> -> vector<8xi64>
46+
%0 = vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
4947
return %0 : vector<1x2x3x8xi64>
5048
}
5149

5250
// CHECK-LABEL: @vector_interleave_nd_with_scalable_dim
53-
func.func @vector_interleave_nd_with_scalable_dim(%a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16>
54-
{
51+
func.func @vector_interleave_nd_with_scalable_dim(
52+
%a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16> {
5553
// The scalable dim blocks unrolling so only the first two dims are unrolled.
5654
// CHECK-COUNT-3: vector.interleave %{{.*}}, %{{.*}} : vector<[2]x2x3x4xf16>
57-
%0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16>
55+
%0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16> -> vector<1x3x[2]x2x3x8xf16>
5856
return %0 : vector<1x3x[2]x2x3x8xf16>
5957
}
6058

mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
22

33
// CHECK-LABEL: @vector_interleave_to_shuffle
4-
func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
5-
{
6-
%0 = vector.interleave %a, %b : vector<7xi16>
4+
func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16> {
5+
%0 = vector.interleave %a, %b : vector<7xi16> -> vector<14xi16>
76
return %0 : vector<14xi16>
87
}
98
// CHECK: vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] : vector<7xi16>, vector<7xi16>

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func.func @entry() {
1717
// CHECK: ( 1, 1, 1, 1
1818
// CHECK: ( 2, 2, 2, 2
1919

20-
%v3 = vector.interleave %v1, %v2 : vector<[4]xf32>
20+
%v3 = vector.interleave %v1, %v2 : vector<[4]xf32> -> vector<[8]xf32>
2121
vector.print %v3 : vector<[8]xf32>
2222
// CHECK: ( 1, 2, 1, 2, 1, 2, 1, 2
2323

mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func.func @entry() {
1616
// CHECK: ( ( 1, 1, 1, 1 ), ( 1, 1, 1, 1 ) )
1717
// CHECK: ( ( 2, 2, 2, 2 ), ( 2, 2, 2, 2 ) )
1818

19-
%v3 = vector.interleave %v1, %v2 : vector<2x4xf32>
19+
%v3 = vector.interleave %v1, %v2 : vector<2x4xf32> -> vector<2x8xf32>
2020
vector.print %v3 : vector<2x8xf32>
2121
// CHECK: ( ( 1, 2, 1, 2, 1, 2, 1, 2 ), ( 1, 2, 1, 2, 1, 2, 1, 2 ) )
2222

0 commit comments

Comments
 (0)