-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv][vector] Support converting vector.from_elements to SPIR-V #118540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv][vector] Support converting vector.from_elements to SPIR-V #118540
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-spirv Author: Andrea Faulds (andfau-amd) ChangesFull diff: https://github.com/llvm/llvm-project/pull/118540.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 656b1cb3e99a1d..a2dbbab34c1db7 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -220,6 +220,34 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
}
};
+struct VectorFromElementsOpConvert final
+ : public OpConversionPattern<vector::FromElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type resultType = getTypeConverter()->convertType(op.getType());
+ auto elements = op.getElements();
+ if (!resultType)
+ return failure();
+ if (isa<spirv::ScalarType>(resultType)) {
+ // In the case with a single scalar operand / single-element result,
+ // pass through the scalar.
+ rewriter.replaceOp(op, elements[0]);
+ return success();
+ } else if (cast<VectorType>(resultType).getRank() == 1) {
+ // SPIRVTypeConverter rejects vectors with rank > 1, so the
+ // multi-dimensional vector.from_elements cases do not need to be handled,
+ // only a simple flat vector.
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
+ elements);
+ return success();
+ }
+ return failure();
+ }
+};
+
struct VectorInsertOpConvert final
: public OpConversionPattern<vector::InsertOp> {
using OpConversionPattern::OpConversionPattern;
@@ -952,8 +980,9 @@ void mlir::populateVectorToSPIRVPatterns(
VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
+ VectorInsertElementOpConvert, VectorInsertOpConvert,
+ VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8796f153c4911b..f9dbe527af2c56 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -217,6 +217,25 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
// -----
+// CHECK-LABEL: @from_elements_0d
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+// CHECK: return %[[RETVAL]]
+func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
+ %0 = vector.from_elements %arg0 : vector<f32>
+ return %0: vector<f32>
+}
+
+// CHECK-LABEL: @from_elements_1d
+// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
+// CHECK: spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
+func.func @from_elements_1d(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
+ return %0: vector<3xf32>
+}
+
+// -----
+
// CHECK-LABEL: @insert
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
|
e3c7701
to
567e741
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % one check
567e741
to
f1454e7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Huh, why can't I "Request changes" from myself? GitHub is awkward :/)
✅ With the latest revision this PR passed the C/C++ code formatter. |
f1454e7
to
0c26bd7
Compare
Closes #118098.