@@ -220,6 +220,34 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
220
220
}
221
221
};
222
222
223
+ struct VectorFromElementsOpConvert final
224
+ : public OpConversionPattern<vector::FromElementsOp> {
225
+ using OpConversionPattern::OpConversionPattern;
226
+
227
+ LogicalResult
228
+ matchAndRewrite (vector::FromElementsOp op, OpAdaptor adaptor,
229
+ ConversionPatternRewriter &rewriter) const override {
230
+ Type resultType = getTypeConverter ()->convertType (op.getType ());
231
+ auto elements = op.getElements ();
232
+ if (!resultType)
233
+ return failure ();
234
+ if (isa<spirv::ScalarType>(resultType)) {
235
+ // In the case with a single scalar operand / single-element result,
236
+ // pass through the scalar.
237
+ rewriter.replaceOp (op, elements[0 ]);
238
+ return success ();
239
+ } else if (cast<VectorType>(resultType).getRank () == 1 ) {
240
+ // SPIRVTypeConverter rejects vectors with rank > 1, so the
241
+ // multi-dimensional vector.from_elements cases do not need to be handled,
242
+ // only a simple flat vector.
243
+ rewriter.replaceOpWithNewOp <spirv::CompositeConstructOp>(op, resultType,
244
+ elements);
245
+ return success ();
246
+ }
247
+ return failure ();
248
+ }
249
+ };
250
+
223
251
struct VectorInsertOpConvert final
224
252
: public OpConversionPattern<vector::InsertOp> {
225
253
using OpConversionPattern::OpConversionPattern;
@@ -952,8 +980,9 @@ void mlir::populateVectorToSPIRVPatterns(
952
980
VectorBitcastConvert, VectorBroadcastConvert,
953
981
VectorExtractElementOpConvert, VectorExtractOpConvert,
954
982
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
955
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
956
- VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
983
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
984
+ VectorInsertElementOpConvert, VectorInsertOpConvert,
985
+ VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
957
986
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
958
987
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
959
988
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
0 commit comments