-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Add result type to interleave
assembly format
#93392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is to make it more obvious for what the result type is, especially with some less trivial cases like 0-d inputs resulting in 1-d inputs or interaction with scalable vector types. Note that `vector.deinterleave` uses the same format with explicit result type. Also improve examples and clean up surrounding code.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-spirv Author: Jakub Kuderski (kuhar) ChangesThis is to make it more obvious for what the result type is, especially with some less trivial cases like 0-d inputs resulting in 1-d inputs or interaction with scalable vector types. Note that Also improve examples and clean up surrounding code. Patch is 20.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93392.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2bb7540ef0b0f..e043320b56411 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -480,24 +480,25 @@ def Vector_ShuffleOp :
let hasCanonicalizer = 1;
}
-def Vector_InterleaveOp :
- Vector_Op<"interleave", [Pure,
- AllTypesMatch<["lhs", "rhs"]>,
- TypesMatchWith<
+def ResultIsDoubleSourceVectorType : TypesMatchWith<
"type of 'result' is double the width of the inputs",
"lhs", "result",
[{
[&]() -> ::mlir::VectorType {
- auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
+ auto vectorType = ::llvm::cast<::mlir::VectorType>($_self);
::mlir::VectorType::Builder builder(vectorType);
if (vectorType.getRank() == 0) {
- static constexpr int64_t v2xty_shape[] = { 2 };
- return builder.setShape(v2xty_shape);
+ static constexpr int64_t v2xTyShape[] = {2};
+ return builder.setShape(v2xTyShape);
}
auto lastDim = vectorType.getRank() - 1;
return builder.setDim(lastDim, vectorType.getDimSize(lastDim) * 2);
}()
- }]>]> {
+ }]>;
+
+def Vector_InterleaveOp :
+ Vector_Op<"interleave", [Pure, AllTypesMatch<["lhs", "rhs"]>,
+ ResultIsDoubleSourceVectorType]> {
let summary = "constructs a vector by interleaving two input vectors";
let description = [{
The interleave operation constructs a new vector by interleaving the
@@ -513,16 +514,15 @@ def Vector_InterleaveOp :
Example:
```mlir
- %0 = vector.interleave %a, %b
- : vector<[4]xi32> ; yields vector<[8]xi32>
- %1 = vector.interleave %c, %d
- : vector<8xi8> ; yields vector<16xi8>
- %2 = vector.interleave %e, %f
- : vector<f16> ; yields vector<2xf16>
- %3 = vector.interleave %g, %h
- : vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64>
- %4 = vector.interleave %i, %j
- : vector<6x3xf32> ; yields vector<6x6xf32>
+ %a = arith.constant dense<[0, 1]> : vector<2xi32>
+ %b = arith.constant dense<[2, 3]> : vector<2xi32>
+ %0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32>
+ // The value of `%0` is `[0, 2, 1, 3]`.
+
+ %1 = vector.interleave %c, %d : vector<f16> -> vector<2xf16>
+ %2 = vector.interleave %e, %f : vector<6x3xf32> -> vector<6x6xf32>
+ %3 = vector.interleave %g, %h : vector<[4]xi32> -> vector<[8]xi32>
+ %4 = vector.interleave %i, %j : vector<2x4x[2]xf64> -> vector<2x4x[4]xf64>
```
}];
@@ -530,7 +530,7 @@ def Vector_InterleaveOp :
let results = (outs AnyVector:$result);
let assemblyFormat = [{
- $lhs `,` $rhs attr-dict `:` type($lhs)
+ $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)
}];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 5326760c9b4eb..77c97b2f1497c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -30,7 +30,7 @@ namespace {
/// Example:
///
/// ```mlir
-/// vector.interleave %a, %b : vector<1x2x3x4xi64>
+/// vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
/// ```
/// Would be unrolled to:
/// ```mlir
@@ -39,14 +39,15 @@ namespace {
/// : vector<4xi64> from vector<1x2x3x4xi64> |
/// %1 = vector.extract %b[0, 0, 0] |
/// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for
-/// %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions
+/// %2 = vector.interleave %0, %1 : | all leading positions
+/// : vector<4xi64> -> vector<8xi64> |
/// %3 = vector.insert %2, %result [0, 0, 0] |
/// : vector<8xi64> into vector<1x2x3x8xi64> ┘
/// ```
///
/// Note: If any leading dimension before the `targetRank` is scalable the
/// unrolling will stop before the scalable dimension.
-class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
+class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
public:
UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
PatternBenefit benefit = 1)
@@ -84,7 +85,7 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
/// Example:
///
/// ```mlir
-/// vector.interleave %a, %b : vector<7xi16>
+/// vector.interleave %a, %b : vector<7xi16> -> vector<14xi16>
/// ```
///
/// Is rewritten into:
@@ -93,10 +94,8 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
/// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13]
/// : vector<7xi16>, vector<7xi16>
/// ```
-class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
-public:
- InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
- : OpRewritePattern(context, benefit) {};
+struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::InterleaveOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 6025c4ad7c145..59b6cb3ae667a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1090,7 +1090,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.interleave %2, %3 : vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1099,7 +1099,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.interleave %2, %3 : vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
/// Example (unsigned):
@@ -1108,7 +1108,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
/// %1 = arith.andi %0, 15 : vector<4xi8>
/// %2 = arith.shrui %0, 4 : vector<4xi8>
-/// %3 = vector.interleave %1, %2 : vector<4xi8>
+/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
///
template <typename ConversionOpType, bool isSigned>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 439f1e920e392..a7a0ca3d43b01 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2495,7 +2495,7 @@ func.func @vector_interleave_0d(%a: vector<i8>, %b: vector<i8>) -> vector<2xi8>
// CHECK: %[[RHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<i8> to vector<1xi8>
// CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS_RANK1]], %[[RHS_RANK1]] [0, 1] : vector<1xi8>
// CHECK: return %[[ZIP]]
- %0 = vector.interleave %a, %b : vector<i8>
+ %0 = vector.interleave %a, %b : vector<i8> -> vector<2xi8>
return %0 : vector<2xi8>
}
@@ -2503,11 +2503,10 @@ func.func @vector_interleave_0d(%a: vector<i8>, %b: vector<i8>) -> vector<2xi8>
// CHECK-LABEL: @vector_interleave_1d
// CHECK-SAME: %[[LHS:.*]]: vector<8xf32>, %[[RHS:.*]]: vector<8xf32>)
-func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32>
-{
+func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32> {
// CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS]], %[[RHS]] [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>
// CHECK: return %[[ZIP]]
- %0 = vector.interleave %a, %b : vector<8xf32>
+ %0 = vector.interleave %a, %b : vector<8xf32> -> vector<16xf32>
return %0 : vector<16xf32>
}
@@ -2515,11 +2514,10 @@ func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<
// CHECK-LABEL: @vector_interleave_1d_scalable
// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, %[[RHS:.*]]: vector<[4]xi32>)
-func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32>
-{
+func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32> {
// CHECK: %[[ZIP:.*]] = "llvm.intr.vector.interleave2"(%[[LHS]], %[[RHS]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32>
// CHECK: return %[[ZIP]]
- %0 = vector.interleave %a, %b : vector<[4]xi32>
+ %0 = vector.interleave %a, %b : vector<[4]xi32> -> vector<[8]xi32>
return %0 : vector<[8]xi32>
}
@@ -2527,11 +2525,10 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
// CHECK-LABEL: @vector_interleave_2d
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
-func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
-{
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> {
// CHECK: llvm.shufflevector
// CHECK-NOT: vector.interleave {{.*}} : vector<2x3xi8>
- %0 = vector.interleave %a, %b : vector<2x3xi8>
+ %0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8>
return %0 : vector<2x6xi8>
}
@@ -2539,10 +2536,9 @@ func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vecto
// CHECK-LABEL: @vector_interleave_2d_scalable
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
-func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
-{
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> {
// CHECK: llvm.intr.vector.interleave2
// CHECK-NOT: vector.interleave {{.*}} : vector<2x[8]xi16>
- %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+ %0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16>
return %0 : vector<2x[16]xi16>
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index a7542086aa766..b24088d951259 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -488,7 +488,7 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
// CHECK: return %[[SHUFFLE]]
func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {
- %0 = vector.interleave %a, %b : vector<2xf32>
+ %0 = vector.interleave %a, %b : vector<2xf32> -> vector<4xf32>
return %0 : vector<4xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61a5f2a96e1c1..22af91e0eb327 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2576,9 +2576,8 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
// CHECK-SAME: %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
-func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
-{
- // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
+func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64> {
+ // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64> -> vector<2xf64>
// CHECK: return %[[ZIP]]
%0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
return %0 : vector<2xf64>
@@ -2589,7 +2588,7 @@ func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>)
// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
- // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
+ // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32> -> vector<12xi32>
// CHECK: return %[[ZIP]]
%0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
return %0 : vector<12xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 9d8101d3eee97..c868c881d079a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1084,36 +1084,36 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 {
// CHECK-LABEL: @interleave_0d
func.func @interleave_0d(%a: vector<f32>, %b: vector<f32>) -> vector<2xf32> {
- // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32>
- %0 = vector.interleave %a, %b : vector<f32>
+ // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32> -> vector<2xf32>
+ %0 = vector.interleave %a, %b : vector<f32> -> vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: @interleave_1d
func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> {
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32>
- %0 = vector.interleave %a, %b : vector<4xf32>
+ %0 = vector.interleave %a, %b : vector<4xf32> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: @interleave_1d_scalable
func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> {
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16>
- %0 = vector.interleave %a, %b : vector<[8]xi16>
+ %0 = vector.interleave %a, %b : vector<[8]xi16> -> vector<[16]xi16>
return %0 : vector<[16]xi16>
}
// CHECK-LABEL: @interleave_2d
func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> {
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32>
- %0 = vector.interleave %a, %b : vector<2x8xf32>
+ %0 = vector.interleave %a, %b : vector<2x8xf32> -> vector<2x16xf32>
return %0 : vector<2x16xf32>
}
// CHECK-LABEL: @interleave_2d_scalable
func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> {
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64>
- %0 = vector.interleave %a, %b : vector<2x[2]xf64>
+ %0 = vector.interleave %a, %b : vector<2x[2]xf64> -> vector<2x[4]xf64>
return %0 : vector<2x[4]xf64>
}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
index 3dd4857860eb1..598f7d70b4f1b 100644
--- a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
@@ -2,8 +2,7 @@
// CHECK-LABEL: @vector_interleave_2d
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
-func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
-{
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
// CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
// CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
@@ -14,14 +13,13 @@ func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vecto
// CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
// CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
// CHECK-NEXT: return %[[RES_1]] : vector<2x6xi8>
- %0 = vector.interleave %a, %b : vector<2x3xi8>
+ %0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8>
return %0 : vector<2x6xi8>
}
// CHECK-LABEL: @vector_interleave_2d_scalable
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
-func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
-{
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
// CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
// CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
@@ -32,7 +30,7 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]
// CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
// CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
// CHECK-NEXT: return %[[RES_1]] : vector<2x[16]xi16>
- %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+ %0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16>
return %0 : vector<2x[16]xi16>
}
@@ -44,17 +42,17 @@ func.func @vector_interleave_4d(%a: vector<1x2x3x4xi64>, %b: vector<1x2x3x4xi64>
// CHECK: %[[RHS_0:.*]] = vector.extract %[[RHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
// CHECK: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]] : vector<4xi64>
// CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<8xi64> into vector<1x2x3x8xi64>
- // CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64>
- %0 = vector.interleave %a, %b : vector<1x2x3x4xi64>
+ // CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64> -> vector<8xi64>
+ %0 = vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
return %0 : vector<1x2x3x8xi64>
}
// CHECK-LABEL: @vector_interleave_nd_with_scalable_dim
-func.func @vector_interleave_nd_with_scalable_dim(%a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16>
-{
+func.func @vector_interleave_nd_with_scalable_dim(
+ %a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16> {
// The scalable dim blocks unrolling so only the first two dims are unrolled.
// CHECK-COUNT-3: vector.interleave %{{.*}}, %{{.*}} : vector<[2]x2x3x4xf16>
- %0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16>
+ %0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16> -> vector<1x3x[2]x2x3x8xf16>
return %0 : vector<1x3x[2]x2x3x8xf16>
}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
index ed3b3396bf3ea..d59cd4e6765ba 100644
--- a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
+++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
@@ -1,9 +1,8 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
// CHECK-LABEL: @vector_interleave_to_shuffle
-func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
-{
- %0 = vector.interleave %a, %b : vector<7xi16>
+func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16> {
+ %0 = vector.interleave %a, %b : vector<7xi16> -> vector<14xi16>
return %0 : vector<14xi16>
}
// CHECK: vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] : vector<7xi16>, vector<7xi16>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
index 07989bd71f501..e9f1bbeafacdd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
@@ -17,7 +17,7 @@ func.func @entry() {
// CHECK: ( 1, 1, 1, ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM!
I left a couple of nits - feel free to ignore.
This is to make it more obvious for what the result type is, especially with some less trivial cases like 0-d inputs resulting in 1-d inputs or interaction with scalable vector types. Note that
vector.deinterleave
uses the same format with explicit result type.Also improve examples and clean up surrounding code.