Skip to content

Commit 0a2116f

Browse files
authored
[mlir][spirv][vector] Support converting vector.from_elements to SPIR-V (#118540)
Closes #118098.
1 parent 85d15bd commit 0a2116f

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,32 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
220220
}
221221
};
222222

223+
struct VectorFromElementsOpConvert final
224+
: public OpConversionPattern<vector::FromElementsOp> {
225+
using OpConversionPattern::OpConversionPattern;
226+
227+
LogicalResult
228+
matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
229+
ConversionPatternRewriter &rewriter) const override {
230+
Type resultType = getTypeConverter()->convertType(op.getType());
231+
if (!resultType)
232+
return failure();
233+
OperandRange elements = op.getElements();
234+
if (isa<spirv::ScalarType>(resultType)) {
235+
// In the case with a single scalar operand / single-element result,
236+
// pass through the scalar.
237+
rewriter.replaceOp(op, elements[0]);
238+
return success();
239+
}
240+
// SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
241+
// vector.from_elements cases should not need to be handled, only 1d.
242+
assert(cast<VectorType>(resultType).getRank() == 1);
243+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
244+
elements);
245+
return success();
246+
}
247+
};
248+
223249
struct VectorInsertOpConvert final
224250
: public OpConversionPattern<vector::InsertOp> {
225251
using OpConversionPattern::OpConversionPattern;
@@ -952,8 +978,9 @@ void mlir::populateVectorToSPIRVPatterns(
952978
VectorBitcastConvert, VectorBroadcastConvert,
953979
VectorExtractElementOpConvert, VectorExtractOpConvert,
954980
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
955-
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
956-
VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
981+
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
982+
VectorInsertElementOpConvert, VectorInsertOpConvert,
983+
VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
957984
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
958985
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
959986
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,35 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
217217

218218
// -----
219219

220+
// CHECK-LABEL: @from_elements_0d
221+
// CHECK-SAME: %[[ARG0:.+]]: f32
222+
// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
223+
// CHECK: return %[[RETVAL]]
224+
func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
225+
%0 = vector.from_elements %arg0 : vector<f32>
226+
return %0: vector<f32>
227+
}
228+
229+
// CHECK-LABEL: @from_elements_1x
230+
// CHECK-SAME: %[[ARG0:.+]]: f32
231+
// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
232+
// CHECK: return %[[RETVAL]]
233+
func.func @from_elements_1x(%arg0 : f32) -> vector<1xf32> {
234+
%0 = vector.from_elements %arg0 : vector<1xf32>
235+
return %0: vector<1xf32>
236+
}
237+
238+
// CHECK-LABEL: @from_elements_3x
239+
// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
240+
// CHECK: %[[RETVAL:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
241+
// CHECK: return %[[RETVAL]]
242+
func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
243+
%0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
244+
return %0: vector<3xf32>
245+
}
246+
247+
// -----
248+
220249
// CHECK-LABEL: @insert
221250
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
222251
// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>

0 commit comments

Comments
 (0)