-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Support vector.step
in vector to spirv conversion
#100651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Angel Zhang (angelz913) ChangesAdded a conversion pattern and LIT tests for lowering Full diff: https://github.com/llvm/llvm-project/pull/100651.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..8b5789f9e8497 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final
}
};
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Type dstType = typeConverter.convertType(stepOp.getType());
+ if (!dstType)
+ return failure();
+
+ Location loc = stepOp.getLoc();
+ int64_t numElements = stepOp.getType().getNumElements();
+ auto intType =
+ rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+ // Input vectors of size 1 are converted to scalars by the type converter.
+ // We just create a constant in this case.
+ if (numElements == 1) {
+ Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+ rewriter.replaceOp(stepOp, zero);
+ return success();
+ }
+
+ SmallVector<Value> source;
+ for (int64_t i = 0; i < numElements; ++i) {
+ Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+ Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+ source.push_back(constOp);
+ }
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+ source);
+ return success();
+ }
+};
+
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
- typeConverter, patterns.getContext(), PatternBenefit(1));
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+ VectorStepOpConvert>(typeConverter, patterns.getContext(),
+ PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..016c9e141a712 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
// -----
+// CHECK-LABEL: @step
+// CHECK-SAME: ()
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[CST1:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[CST2:.*]] = spirv.Constant 2 : i32
+// CHECK: %[[CST3:.*]] = spirv.Constant 3 : i32
+// CHECK: %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+// CHECK: return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+ %0 = vector.step : vector<4xindex>
+ return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1
+// CHECK-SAME: ()
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+// CHECK: return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+ %0 = vector.step : vector<1xindex>
+ return %0 : vector<1xindex>
+}
+
+// -----
+
module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
|
@llvm/pr-subscribers-mlir-spirv Author: Angel Zhang (angelz913) ChangesAdded a conversion pattern and LIT tests for lowering Full diff: https://github.com/llvm/llvm-project/pull/100651.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..8b5789f9e8497 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final
}
};
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Type dstType = typeConverter.convertType(stepOp.getType());
+ if (!dstType)
+ return failure();
+
+ Location loc = stepOp.getLoc();
+ int64_t numElements = stepOp.getType().getNumElements();
+ auto intType =
+ rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+ // Input vectors of size 1 are converted to scalars by the type converter.
+ // We just create a constant in this case.
+ if (numElements == 1) {
+ Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+ rewriter.replaceOp(stepOp, zero);
+ return success();
+ }
+
+ SmallVector<Value> source;
+ for (int64_t i = 0; i < numElements; ++i) {
+ Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+ Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+ source.push_back(constOp);
+ }
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+ source);
+ return success();
+ }
+};
+
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
- typeConverter, patterns.getContext(), PatternBenefit(1));
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+ VectorStepOpConvert>(typeConverter, patterns.getContext(),
+ PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..016c9e141a712 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
// -----
+// CHECK-LABEL: @step
+// CHECK-SAME: ()
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[CST1:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[CST2:.*]] = spirv.Constant 2 : i32
+// CHECK: %[[CST3:.*]] = spirv.Constant 3 : i32
+// CHECK: %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+// CHECK: return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+ %0 = vector.step : vector<4xindex>
+ return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1
+// CHECK-SAME: ()
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+// CHECK: return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+ %0 = vector.step : vector<1xindex>
+ return %0 : vector<1xindex>
+}
+
+// -----
+
module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nits
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Added a conversion pattern and LIT tests for lowering
vector.step
to SPIR-V.Fixes: #100602