Skip to content

Commit e3c7701

Browse files
committed
[mlir][spirv][vector] Support converting vector.from_elements to SPIR-V
1 parent ac7fe42 commit e3c7701

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,34 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
220220
}
221221
};
222222

223+
struct VectorFromElementsOpConvert final
224+
: public OpConversionPattern<vector::FromElementsOp> {
225+
using OpConversionPattern::OpConversionPattern;
226+
227+
LogicalResult
228+
matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
229+
ConversionPatternRewriter &rewriter) const override {
230+
Type resultType = getTypeConverter()->convertType(op.getType());
231+
auto elements = op.getElements();
232+
if (!resultType)
233+
return failure();
234+
if (isa<spirv::ScalarType>(resultType)) {
235+
// In the case with a single scalar operand / single-element result,
236+
// pass through the scalar.
237+
rewriter.replaceOp(op, elements[0]);
238+
return success();
239+
} else if (cast<VectorType>(resultType).getRank() == 1) {
240+
// SPIRVTypeConverter rejects vectors with rank > 1, so the
241+
// multi-dimensional vector.from_elements cases do not need to be handled,
242+
// only a simple flat vector.
243+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
244+
elements);
245+
return success();
246+
}
247+
return failure();
248+
}
249+
};
250+
223251
struct VectorInsertOpConvert final
224252
: public OpConversionPattern<vector::InsertOp> {
225253
using OpConversionPattern::OpConversionPattern;
@@ -952,8 +980,9 @@ void mlir::populateVectorToSPIRVPatterns(
952980
VectorBitcastConvert, VectorBroadcastConvert,
953981
VectorExtractElementOpConvert, VectorExtractOpConvert,
954982
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
955-
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
956-
VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
983+
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
984+
VectorInsertElementOpConvert, VectorInsertOpConvert,
985+
VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
957986
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
958987
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
959988
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,25 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
217217

218218
// -----
219219

220+
// CHECK-LABEL: @from_elements_0d
221+
// CHECK-SAME: %[[ARG0:.+]]: f32
222+
// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
223+
// CHECK: return %[[RETVAL]]
224+
func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
225+
%0 = vector.from_elements %arg0 : vector<f32>
226+
return %0: vector<f32>
227+
}
228+
229+
// CHECK-LABEL: @from_elements_1d
230+
// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
231+
// CHECK: spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
232+
func.func @from_elements_1d(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
233+
%0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
234+
return %0: vector<3xf32>
235+
}
236+
237+
// -----
238+
220239
// CHECK-LABEL: @insert
221240
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
222241
// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>

0 commit comments

Comments
 (0)