-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector][spirv] Lower vector.load and vector.store to SPIR-V #71674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Kai Wang (Hsiangkai) ChangesAdd patterns to lower vector.load to spirv.load and vector.store to spirv.store. Full diff: https://github.com/llvm/llvm-project/pull/71674.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..fab70e03c6a3b39 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -509,6 +509,72 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorLoadOpConverter final
+ : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefType = loadOp.getMemRefType();
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return failure();
+
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = loadOp.getLoc();
+ Value accessChain =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+ adaptor.getIndices(), loc, rewriter);
+ if (!accessChain)
+ return failure();
+
+ spirv::StorageClass storageClass = attr.getValue();
+ auto vectorType = loadOp.getVectorType();
+ auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+ Value castedAccessChain =
+ rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
+ castedAccessChain);
+
+ return success();
+ }
+};
+
+struct VectorStoreOpConverter final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefType = storeOp.getMemRefType();
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return failure();
+
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = storeOp.getLoc();
+ Value accessChain =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+ adaptor.getIndices(), loc, rewriter);
+ if (!accessChain)
+ return failure();
+
+ spirv::StorageClass storageClass = attr.getValue();
+ auto vectorType = storeOp.getVectorType();
+ auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+ Value castedAccessChain =
+ rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
+ adaptor.getValueToStore());
+
+ return success();
+ }
+};
+
struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;
@@ -614,15 +680,16 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
- VectorExtractStridedSliceOpConvert,
- VectorFmaOpConvert<spirv::GLFmaOp>,
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
- VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern>(typeConverter, patterns.getContext());
+ patterns.add<
+ VectorBitcastConvert, VectorBroadcastConvert,
+ VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+ VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
+ VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ typeConverter, patterns.getContext());
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..858ecd759d0972b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -631,3 +631,105 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
%1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
return %1 : vector<1xf32>
}
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ } {
+
+// CHECK-LABEL: @vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32>
+// CHECK: return %[[R0]] : vector<4xf32>
+func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_load_2d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
+// CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32
+// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
+// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
+// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+// CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S8]] : vector<4xf32>
+// CHECK: return %[[R0]] : vector<4xf32>
+func.func @vector_load_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ %0 = vector.load %arg0[%idx_0, %idx_1] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+// CHECK: spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32>
+func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+ %idx = arith.constant 0 : index
+ vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: @vector_store_2d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
+// CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32
+// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
+// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
+// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+// CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+// CHECK: spirv.Store "StorageBuffer" %[[S8]], %[[ARG1]] : vector<4xf32>
+func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ vector.store %arg1, %arg0[%idx_0, %idx_1] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return
+}
+
+} // end module
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice! Just a few nits. Thanks!
Add patterns to lower vector.load to spirv.load and vector.store to spirv.store.
194c8ac
to
0959b03
Compare
Add patterns to lower vector.load to spirv.load and vector.store to spirv.store.