Skip to content

[mlir][spirv] Implement SPIR-V lowering for vector.deinterleave #95313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 13, 2024

Conversation

angelz913
Copy link
Contributor

  1. Added a conversion for vector.deinterleave to the VectorToSPIRV pass.
  2. Added LIT tests for the new conversion.

@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2024

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Angel Zhang (angelz913)

Changes
  1. Added a conversion for vector.deinterleave to the VectorToSPIRV pass.
  2. Added LIT tests for the new conversion.

Full diff: https://github.com/llvm/llvm-project/pull/95313.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+71-3)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+50)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 92168cfa36147..b9a086cfc91a4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -618,6 +618,74 @@ struct VectorInterleaveOpConvert final
   }
 };
 
+struct VectorDeinterleaveOpConvert final
+    : public OpConversionPattern<vector::DeinterleaveOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    // Check the result vector type.
+    VectorType oldResultType = deinterleaveOp.getResultVectorType();
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(deinterleaveOp,
+                                         "unsupported result vector type");
+
+    // Get location.
+    Location loc = deinterleaveOp->getLoc();
+
+    // Deinterleave the indices.
+    VectorType sourceType = deinterleaveOp.getSourceVectorType();
+    int n = sourceType.getNumElements();
+
+    // Output vectors of size 1 are converted to scalars by the type converter.
+    // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
+    // use `spirv::CompositeExtractOp`.
+    if (n == 2) {
+      spirv::CompositeExtractOp compositeExtractZero =
+          rewriter.create<spirv::CompositeExtractOp>(
+              loc, newResultType, adaptor.getSource(),
+              rewriter.getI32ArrayAttr({0}));
+
+      spirv::CompositeExtractOp compositeExtractOne =
+          rewriter.create<spirv::CompositeExtractOp>(
+              loc, newResultType, adaptor.getSource(),
+              rewriter.getI32ArrayAttr({1}));
+
+      rewriter.replaceOp(deinterleaveOp,
+                         {compositeExtractZero, compositeExtractOne});
+      return success();
+    }
+
+    // Indices for `res1`.
+    auto seqEven = llvm::seq<int64_t>(n / 2);
+    auto indicesEven =
+        llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
+
+    // Indices for `res2`.
+    auto seqOdd = llvm::seq<int64_t>(n / 2);
+    auto indicesOdd =
+        llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
+
+    // Create two SPIR-V shuffles.
+    spirv::VectorShuffleOp shuffleEven =
+        rewriter.create<spirv::VectorShuffleOp>(
+            loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+            rewriter.getI32ArrayAttr(indicesEven));
+
+    spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
+        loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+        rewriter.getI32ArrayAttr(indicesOdd));
+
+    // Replace deinterleaveOp with SPIR-V shuffles.
+    rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
+
+    return success();
+  }
+};
+
 struct VectorLoadOpConverter final
     : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -862,9 +930,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
-      VectorStoreOpConverter>(typeConverter, patterns.getContext(),
-                              PatternBenefit(1));
+      VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 2592d0fc04111..87823ab9afc0f 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -507,6 +507,56 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3
 
 // -----
 
+// CHECK-LABEL: func @deinterleave_return0
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+//       CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: return %[[SHUFFLE0]]
+func.func @deinterleave_return0(%a: vector<4xf32>) -> vector<2xf32> {
+  %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_return1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+//       CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: return %[[SHUFFLE1]]
+func.func @deinterleave_return1(%a: vector<4xf32>) -> vector<2xf32> {
+  %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+  return %1 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar_return0
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+//       CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+//       CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+//       CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
+//       CHECK: return %[[RES]]
+func.func @deinterleave_scalar_return0(%a: vector<2xf32>) -> vector<1xf32> {
+  %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+  return %0 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar_return1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+//       CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+//       CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+//       CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
+//       CHECK: return %[[RES]]
+func.func @deinterleave_scalar_return1(%a: vector<2xf32>) -> vector<1xf32> {
+  %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+  return %1 : vector<1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_add
 //  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just some nits

Copy link

github-actions bot commented Jun 12, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff bf7c505847aa58af23f14ee986ee4bb7acf22e62 d1b940b86ab6259a4151f3a2ec961c649b7a953d -- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 3cdd4ee524..58bfc7d280 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -648,10 +648,9 @@ struct VectorDeinterleaveOpConvert final
               loc, newResultType, adaptor.getSource(),
               rewriter.getI32ArrayAttr({0}));
 
-      auto elem1 =
-          rewriter.create<spirv::CompositeExtractOp>(
-              loc, newResultType, adaptor.getSource(),
-              rewriter.getI32ArrayAttr({1}));
+      auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
+          loc, newResultType, adaptor.getSource(),
+          rewriter.getI32ArrayAttr({1}));
 
       rewriter.replaceOp(deinterleaveOp,
                          {compositeExtractZero, compositeExtractOne});

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just 2 nits.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for handling this!

angelz913 and others added 12 commits June 13, 2024 20:21
1. Added a conversion for vector.deinterleave to the VectorToSPIRV pass.
2. Added LIT tests for the new conversion.
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
@angelz913 angelz913 force-pushed the vector-deinterleave-to-spirv branch from e91f266 to 0335c7a Compare June 13, 2024 20:22
@kuhar kuhar merged commit 597cde1 into llvm:main Jun 13, 2024
5 of 6 checks passed
kuhar pushed a commit that referenced this pull request Jun 14, 2024
@angelz913 angelz913 deleted the vector-deinterleave-to-spirv branch June 14, 2024 19:31
EthanLuisMcDonough pushed a commit to EthanLuisMcDonough/llvm-project that referenced this pull request Aug 13, 2024
…vm#95313)

1. Added a conversion for `vector.deinterleave` to the `VectorToSPIRV`
pass.
2. Added LIT tests for the new conversion.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants