Skip to content

[mlir][vector][spirv] Lower vector.load and vector.store to SPIR-V #71674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 24, 2023

Conversation

Hsiangkai
Copy link
Contributor

Add patterns to lower vector.load to spirv.load and vector.store to spirv.store.

@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Kai Wang (Hsiangkai)

Changes

Add patterns to lower vector.load to spirv.load and vector.store to spirv.store.


Full diff: https://github.com/llvm/llvm-project/pull/71674.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+76-9)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+102)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..fab70e03c6a3b39 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -509,6 +509,72 @@ struct VectorShuffleOpConvert final
   }
 };
 
+struct VectorLoadOpConverter final
+    : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto memrefType = loadOp.getMemRefType();
+    auto attr =
+        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+    if (!attr)
+      return failure();
+
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto loc = loadOp.getLoc();
+    Value accessChain =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+                             adaptor.getIndices(), loc, rewriter);
+    if (!accessChain)
+      return failure();
+
+    spirv::StorageClass storageClass = attr.getValue();
+    auto vectorType = loadOp.getVectorType();
+    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+    Value castedAccessChain =
+        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
+                                               castedAccessChain);
+
+    return success();
+  }
+};
+
+struct VectorStoreOpConverter final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto memrefType = storeOp.getMemRefType();
+    auto attr =
+        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+    if (!attr)
+      return failure();
+
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto loc = storeOp.getLoc();
+    Value accessChain =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+                             adaptor.getIndices(), loc, rewriter);
+    if (!accessChain)
+      return failure();
+
+    spirv::StorageClass storageClass = attr.getValue();
+    auto vectorType = storeOp.getVectorType();
+    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+    Value castedAccessChain =
+        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
+                                                adaptor.getValueToStore());
+
+    return success();
+  }
+};
+
 struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -614,15 +680,16 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
-  patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
-               VectorExtractElementOpConvert, VectorExtractOpConvert,
-               VectorExtractStridedSliceOpConvert,
-               VectorFmaOpConvert<spirv::GLFmaOp>,
-               VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-               VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
-               VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
-               VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-               VectorSplatPattern>(typeConverter, patterns.getContext());
+  patterns.add<
+      VectorBitcastConvert, VectorBroadcastConvert,
+      VectorExtractElementOpConvert, VectorExtractOpConvert,
+      VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
+      VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+      VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
+      VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
+      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      typeConverter, patterns.getContext());
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..858ecd759d0972b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -631,3 +631,105 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
   %1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
   return %1 : vector<1xf32>
 }
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+  } {
+
+// CHECK-LABEL: @vector_load
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+//       CHECK:   %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+//       CHECK:   %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32>
+//       CHECK:   return %[[R0]] : vector<4xf32>
+func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_load_2d
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:   %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
+//       CHECK:   %[[CST0_1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST0_2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST4:.+]] = spirv.Constant 4 : i32
+//       CHECK:   %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
+//       CHECK:   %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
+//       CHECK:   %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+//       CHECK:   %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+//       CHECK:   %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S8]] : vector<4xf32>
+//       CHECK:   return %[[R0]] : vector<4xf32>
+func.func @vector_load_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %0 = vector.load %arg0[%idx_0, %idx_1] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_store
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+//  CHECK-SAME:  %[[ARG1:.*]]: vector<4xf32>
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+//       CHECK:   %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+//       CHECK:   spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32>
+func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx = arith.constant 0 : index
+  vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: @vector_store_2d
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
+//  CHECK-SAME:  %[[ARG1:.*]]: vector<4xf32>
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:   %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
+//       CHECK:   %[[CST0_1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST0_2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST4:.+]] = spirv.Constant 4 : i32
+//       CHECK:   %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
+//       CHECK:   %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
+//       CHECK:   %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+//       CHECK:   %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+//       CHECK:   spirv.Store "StorageBuffer" %[[S8]], %[[ARG1]] : vector<4xf32>
+func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  vector.store %arg1, %arg0[%idx_0, %idx_1] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return
+}
+
+} // end module

Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice! Just a few nits. Thanks!

Add patterns to lower vector.load to spirv.load and vector.store to
spirv.store.
@Hsiangkai Hsiangkai force-pushed the vector-load-to-spirv branch from 194c8ac to 0959b03 Compare November 9, 2023 13:40
@Hsiangkai Hsiangkai merged commit 3049c76 into llvm:main Nov 24, 2023
@Hsiangkai Hsiangkai deleted the vector-load-to-spirv branch November 24, 2023 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants