@@ -906,6 +906,43 @@ struct VectorReductionToFPDotProd final
906
906
}
907
907
};
908
908
909
+ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
910
+ using OpConversionPattern::OpConversionPattern;
911
+
912
+ LogicalResult
913
+ matchAndRewrite (vector::StepOp stepOp, OpAdaptor adaptor,
914
+ ConversionPatternRewriter &rewriter) const override {
915
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
916
+ Type dstType = typeConverter.convertType (stepOp.getType ());
917
+ if (!dstType)
918
+ return failure ();
919
+
920
+ Location loc = stepOp.getLoc ();
921
+ int64_t numElements = stepOp.getType ().getNumElements ();
922
+ auto intType =
923
+ rewriter.getIntegerType (typeConverter.getIndexTypeBitwidth ());
924
+
925
+ // Input vectors of size 1 are converted to scalars by the type converter.
926
+ // We just create a constant in this case.
927
+ if (numElements == 1 ) {
928
+ Value zero = spirv::ConstantOp::getZero (intType, loc, rewriter);
929
+ rewriter.replaceOp (stepOp, zero);
930
+ return success ();
931
+ }
932
+
933
+ SmallVector<Value> source;
934
+ source.reserve (numElements);
935
+ for (int64_t i = 0 ; i < numElements; ++i) {
936
+ Attribute intAttr = rewriter.getIntegerAttr (intType, i);
937
+ Value constOp = rewriter.create <spirv::ConstantOp>(loc, intType, intAttr);
938
+ source.push_back (constOp);
939
+ }
940
+ rewriter.replaceOpWithNewOp <spirv::CompositeConstructOp>(stepOp, dstType,
941
+ source);
942
+ return success ();
943
+ }
944
+ };
945
+
909
946
} // namespace
910
947
#define CL_INT_MAX_MIN_OPS \
911
948
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +966,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
929
966
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
930
967
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
931
968
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
932
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
933
- typeConverter, patterns.getContext (), PatternBenefit (1 ));
969
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
970
+ VectorStepOpConvert>(typeConverter, patterns.getContext (),
971
+ PatternBenefit (1 ));
934
972
935
973
// Make sure that the more specialized dot product pattern has higher benefit
936
974
// than the generic one that extracts all elements.
0 commit comments