Skip to content

Commit 599a91a

Browse files
angelz913kuhar
andauthored
[mlir][spirv] Support vector.step in vector to spirv conversion (#100651)
Added a conversion pattern and LIT tests for lowering `vector.step` to SPIR-V. Fixes: #100602 --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent ca69f51 commit 599a91a

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,43 @@ struct VectorReductionToFPDotProd final
906906
}
907907
};
908908

909+
struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
910+
using OpConversionPattern::OpConversionPattern;
911+
912+
LogicalResult
913+
matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
914+
ConversionPatternRewriter &rewriter) const override {
915+
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
916+
Type dstType = typeConverter.convertType(stepOp.getType());
917+
if (!dstType)
918+
return failure();
919+
920+
Location loc = stepOp.getLoc();
921+
int64_t numElements = stepOp.getType().getNumElements();
922+
auto intType =
923+
rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
924+
925+
// Input vectors of size 1 are converted to scalars by the type converter.
926+
// We just create a constant in this case.
927+
if (numElements == 1) {
928+
Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
929+
rewriter.replaceOp(stepOp, zero);
930+
return success();
931+
}
932+
933+
SmallVector<Value> source;
934+
source.reserve(numElements);
935+
for (int64_t i = 0; i < numElements; ++i) {
936+
Attribute intAttr = rewriter.getIntegerAttr(intType, i);
937+
Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
938+
source.push_back(constOp);
939+
}
940+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
941+
source);
942+
return success();
943+
}
944+
};
945+
909946
} // namespace
910947
#define CL_INT_MAX_MIN_OPS \
911948
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +966,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
929966
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
930967
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
931968
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
932-
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
933-
typeConverter, patterns.getContext(), PatternBenefit(1));
969+
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
970+
VectorStepOpConvert>(typeConverter, patterns.getContext(),
971+
PatternBenefit(1));
934972

935973
// Make sure that the more specialized dot product pattern has higher benefit
936974
// than the generic one that extracts all elements.

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,32 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
794794

795795
// -----
796796

797+
// CHECK-LABEL: @step()
798+
// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
799+
// CHECK: %[[CST1:.*]] = spirv.Constant 1 : i32
800+
// CHECK: %[[CST2:.*]] = spirv.Constant 2 : i32
801+
// CHECK: %[[CST3:.*]] = spirv.Constant 3 : i32
802+
// CHECK: %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
803+
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
804+
// CHECK: return %[[CAST]] : vector<4xindex>
805+
func.func @step() -> vector<4xindex> {
806+
%0 = vector.step : vector<4xindex>
807+
return %0 : vector<4xindex>
808+
}
809+
810+
// -----
811+
812+
// CHECK-LABEL: @step_size1()
813+
// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
814+
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
815+
// CHECK: return %[[CAST]] : vector<1xindex>
816+
func.func @step_size1() -> vector<1xindex> {
817+
%0 = vector.step : vector<1xindex>
818+
return %0 : vector<1xindex>
819+
}
820+
821+
// -----
822+
797823
module attributes {
798824
spirv.target_env = #spirv.target_env<
799825
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>

0 commit comments

Comments
 (0)