-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Implement SPIR-V lowering for vector.deinterleave
#95313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Angel Zhang (angelz913) Changes
Full diff: https://github.com/llvm/llvm-project/pull/95313.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 92168cfa36147..b9a086cfc91a4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -618,6 +618,74 @@ struct VectorInterleaveOpConvert final
}
};
+struct VectorDeinterleaveOpConvert final
+ : public OpConversionPattern<vector::DeinterleaveOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Check the result vector type.
+ VectorType oldResultType = deinterleaveOp.getResultVectorType();
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(deinterleaveOp,
+ "unsupported result vector type");
+
+ // Get location.
+ Location loc = deinterleaveOp->getLoc();
+
+ // Deinterleave the indices.
+ VectorType sourceType = deinterleaveOp.getSourceVectorType();
+ int n = sourceType.getNumElements();
+
+ // Output vectors of size 1 are converted to scalars by the type converter.
+ // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
+ // use `spirv::CompositeExtractOp`.
+ if (n == 2) {
+ spirv::CompositeExtractOp compositeExtractZero =
+ rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, adaptor.getSource(),
+ rewriter.getI32ArrayAttr({0}));
+
+ spirv::CompositeExtractOp compositeExtractOne =
+ rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, adaptor.getSource(),
+ rewriter.getI32ArrayAttr({1}));
+
+ rewriter.replaceOp(deinterleaveOp,
+ {compositeExtractZero, compositeExtractOne});
+ return success();
+ }
+
+ // Indices for `res1`.
+ auto seqEven = llvm::seq<int64_t>(n / 2);
+ auto indicesEven =
+ llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
+
+ // Indices for `res2`.
+ auto seqOdd = llvm::seq<int64_t>(n / 2);
+ auto indicesOdd =
+ llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
+
+ // Create two SPIR-V shuffles.
+ spirv::VectorShuffleOp shuffleEven =
+ rewriter.create<spirv::VectorShuffleOp>(
+ loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+ rewriter.getI32ArrayAttr(indicesEven));
+
+ spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
+ loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+ rewriter.getI32ArrayAttr(indicesOdd));
+
+ // Replace deinterleaveOp with SPIR-V shuffles.
+ rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
+
+ return success();
+ }
+};
+
struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -862,9 +930,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
- VectorStoreOpConverter>(typeConverter, patterns.getContext(),
- PatternBenefit(1));
+ VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 2592d0fc04111..87823ab9afc0f 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -507,6 +507,56 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3
// -----
+// CHECK-LABEL: func @deinterleave_return0
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+// CHECK: return %[[SHUFFLE0]]
+func.func @deinterleave_return0(%a: vector<4xf32>) -> vector<2xf32> {
+ %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_return1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+// CHECK: return %[[SHUFFLE1]]
+func.func @deinterleave_return1(%a: vector<4xf32>) -> vector<2xf32> {
+ %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+ return %1 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar_return0
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
+// CHECK: return %[[RES]]
+func.func @deinterleave_scalar_return0(%a: vector<2xf32>) -> vector<1xf32> {
+ %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+ return %0 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar_return1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
+// CHECK: return %[[RES]]
+func.func @deinterleave_scalar_return1(%a: vector<2xf32>) -> vector<1xf32> {
+ %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_add
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just some nits
✅ With the latest revision this PR passed the C/C++ code formatter. |
You can test this locally with the following command:git-clang-format --diff bf7c505847aa58af23f14ee986ee4bb7acf22e62 d1b940b86ab6259a4151f3a2ec961c649b7a953d -- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp View the diff from clang-format here.diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 3cdd4ee524..58bfc7d280 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -648,10 +648,9 @@ struct VectorDeinterleaveOpConvert final
loc, newResultType, adaptor.getSource(),
rewriter.getI32ArrayAttr({0}));
- auto elem1 =
- rewriter.create<spirv::CompositeExtractOp>(
- loc, newResultType, adaptor.getSource(),
- rewriter.getI32ArrayAttr({1}));
+ auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, adaptor.getSource(),
+ rewriter.getI32ArrayAttr({1}));
rewriter.replaceOp(deinterleaveOp,
{compositeExtractZero, compositeExtractOne});
|
997bb08
to
ddf3125
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just 2 nits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for handling this!
1. Added a conversion for vector.deinterleave to the VectorToSPIRV pass. 2. Added LIT tests for the new conversion.
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
e91f266
to
0335c7a
Compare
This commit is dependent on #95313.
…vm#95313) 1. Added a conversion for `vector.deinterleave` to the `VectorToSPIRV` pass. 2. Added LIT tests for the new conversion. --------- Co-authored-by: Jakub Kuderski <[email protected]>
vector.deinterleave
to theVectorToSPIRV
pass.