Skip to content

[mlir][vector][spirv] Lower vector.transfer_read and vector.transfer_write to SPIR-V #69708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

Hsiangkai
Copy link
Contributor

Add patterns to lower vector.transfer_read to spirv.load and vector.transfer_write to spirv.store.

@llvmbot
Copy link
Member

llvmbot commented Oct 20, 2023

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Kai Wang (Hsiangkai)

Changes

Add patterns to lower vector.transfer_read to spirv.load and vector.transfer_write to spirv.store.


Full diff: https://github.com/llvm/llvm-project/pull/69708.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+84-1)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+50)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..9b320e49996e0a3 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -509,6 +509,87 @@ struct VectorShuffleOpConvert final
   }
 };
 
+struct VectorTransferReadOpConverter final
+    : public OpConversionPattern<vector::TransferReadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferReadOp transferReadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (transferReadOp.getMask())
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "unsupported transfer_read with mask");
+    auto sourceType = transferReadOp.getSource().getType();
+    if (!llvm::isa<MemRefType>(sourceType))
+      return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
+
+    auto memrefType = cast<MemRefType>(sourceType);
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto loc = transferReadOp.getLoc();
+    Value accessChain =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(),
+                             adaptor.getIndices(), loc, rewriter);
+    if (!accessChain)
+      return failure();
+
+    auto attr =
+        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+    if (!attr)
+      return failure();
+
+    spirv::StorageClass storageClass = attr.getValue();
+    auto vectorType = transferReadOp.getVectorType();
+    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+    Value castedAccessChain =
+        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(transferReadOp, vectorType,
+                                               castedAccessChain);
+
+    return success();
+  }
+};
+
+struct VectorTransferWriteOpConverter final
+    : public OpConversionPattern<vector::TransferWriteOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp transferWriteOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (transferWriteOp.getMask())
+      return rewriter.notifyMatchFailure(
+          transferWriteOp, "unsupported transfer_write with mask");
+    auto sourceType = transferWriteOp.getSource().getType();
+    if (!llvm::isa<MemRefType>(sourceType))
+      return rewriter.notifyMatchFailure(transferWriteOp,
+                                         "not a memref source");
+
+    auto memrefType = cast<MemRefType>(sourceType);
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto loc = transferWriteOp.getLoc();
+    Value accessChain =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(),
+                             adaptor.getIndices(), loc, rewriter);
+    if (!accessChain)
+      return failure();
+
+    auto attr =
+        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+    if (!attr)
+      return failure();
+
+    spirv::StorageClass storageClass = attr.getValue();
+    auto vectorType = transferWriteOp.getVectorType();
+    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+    Value castedAccessChain =
+        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+        transferWriteOp, castedAccessChain, adaptor.getVector());
+
+    return success();
+  }
+};
+
 struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -622,7 +703,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
                VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
                VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-               VectorSplatPattern>(typeConverter, patterns.getContext());
+               VectorSplatPattern, VectorTransferReadOpConverter,
+               VectorTransferWriteOpConverter>(typeConverter,
+                                               patterns.getContext());
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..c6e7b6519ad8517 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -631,3 +631,53 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
   %1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
   return %1 : vector<1xf32>
 }
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+  } {
+
+// CHECK-LABEL: @transfer_read
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+//       CHECK:   %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]]
+//       CHECK:   %[[S5:.+]] = spirv.Bitcast %[[S4]]
+//       CHECK:   %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32>
+//       CHECK:   return %[[R0]]
+func.func @transfer_read(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %0 = vector.transfer_read %arg0[%idx], %cst_0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @transfer_write
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+//  CHECK-SAME:  %[[ARG1:.*]]: vector<4xf32>
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+//       CHECK:   %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]]
+//       CHECK:   %[[S5:.+]] = spirv.Bitcast %[[S4]]
+//       CHECK:   spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32>
+func.func @transfer_write(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%idx] : vector<4xf32>, memref<4xf32, #spirv.storage_class<StorageBuffer>>
+  return
+}
+
+} // end module

@Hsiangkai Hsiangkai requested a review from kuhar October 24, 2023 08:30
Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patterns. Actually I'm wondering whether we need them--vector transfer ops are high level abstractions; they provide powerful mechansisms with padding, in bounds, and permutation maps. We'd typically want to first lower these transfer ops into vector.load/store or memref.load/store and then lower them to SPIR-V. Have you considered that paths?

const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
auto loc = transferReadOp.getLoc();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works when the permutation_map is an identity map? I don't see checks to exclude other cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revoked the fix. I am still thinking why we need identity map here. Can you give more information about it?

@Hsiangkai
Copy link
Contributor Author

Thanks for the patterns. Actually I'm wondering whether we need them--vector transfer ops are high level abstractions; they provide powerful mechansisms with padding, in bounds, and permutation maps. We'd typically want to first lower these transfer ops into vector.load/store or memref.load/store and then lower them to SPIR-V. Have you considered that paths?

There is a description in the document. That is, "A vector.transfer_read can be lowered to a simple load if all dimensions are specified to be within bounds and no mask was specified". Under these conditions, does it make sense to convert transfer_read/transfer_write to SPIR-V directly? I have not tried the way to convert transfer_read to memref.load though.

@kuhar
Copy link
Member

kuhar commented Oct 31, 2023

Thanks for the patterns. Actually I'm wondering whether we need them--vector transfer ops are high level abstractions; they provide powerful mechansisms with padding, in bounds, and permutation maps. We'd typically want to first lower these transfer ops into vector.load/store or memref.load/store and then lower them to SPIR-V. Have you considered that paths?

There is a description in the document. That is, "A vector.transfer_read can be lowered to a simple load if all dimensions are specified to be within bounds and no mask was specified". Under these conditions, does it make sense to convert transfer_read/transfer_write to SPIR-V directly? I have not tried the way to convert transfer_read to memref.load though.

What's the benefit of skipping memref loads/stores? Seems like this will complicate lowering and disallow us to run target-independent transforms at the level of memref ops, e.g., narrow type emulation.

…write to SPIR-V

Add patterns to lower vector.transfer_read to spirv.load and
vector.transfer_write to spirv.store.
@antiagainst
Copy link
Member

There is a description in the document. That is, "A vector.transfer_read can be lowered to a simple load if all dimensions are specified to be within bounds and no mask was specified". Under these conditions, does it make sense to convert transfer_read/transfer_write to SPIR-V directly? I have not tried the way to convert transfer_read to memref.load though.

Yup. That "simple load" is modelled as vector.load or memref.load at a lower level. Keep in mind that the core concept behing mlir is to have suitable representations at the proper level so we can have a principled lowering flow and sharing common utilities--where we already build quite some pieces that you might find valuable in your flow. For example, the lowering for vector transfer ops to lower level already exists in LowerVectorTransfer.cpp; and we have stuff like emulating narrow type for those lower level load ops like @kuar mentioned.

Right now we are missing vector.load/store to spirv lowering. For the cases you want to cover, I think it's worth adding them instead, which is much simpler (as we don't need to verify all those out of bounds, padding, indexing map, etc.).

There might be a place where we want to introduce vector transfer ops to spirv directly--for actually powerful cases not representable by vector.load/store and memref.load/store, but we have direct hardware support that sort of exposed by spirv ops. That is a sign of maybe we want to discuss the representation at vector/memref load/store level; and that can take time. In the meanwhile it makes sense to connect directly there to keep things going.

Hopefully this explains the rationale. Again, thanks for the contribution! :)

@Hsiangkai
Copy link
Contributor Author

There is a description in the document. That is, "A vector.transfer_read can be lowered to a simple load if all dimensions are specified to be within bounds and no mask was specified". Under these conditions, does it make sense to convert transfer_read/transfer_write to SPIR-V directly? I have not tried the way to convert transfer_read to memref.load though.

Yup. That "simple load" is modelled as vector.load or memref.load at a lower level. Keep in mind that the core concept behing mlir is to have suitable representations at the proper level so we can have a principled lowering flow and sharing common utilities--where we already build quite some pieces that you might find valuable in your flow. For example, the lowering for vector transfer ops to lower level already exists in LowerVectorTransfer.cpp; and we have stuff like emulating narrow type for those lower level load ops like @kuar mentioned.

Right now we are missing vector.load/store to spirv lowering. For the cases you want to cover, I think it's worth adding them instead, which is much simpler (as we don't need to verify all those out of bounds, padding, indexing map, etc.).

There might be a place where we want to introduce vector transfer ops to spirv directly--for actually powerful cases not representable by vector.load/store and memref.load/store, but we have direct hardware support that sort of exposed by spirv ops. That is a sign of maybe we want to discuss the representation at vector/memref load/store level; and that can take time. In the meanwhile it makes sense to connect directly there to keep things going.

Hopefully this explains the rationale. Again, thanks for the contribution! :)

Hi @kuhar and @antiagainst,

Sorry for late reply and thanks for your information. I found there are patterns to lower transfer_read to memref::Load in convert-vector-to-gpu pass. I tried to add it to my pipeline but it doesn't work. I also found there are some patterns to convert vector::TransferReadOp to nvgpu::LdMatrixOp. It is a low level dialect. That's why I think it may make sense to convert transfer_read to SPIR-V directly. Thanks for providing the rationale to me. I will take a look into LowerVectorTransfer.cpp and VectorEmulateNarrowType.cpp and figure out how to convert transfer_read to memref::Load first in my pipeline. Thank you very much! :)

@antiagainst
Copy link
Member

Depends on your use case, converting to memref.load may not out directly. If you want to read a vector out of a scalar memref, you may need to connect the flow from vector transfer ops to vector.load and then add the missing pattern from vector.load to spirv.load.

@Hsiangkai
Copy link
Contributor Author

Depends on your use case, converting to memref.load may not out directly. If you want to read a vector out of a scalar memref, you may need to connect the flow from vector transfer ops to vector.load and then add the missing pattern from vector.load to spirv.load.

Thanks a lot. I can lower to vector.load/maskedload now. I will add patterns from vector.maskedload and vector.load to spirv.load later. I close this PR first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants