-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector][spirv] Lower vector.transfer_read and vector.transfer_write to SPIR-V #69708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Kai Wang (Hsiangkai) ChangesAdd patterns to lower vector.transfer_read to spirv.load and vector.transfer_write to spirv.store. Full diff: https://github.com/llvm/llvm-project/pull/69708.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..9b320e49996e0a3 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -509,6 +509,87 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorTransferReadOpConverter final
+ : public OpConversionPattern<vector::TransferReadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferReadOp transferReadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (transferReadOp.getMask())
+ return rewriter.notifyMatchFailure(transferReadOp,
+ "unsupported transfer_read with mask");
+ auto sourceType = transferReadOp.getSource().getType();
+ if (!llvm::isa<MemRefType>(sourceType))
+ return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
+
+ auto memrefType = cast<MemRefType>(sourceType);
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = transferReadOp.getLoc();
+ Value accessChain =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(),
+ adaptor.getIndices(), loc, rewriter);
+ if (!accessChain)
+ return failure();
+
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return failure();
+
+ spirv::StorageClass storageClass = attr.getValue();
+ auto vectorType = transferReadOp.getVectorType();
+ auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+ Value castedAccessChain =
+ rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(transferReadOp, vectorType,
+ castedAccessChain);
+
+ return success();
+ }
+};
+
+struct VectorTransferWriteOpConverter final
+ : public OpConversionPattern<vector::TransferWriteOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferWriteOp transferWriteOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (transferWriteOp.getMask())
+ return rewriter.notifyMatchFailure(
+ transferWriteOp, "unsupported transfer_write with mask");
+ auto sourceType = transferWriteOp.getSource().getType();
+ if (!llvm::isa<MemRefType>(sourceType))
+ return rewriter.notifyMatchFailure(transferWriteOp,
+ "not a memref source");
+
+ auto memrefType = cast<MemRefType>(sourceType);
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = transferWriteOp.getLoc();
+ Value accessChain =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(),
+ adaptor.getIndices(), loc, rewriter);
+ if (!accessChain)
+ return failure();
+
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return failure();
+
+ spirv::StorageClass storageClass = attr.getValue();
+ auto vectorType = transferWriteOp.getVectorType();
+ auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+ Value castedAccessChain =
+ rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+ transferWriteOp, castedAccessChain, adaptor.getVector());
+
+ return success();
+ }
+};
+
struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;
@@ -622,7 +703,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern>(typeConverter, patterns.getContext());
+ VectorSplatPattern, VectorTransferReadOpConverter,
+ VectorTransferWriteOpConverter>(typeConverter,
+ patterns.getContext());
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..c6e7b6519ad8517 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -631,3 +631,53 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
%1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
return %1 : vector<1xf32>
}
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ } {
+
+// CHECK-LABEL: @transfer_read
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]]
+// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]]
+// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32>
+// CHECK: return %[[R0]]
+func.func @transfer_read(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %0 = vector.transfer_read %arg0[%idx], %cst_0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @transfer_write
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]]
+// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]]
+// CHECK: spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32>
+func.func @transfer_write(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+ %idx = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%idx] : vector<4xf32>, memref<4xf32, #spirv.storage_class<StorageBuffer>>
+ return
+}
+
+} // end module
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patterns. Actually I'm wondering whether we need them--vector transfer ops are high level abstractions; they provide powerful mechansisms with padding, in bounds, and permutation maps. We'd typically want to first lower these transfer ops into vector.load/store or memref.load/store and then lower them to SPIR-V. Have you considered that paths?
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); | ||
auto loc = transferReadOp.getLoc(); | ||
Value accessChain = | ||
spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only works when the permutation_map
is an identity map? I don't see checks to exclude other cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I revoked the fix. I am still thinking why we need identity map here. Can you give more information about it?
2c5157a
to
57ec4c0
Compare
There is a description in the document. That is, "A vector.transfer_read can be lowered to a simple load if all dimensions are specified to be within bounds and no mask was specified". Under these conditions, does it make sense to convert transfer_read/transfer_write to SPIR-V directly? I have not tried the way to convert transfer_read to memref.load though. |
57ec4c0
to
37e74d2
Compare
What's the benefit of skipping memref loads/stores? Seems like this will complicate lowering and disallow us to run target-independent transforms at the level of memref ops, e.g., narrow type emulation. |
…write to SPIR-V Add patterns to lower vector.transfer_read to spirv.load and vector.transfer_write to spirv.store.
37e74d2
to
c853ef9
Compare
Yup. That "simple load" is modelled as Right now we are missing vector.load/store to spirv lowering. For the cases you want to cover, I think it's worth adding them instead, which is much simpler (as we don't need to verify all those out of bounds, padding, indexing map, etc.). There might be a place where we want to introduce vector transfer ops to spirv directly--for actually powerful cases not representable by vector.load/store and memref.load/store, but we have direct hardware support that sort of exposed by spirv ops. That is a sign of maybe we want to discuss the representation at vector/memref load/store level; and that can take time. In the meanwhile it makes sense to connect directly there to keep things going. Hopefully this explains the rationale. Again, thanks for the contribution! :) |
Hi @kuhar and @antiagainst, Sorry for late reply and thanks for your information. I found there are patterns to lower transfer_read to memref::Load in convert-vector-to-gpu pass. I tried to add it to my pipeline but it doesn't work. I also found there are some patterns to convert vector::TransferReadOp to nvgpu::LdMatrixOp. It is a low level dialect. That's why I think it may make sense to convert transfer_read to SPIR-V directly. Thanks for providing the rationale to me. I will take a look into LowerVectorTransfer.cpp and VectorEmulateNarrowType.cpp and figure out how to convert transfer_read to memref::Load first in my pipeline. Thank you very much! :) |
Depends on your use case, converting to memref.load may not out directly. If you want to read a vector out of a scalar memref, you may need to connect the flow from vector transfer ops to vector.load and then add the missing pattern from vector.load to spirv.load. |
Thanks a lot. I can lower to vector.load/maskedload now. I will add patterns from vector.maskedload and vector.load to spirv.load later. I close this PR first. |
Add patterns to lower vector.transfer_read to spirv.load and vector.transfer_write to spirv.store.