Skip to content

[mlir][gpu] Introduce gpu.dynamic_shared_memory Op #71546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Nov 16, 2023

Conversation

grypp
Copy link
Member

@grypp grypp commented Nov 7, 2023

While the gpu.launch Op allows setting the size via the dynamic_shared_memory_size argument, accessing the dynamic shared memory is very convoluted. This PR implements the proposed Op, gpu.dynamic_shared_memory that aims to simplify the utilization of dynamic shared memory.

RFC: https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/

Proposal from RFC
This PR gpu.dynamic.shared.memory Op to use dynamic shared memory feature efficiently. It is is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU.

Current way Using Dynamic Shared Memory with MLIR

Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps:

  • memref.global 0-sized array LLVM's NVPTX backend expects
  • dynamic_shared_memory_size Set the size of dynamic shared memory
  • memref.get_global Access the global symbol
  • reinterpret_cast and subview Many OPs for pointer arithmetic
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem  : memref<0xf16, 3> { alignment = 16 }
func.func @main() {
  // Step 2. Allocate shared memory
  gpu.launch blocks(...) threads(...)
    dynamic_shared_memory_size %c10000 {
    // Step 3. Access the global object
    %shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
    // Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
    %4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
    %5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
    %6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
    %7 = memref.subview %6[0, 0][64, 64][1,1]  : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
    %8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
    // Step.5 Use
    "test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
    "test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
    gpu.terminator
  }

Let’s write the program above with that:

func.func @main() {
    gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
    	%i = arith.constant 18 : index
        // Step 1: Obtain shared memory directly
        %shmem = gpu.dynamic_shared_memory : memref<?xi8, 3>
        %c147456 = arith.constant 147456 : index
        %c155648 = arith.constant 155648 : index
        %7 = memref.view %shmem[%c147456][] : memref<?xi8, 3> to memref<64x64xf16, 3>
        %8 = memref.view %shmem[%c155648][] : memref<?xi8, 3> to memref<64x64xf16, 3>

        // Step 2: Utilize the shared memory
        "test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
        "test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
    }
}

This PR resolves #72513

grypp added 2 commits November 7, 2023 16:14
While the `gpu.launch` Op allows setting the size via the `dynamic_shared_memory_size` argument, accessing the dynamic shared memory is very convoluted. This PR implements the proposed Op, `gpu.dynamic_shared_memory` that aims to simplify the utilization of dynamic shared memory.

RFC: https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/

**Proposal from RFC**
This PR `gpu.dynamic.shared.memory` Op to use dynamic shared memory feature efficiently. It is is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU.

**Current way Using Dynamic Shared Memory with MLIR**

Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps:
- memref.global 0-sized array LLVM's NVPTX backend expects
- dynamic_shared_memory_size Set the size of dynamic shared memory
- memref.get_global Access the global symbol
- reinterpret_cast and subview Many OPs for pointer arithmetic

```
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem  : memref<0xf16, 3> { alignment = 16 }
func.func @main() {
  // Step 2. Allocate shared memory
  gpu.launch blocks(...) threads(...)
    dynamic_shared_memory_size %c10000 {
    // Step 3. Access the global object
    %shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
    // Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
    %4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
    %5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
    %6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
    %7 = memref.subview %6[0, 0][64, 64][1,1]  : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
    %8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
    // Step.5 Use
    "test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
    "test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
    gpu.terminator
  }
```

Let’s write the program above with that:

```
func.func @main() {
    gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
    	%i = arith.constant 18 : index
        // Step 1: Obtain shared memory directly
        %shmem = gpu.dynamic_shared_memory : memref<?xi8, 3>
        %c147456 = arith.constant 147456 : index
        %c155648 = arith.constant 155648 : index
        %7 = memref.view %shmem[%c147456][] : memref<?xi8, 3> to memref<64x64xf16, 3>
        %8 = memref.view %shmem[%c155648][] : memref<?xi8, 3> to memref<64x64xf16, 3>

        // Step 2: Utilize the shared memory
        "test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
        "test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
    }
}
```
@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2023

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Guray Ozen (grypp)

Changes

While the gpu.launch Op allows setting the size via the dynamic_shared_memory_size argument, accessing the dynamic shared memory is very convoluted. This PR implements the proposed Op, gpu.dynamic_shared_memory that aims to simplify the utilization of dynamic shared memory.

RFC: https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/

Proposal from RFC
This PR gpu.dynamic.shared.memory Op to use dynamic shared memory feature efficiently. It is is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU.

Current way Using Dynamic Shared Memory with MLIR

Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps:

  • memref.global 0-sized array LLVM's NVPTX backend expects
  • dynamic_shared_memory_size Set the size of dynamic shared memory
  • memref.get_global Access the global symbol
  • reinterpret_cast and subview Many OPs for pointer arithmetic
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @<!-- -->dynamicShmem  : memref&lt;0xf16, 3&gt; { alignment = 16 }
func.func @<!-- -->main() {
  // Step 2. Allocate shared memory
  gpu.launch blocks(...) threads(...)
    dynamic_shared_memory_size %c10000 {
    // Step 3. Access the global object
    %shmem = memref.get_global @<!-- -->dynamicShmem : memref&lt;0xf16, 3&gt;
    // Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
    %4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128],  strides: [8192,128,1] : memref&lt;0xf16, 3&gt; to memref&lt;14x64x128xf16,3&gt;
    %5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref&lt;14x64x128xf16,3&gt; to memref&lt;7x64x128xf16, strided&lt;[8192, 128, 1], offset: 57344&gt;, 3&gt;
    %6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref&lt;7x64x128xf16, strided&lt;[8192, 128, 1], offset: 57344&gt;, 3&gt; to memref&lt;64x128xf16, strided&lt;[128, 1], offset: 73728&gt;, 3&gt;
    %7 = memref.subview %6[0, 0][64, 64][1,1]  : memref&lt;64x128xf16, strided&lt;[128, 1], offset: 73728&gt;, 3&gt; to memref&lt;64x64xf16, strided&lt;[128, 1], offset: 73728&gt;, 3&gt;
    %8 = memref.subview %6[32, 0][64, 64][1,1] : memref&lt;64x128xf16, strided&lt;[128, 1], offset: 73728&gt;, 3&gt; to memref&lt;64x64xf16, strided&lt;[128, 1], offset: 77824&gt;, 3&gt;
    // Step.5 Use
    "test.use.shared.memory"(%7) : (memref&lt;64x64xf16, strided&lt;[128, 1], offset: 73728&gt;, 3&gt;) -&gt; (index)
    "test.use.shared.memory"(%8) : (memref&lt;64x64xf16, strided&lt;[128, 1], offset: 77824&gt;, 3&gt;) -&gt; (index)
    gpu.terminator
  }

Let’s write the program above with that:

func.func @<!-- -->main() {
    gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
    	%i = arith.constant 18 : index
        // Step 1: Obtain shared memory directly
        %shmem = gpu.dynamic_shared_memory : memref&lt;?xi8, 3&gt;
        %c147456 = arith.constant 147456 : index
        %c155648 = arith.constant 155648 : index
        %7 = memref.view %shmem[%c147456][] : memref&lt;?xi8, 3&gt; to memref&lt;64x64xf16, 3&gt;
        %8 = memref.view %shmem[%c155648][] : memref&lt;?xi8, 3&gt; to memref&lt;64x64xf16, 3&gt;

        // Step 2: Utilize the shared memory
        "test.use.shared.memory"(%7) : (memref&lt;64x64xf16, 3&gt;) -&gt; (index)
        "test.use.shared.memory"(%8) : (memref&lt;64x64xf16, 3&gt;) -&gt; (index)
    }
}

Full diff: https://github.com/llvm/llvm-project/pull/71546.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUBase.td (+10)
  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (+13)
  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+23)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h (+3)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+87)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h (+21)
  • (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+3)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+31-10)
  • (added) mlir/test/Dialect/GPU/dynamic-shared-memory.mlir (+35)
  • (modified) mlir/test/Dialect/GPU/invalid.mlir (+49)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 755c82d8b75c9c0..057b507c394e80f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -52,6 +52,16 @@ def GPU_Dialect : Dialect {
     /// Returns the numeric value used to identify the private memory address
     /// space.
     static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
+    
+    /// Return true if the given MemRefType has an integer address
+    /// space that matches the workgroup memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool hasWorkgroupMemoryAddressSpace(MemRefType type);
+
+    /// Return true if the given Attribute has an integer address
+    /// space that matches the workgroup memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool isWorkgroupMemoryAddressSpace(Attribute memorySpace);  
   }];
 
   let dependentDialects = ["arith::ArithDialect"];
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 14a1fac5fd255f3..286856324950eb7 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -17,6 +17,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/DLTI/Traits.h"
 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -32,6 +33,18 @@
 namespace mlir {
 namespace gpu {
 
+/// GPU memory space identifiers.
+enum GPUMemorySpace {
+  /// Generic memory space identifier.
+  kGenericMemorySpace = 0,
+
+  /// Global memory space identifier.
+  kGlobalMemorySpace = 1,
+
+  /// Shared memory space identifier.
+  kSharedMemorySpace = 3
+};
+
 /// Utility class for the GPU dialect to represent triples of `Value`s
 /// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
 struct KernelDim3 {
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 6375d35f4311295..f3a37c62d3a7465 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -433,6 +433,29 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
   let hasVerifier = 1;
 }
 
+def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [] > {
+  let summary = "Get the memref for dynamic shared memory";
+  
+  let description = [{
+    This operation provides a memref pointer to the start of dynamic shared 
+    memory, often referred to as workgroup memory. It's important to note that
+     this dynamic shared memory needs to be allocated at kernel launch. One can 
+     conveniently utilize `the dynamic_shared_memory_size` parameter of 
+     `gpu.launch` for this purpose.
+   
+    Examples: 
+    ```mlir        
+    %0 = gpu.dynamic.shared.memory : memref<?xi8, 3>
+    %1 = memref.view %0[%c8192][] : memref<?xi8, 3> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    %2 = memref.view %0[%c16384][] : memref<?xi8, 3> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    ```
+  }];  
+  let arguments = (ins );
+  let results = (outs Arg<MemRefRankOf<[I8], [1]>>:$resultMemref);
+  let assemblyFormat = [{ attr-dict `:` type($resultMemref) }];
+  let hasVerifier = 1;
+}
+
 def LaunchIndx : AnyTypeOf<[Index, I32, I64]>;
 
 def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 8ff8f850a9c1858..08019e77ae6af8a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -27,6 +27,9 @@
 namespace mlir {
 namespace NVVM {
 
+// Shared memory has 128-bit alignment
+constexpr int kSharedMemoryAlignmentBit = 128;
+
 /// NVVM memory space identifiers.
 enum NVVMMemorySpace {
   /// Global memory space identifier.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6d2585aa30ab4c5..fbea498ee27caa7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
 #include "GPUOpsLowering.h"
 
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -554,6 +555,92 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
   return IntegerAttr::get(IntegerType::get(ctx, 64), space);
 }
 
+/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
+/// or uses existing symbol.
+LLVM::GlobalOp getDynamicSharedMemorySymbol(
+    ConversionPatternRewriter &rewriter, gpu::DynamicSharedMemoryOp op,
+    const LLVMTypeConverter *typeConverter, MemRefType memrefType, unsigned alignmentBit) {
+  std::optional<LLVM::GlobalOp> existingGlobalOp;
+
+  LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+  assert(funcOp && "cannot find llvm.func op");
+
+  gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
+  assert(moduleOp && "cannot find gpu.module op");
+
+  // Use already generated global op if it exists
+  int index = 0;
+  std::string prefix = llvm::formatv("__shmem_{0}", funcOp.getSymName());
+  moduleOp->walk([&](LLVM::GlobalOp globalOp) {
+    if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+      if (arrayType.getNumElements() == 0) {
+        existingGlobalOp = globalOp;
+        return WalkResult::interrupt();
+      }
+    }
+    if (globalOp.getSymName().startswith(prefix))
+      index++;
+    return WalkResult::advance();
+  });
+  if (existingGlobalOp.has_value())
+    return existingGlobalOp.value();
+
+  // Generate a new global op
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(&moduleOp.front());
+
+  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
+      typeConverter->convertType(memrefType.getElementType()), 0);
+  std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
+  // TODO: better alignment calculation
+  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+  return rewriter.create<LLVM::GlobalOp>(
+      funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
+      LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignmentByte,
+      mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
+}
+
+LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
+    gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  MemRefType memrefType = op.getResultMemref().getType();
+  auto elementType = typeConverter->convertType(memrefType.getElementType());
+  assert(memrefType && "memref is not valid");
+  
+  // Step 1: Generate a memref<0xi8> type
+  MemRefLayoutAttrInterface layout = {};
+  auto memrefType0sz = MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());  
+
+  // Step 2: Generate a global symbol or existing for the dynamic shared
+  // memory with memref<0xi8> type
+  LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
+      rewriter, op, getTypeConverter(), memrefType0sz ,alignmentBit);
+  assert(shmemOp && "cannot find module op or failed generating global op");
+
+  // Step 3. Get address of the global symbol
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(op);
+  auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
+  Type baseType = basePtr->getResultTypes().front();
+
+  // Step 4. Generate GEP using offsets
+  SmallVector<LLVM::GEPArg> gepArgs = {0};
+  Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
+                                                basePtr, gepArgs);
+  // Step 5. Create a memref descriptor
+  SmallVector<Value> shape, strides;
+  Value sizeBytes;
+  getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
+                           sizeBytes);
+  auto memRefDescriptor = this->createMemRefDescriptor(
+      loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
+
+  // Step 5. Replace the op with memref descriptor
+  rewriter.replaceOp(op, {memRefDescriptor});
+  return success();
+}
+
 void mlir::populateGpuMemorySpaceAttributeConversions(
     TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
   typeConverter.addTypeAttributeConversion(
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index bd90286494d8035..a77db4a036bad3f 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,27 @@
 
 namespace mlir {
 
+/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
+/// create a 0-sized global array symbol similar as LLVM expects. It constructs
+/// a memref descriptor with these values and return it.
+struct GPUDynamicSharedMemoryOpLowering
+    : public ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp> {
+  using ConvertOpToLLVMPattern<
+      gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
+  GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
+                                   unsigned alignmentBit = 0)
+      : ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
+        alignmentBit(alignmentBit) {}
+
+  LogicalResult
+  matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+
+private:
+  // Alignment bit
+  unsigned alignmentBit;
+};
+
 struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
   GPUFuncOpLowering(const LLVMTypeConverter &converter,
                     unsigned allocaAddrSpace, unsigned workgroupAddrSpace,
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 935e3d2a4095003..86a77f557cb9579 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -325,6 +325,9 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
            GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
           converter);
 
+  patterns.add<GPUDynamicSharedMemoryOpLowering>(
+      converter, NVVM::kSharedMemoryAlignmentBit);
+
   // Explicitly drop memory space when lowering private memory
   // attributions since NVVM models it as `alloca`s in the default
   // memory space and does not support `alloca`s with addrspace(5).
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5eb2cadc884e151..3216e82147da907 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -164,17 +164,20 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
 // GPUDialect
 //===----------------------------------------------------------------------===//
 
-/// GPU memory space identifiers.
-enum GPUMemorySpace {
-  /// Generic memory space identifier.
-  kGenericMemorySpace = 0,
-
-  /// Global memory space identifier.
-  kGlobalMemorySpace = 1,
+bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intAttr.getInt() == GPUMemorySpace::kSharedMemorySpace;
+  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuAttr.getValue() == getWorkgroupAddressSpace();
+  return false;
+}
 
-  /// Shared memory space identifier.
-  kSharedMemorySpace = 3
-};
+bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
+  Attribute memorySpace = type.getMemorySpace();
+  return isWorkgroupMemoryAddressSpace(memorySpace);
+}
 
 bool GPUDialect::isKernel(Operation *op) {
   UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
@@ -2024,6 +2027,24 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// DynamicSharedMemoryOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult gpu::DynamicSharedMemoryOp::verify() {
+  MemRefType memrefType = getResultMemref().getType();
+  // Check address space
+  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
+    return emitOpError() << "Address space must be "
+                         << gpu::AddressSpaceAttr::getMnemonic() << "<"
+                         << stringifyEnum(gpu::AddressSpace::Workgroup)
+                         << "> or " << int(GPUMemorySpace::kSharedMemorySpace);
+  }
+  if(memrefType.hasStaticShape()) 
+    return emitOpError() << "result memref type must be memref<?xi8>";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GPU target options
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
new file mode 100644
index 000000000000000..a2706fa6bde7ae9
--- /dev/null
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
+
+gpu.module @modules {
+  // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+  
+  // CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
+  // CHECK-SAME: %[[arg0:.+]]: i64)
+  gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {    
+    %c1 = arith.constant 1 : index
+    %c100 = arith.constant 100 : index
+    %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+
+    %0 = memref.view %shmem[%c100][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+    
+// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
+// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_0 : !llvm.ptr<3>
+// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][100] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S12:.+]] = llvm.insertvalue %[[S0]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %[[S13]] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+    gpu.return
+  }
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index c8c0b7d24bc3ab2..768289819cc0e09 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -640,3 +640,52 @@ module {
   // expected-error @+1 {{'gpu.binary' op attribute 'offloadingHandler' failed to satisfy constraint: any attribute with the `OffloadingTranslationAttrTrait` trait.}}
   gpu.binary @binary <1> [#gpu.object<#nvvm.target, "">]
 }
+
+// -----
+
+func.func @main() {
+  %shmemSize = arith.constant 10000 : i32
+  %c1 = arith.constant 1 : index
+  gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
+             dynamic_shared_memory_size %shmemSize
+  {
+    // expected-error @+1 {{op Address space must be address_space<workgroup> or 3}}
+    %0 = gpu.dynamic_shared_memory : memref<?xi8>  
+    gpu.terminator
+  }
+  return
+}
+
+
+// -----
+
+func.func @main() {
+  %shmemSize = arith.constant 8192 : i32
+  %c1 = arith.constant 1 : index
+  gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
+             dynamic_shared_memory_size %shmemSize
+  {
+    // expected-error @+1 {{result memref type must be memref<?xi8>}}
+    %0 = gpu.dynamic_shared_memory : memref<1xi8,3>  
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @main(%arg0 : index) {
+  %shmemSize = arith.constant 8192 : i32
+  %c1 = arith.constant 1 : index
+  gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
+             dynamic_shared_memory_size %shmemSize
+  {
+    // expected-error @+1 {{op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, 3>}}
+    %0 = gpu.dynamic_shared_memory : memref<?xf32,3>  
+    gpu.terminator
+  }
+  return
+}

Copy link

github-actions bot commented Nov 7, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

// LLVM::GlobalOp is suitable for shared memory, return it.
llvm::StringSet<> existingGlobalNames;
for (auto globalOp :
moduleOp->getRegion(0).front().template getOps<LLVM::GlobalOp>()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a getBody() or equivalent in moduleOp?
Also not sure why the template is needed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the moduleOp is actually Operation*, so I could not see getBody

Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();

@grypp grypp merged commit ea84897 into llvm:main Nov 16, 2023
@grypp grypp deleted the dynamic-shmem-attempt-2 branch November 16, 2023 13:43
sr-tream pushed a commit to sr-tream/llvm-project that referenced this pull request Nov 20, 2023
While the `gpu.launch` Op allows setting the size via the
`dynamic_shared_memory_size` argument, accessing the dynamic shared
memory is very convoluted. This PR implements the proposed Op,
`gpu.dynamic_shared_memory` that aims to simplify the utilization of
dynamic shared memory.

RFC:
https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/

**Proposal from RFC**
This PR `gpu.dynamic.shared.memory` Op to use dynamic shared memory
feature efficiently. It is is a powerful feature that enables the
allocation of shared memory at runtime with the kernel launch on the
host. Afterwards, the memory can be accessed directly from the device. I
believe similar story exists for AMDGPU.

**Current way Using Dynamic Shared Memory with MLIR**

Let me illustrate the challenges of using dynamic shared memory in MLIR
with an example below. The process involves several steps:
- memref.global 0-sized array LLVM's NVPTX backend expects
- dynamic_shared_memory_size Set the size of dynamic shared memory
- memref.get_global Access the global symbol
- reinterpret_cast and subview Many OPs for pointer arithmetic

```
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem  : memref<0xf16, 3> { alignment = 16 }
func.func @main() {
  // Step 2. Allocate shared memory
  gpu.launch blocks(...) threads(...)
    dynamic_shared_memory_size %c10000 {
    // Step 3. Access the global object
    %shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
    // Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
    %4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
    %5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
    %6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
    %7 = memref.subview %6[0, 0][64, 64][1,1]  : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
    %8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
    // Step.5 Use
    "test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
    "test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
    gpu.terminator
  }
```

Let’s write the program above with that:

```
func.func @main() {
    gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
    	%i = arith.constant 18 : index
        // Step 1: Obtain shared memory directly
        %shmem = gpu.dynamic_shared_memory : memref<?xi8, 3>
        %c147456 = arith.constant 147456 : index
        %c155648 = arith.constant 155648 : index
        %7 = memref.view %shmem[%c147456][] : memref<?xi8, 3> to memref<64x64xf16, 3>
        %8 = memref.view %shmem[%c155648][] : memref<?xi8, 3> to memref<64x64xf16, 3>

        // Step 2: Utilize the shared memory
        "test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
        "test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
    }
}
```

This PR resolves llvm#72513
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
While the `gpu.launch` Op allows setting the size via the
`dynamic_shared_memory_size` argument, accessing the dynamic shared
memory is very convoluted. This PR implements the proposed Op,
`gpu.dynamic_shared_memory` that aims to simplify the utilization of
dynamic shared memory.

RFC:
https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/

**Proposal from RFC**
This PR `gpu.dynamic.shared.memory` Op to use dynamic shared memory
feature efficiently. It is is a powerful feature that enables the
allocation of shared memory at runtime with the kernel launch on the
host. Afterwards, the memory can be accessed directly from the device. I
believe similar story exists for AMDGPU.

**Current way Using Dynamic Shared Memory with MLIR**

Let me illustrate the challenges of using dynamic shared memory in MLIR
with an example below. The process involves several steps:
- memref.global 0-sized array LLVM's NVPTX backend expects
- dynamic_shared_memory_size Set the size of dynamic shared memory
- memref.get_global Access the global symbol
- reinterpret_cast and subview Many OPs for pointer arithmetic

```
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem  : memref<0xf16, 3> { alignment = 16 }
func.func @main() {
  // Step 2. Allocate shared memory
  gpu.launch blocks(...) threads(...)
    dynamic_shared_memory_size %c10000 {
    // Step 3. Access the global object
    %shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
    // Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
    %4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
    %5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
    %6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
    %7 = memref.subview %6[0, 0][64, 64][1,1]  : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
    %8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
    // Step.5 Use
    "test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
    "test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
    gpu.terminator
  }
```

Let’s write the program above with that:

```
func.func @main() {
    gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
    	%i = arith.constant 18 : index
        // Step 1: Obtain shared memory directly
        %shmem = gpu.dynamic_shared_memory : memref<?xi8, 3>
        %c147456 = arith.constant 147456 : index
        %c155648 = arith.constant 155648 : index
        %7 = memref.view %shmem[%c147456][] : memref<?xi8, 3> to memref<64x64xf16, 3>
        %8 = memref.view %shmem[%c155648][] : memref<?xi8, 3> to memref<64x64xf16, 3>

        // Step 2: Utilize the shared memory
        "test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
        "test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
    }
}
```

This PR resolves llvm#72513
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dynamic Shared Memory Support
6 participants