Skip to content

[mlir][nvgpu] Add nvgpu.tma.async.store #77811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2024
Merged

Conversation

grypp
Copy link
Member

@grypp grypp commented Jan 11, 2024

PR adds nvgpu.tma.async.store Op for asynchronous stores using the Tensor Memory Access (TMA) unit.

It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.

PR adds `nvgpu.tma.async.store` Op for asynchronous stores usingfrom the Tensor Memory Access (TMA) unit.

It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.
@llvmbot
Copy link
Member

llvmbot commented Jan 11, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Guray Ozen (grypp)

Changes

PR adds nvgpu.tma.async.store Op for asynchronous stores using the Tensor Memory Access (TMA) unit.

It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.


Full diff: https://github.com/llvm/llvm-project/pull/77811.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+22)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+24)
  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+23)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+46)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 7e139663d74b47..239a5f1e2bc298 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -661,6 +661,28 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]
 
 }
 
+def NVGPU_TmaAsyncStoreOp : NVGPU_Op<"tma.async.store", [AttrSizedOperandSegments]> {
+  let summary = "TMA asynchronous store";
+  let description = [{
+    The Op store a tile memory region from global memory to shared memory by 
+    Tensor Memory Access (TMA).
+    
+    `$tensorMapDescriptor` is tensor map descriptor which has information about
+    tile shape. The descriptor is created by `nvgpu.tma.create.descriptor`
+  }];  
+  let arguments = (ins  Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
+                        NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
+                        Variadic<Index>:$coordinates, 
+                        Optional<I1>:$predicate);
+  let assemblyFormat = [{
+      $src `to` $tensorMapDescriptor `[` $coordinates `]`
+      (`,` `predicate` `=` $predicate^)?
+      attr-dict `:` type($src)
+      `->` type($tensorMapDescriptor)
+  }];
+  let hasVerifier = 1;
+}
+
 def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
   let summary = "TMA create descriptor";
   let description = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index db84e5cf62a5e9..759766275de4a5 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -995,6 +995,29 @@ struct NVGPUTmaAsyncLoadOpLowering
     return success();
   }
 };
+
+struct NVGPUTmaAsyncStoreOpLowering
+    : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
+  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
+  LogicalResult
+  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+    auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
+    Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
+                                      adaptor.getSrc(), {}, rewriter);
+    SmallVector<Value> coords = adaptor.getCoordinates();
+    for (auto [index, value] : llvm::enumerate(coords)) {
+      coords[index] = truncToI32(b, value);
+    }
+
+    rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
+        op, adaptor.getTensorMapDescriptor(), dest, coords,
+        adaptor.getPredicate());
+    return success();
+  }
+};
+
 struct NVGPUGenerateWarpgroupDescriptorLowering
     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
   using ConvertOpToLLVMPattern<
@@ -1639,6 +1662,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUMBarrierTestWaitLowering,         // nvgpu.mbarrier.test_wait_parity
       NVGPUMBarrierTryWaitParityLowering,    // nvgpu.mbarrier.try_wait_parity
       NVGPUTmaAsyncLoadOpLowering,           // nvgpu.tma.async.load
+      NVGPUTmaAsyncStoreOpLowering,          // nvgpu.tma.async.store
       NVGPUTmaCreateDescriptorOpLowering,    // nvgpu.tma.create.descriptor
       NVGPUTmaPrefetchOpLowering,            // nvgpu.tma.prefetch.descriptor
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index c9756ae8fc11ce..5ffa854e97cb17 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -402,6 +402,29 @@ LogicalResult TmaAsyncLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// NVGPU_TmaAsyncStoreOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TmaAsyncStoreOp::verify() {
+  std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
+      *this, getTensorMapDescriptor().getType(), getSrc().getType());
+  if (error.has_value())
+    return error.value();
+
+  if (getCoordinates().size() > kMaxTMATensorDimension) {
+    return emitError() << "Maximum " << kMaxTMATensorDimension
+                       << " coordinates are supported.";
+  }
+  if (getCoordinates().size() !=
+      size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
+    return emitError() << "number of coordinates do not match with the rank of "
+                          "tensor descriptor map.";
+  }
+
+  return success();
+}
+
 LogicalResult TmaCreateDescriptorOp::verify() {
   if (getBoxDimensions().size() > kMaxTMATensorDimension) {
     return emitError() << "Maximum " << kMaxTMATensorDimension
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index b8a0f75d1cc8b9..edccd7e80603bd 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -728,6 +728,52 @@ func.func @async_tma_load_multicast(
   func.return 
 }
 
+func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d, 
+                           %buffer1d: memref<128xf32,3>,      
+                           %buffer2d: memref<32x32xf32,3>,    
+                           %buffer3d: memref<2x32x32xf32,3>,  
+                           %buffer4d: memref<2x2x32x32xf32,3>,  
+                           %buffer5d: memref<2x2x2x32x32xf32,3>) {
+  %c0 = arith.constant 0 : index
+  %crd0 = arith.constant 0 : index
+  %crd1 = arith.constant 0 : index
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}] 
+  nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0] : memref<128xf32,3> -> !tensorMap1d 
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1]  : memref<32x32xf32,3> -> !tensorMap2d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0]  : memref<2x32x32xf32,3> -> !tensorMap3d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0]  : memref<2x2x32x32xf32,3> -> !tensorMap4d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0]  : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+  func.return 
+}
+
+
+func.func @async_tma_store_predicate(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d, 
+                           %buffer1d: memref<128xf32,3>,      
+                           %buffer2d: memref<32x32xf32,3>,    
+                           %buffer3d: memref<2x32x32xf32,3>,  
+                           %buffer4d: memref<2x2x32x32xf32,3>,  
+                           %buffer5d: memref<2x2x2x32x32xf32,3>,
+                           %p: i1) {
+  %c0 = arith.constant 0 : index
+  %crd0 = arith.constant 0 : index
+  %crd1 = arith.constant 0 : index
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0], predicate = %p : memref<128xf32,3> -> !tensorMap1d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1], predicate = %p  : memref<32x32xf32,3> -> !tensorMap2d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0], predicate = %p  : memref<2x32x32xf32,3> -> !tensorMap3d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0], predicate = %p  : memref<2x2x32x32xf32,3> -> !tensorMap4d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], predicate = %p  : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+  func.return 
+}
+
 func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
   %crd0 = arith.constant 64 : index
   %crd1 = arith.constant 128 : index

@grypp grypp requested review from joker-eph and manishucsd January 11, 2024 18:23
@llvmbot
Copy link
Member

llvmbot commented Jan 11, 2024

@llvm/pr-subscribers-mlir-nvgpu

Author: Guray Ozen (grypp)

Changes

PR adds nvgpu.tma.async.store Op for asynchronous stores using the Tensor Memory Access (TMA) unit.

It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.


Full diff: https://github.com/llvm/llvm-project/pull/77811.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+22)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+24)
  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+23)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+46)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 7e139663d74b47..239a5f1e2bc298 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -661,6 +661,28 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]
 
 }
 
+def NVGPU_TmaAsyncStoreOp : NVGPU_Op<"tma.async.store", [AttrSizedOperandSegments]> {
+  let summary = "TMA asynchronous store";
+  let description = [{
+    The Op store a tile memory region from global memory to shared memory by 
+    Tensor Memory Access (TMA).
+    
+    `$tensorMapDescriptor` is tensor map descriptor which has information about
+    tile shape. The descriptor is created by `nvgpu.tma.create.descriptor`
+  }];  
+  let arguments = (ins  Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
+                        NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
+                        Variadic<Index>:$coordinates, 
+                        Optional<I1>:$predicate);
+  let assemblyFormat = [{
+      $src `to` $tensorMapDescriptor `[` $coordinates `]`
+      (`,` `predicate` `=` $predicate^)?
+      attr-dict `:` type($src)
+      `->` type($tensorMapDescriptor)
+  }];
+  let hasVerifier = 1;
+}
+
 def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
   let summary = "TMA create descriptor";
   let description = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index db84e5cf62a5e9..759766275de4a5 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -995,6 +995,29 @@ struct NVGPUTmaAsyncLoadOpLowering
     return success();
   }
 };
+
+struct NVGPUTmaAsyncStoreOpLowering
+    : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
+  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
+  LogicalResult
+  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+    auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
+    Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
+                                      adaptor.getSrc(), {}, rewriter);
+    SmallVector<Value> coords = adaptor.getCoordinates();
+    for (auto [index, value] : llvm::enumerate(coords)) {
+      coords[index] = truncToI32(b, value);
+    }
+
+    rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
+        op, adaptor.getTensorMapDescriptor(), dest, coords,
+        adaptor.getPredicate());
+    return success();
+  }
+};
+
 struct NVGPUGenerateWarpgroupDescriptorLowering
     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
   using ConvertOpToLLVMPattern<
@@ -1639,6 +1662,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUMBarrierTestWaitLowering,         // nvgpu.mbarrier.test_wait_parity
       NVGPUMBarrierTryWaitParityLowering,    // nvgpu.mbarrier.try_wait_parity
       NVGPUTmaAsyncLoadOpLowering,           // nvgpu.tma.async.load
+      NVGPUTmaAsyncStoreOpLowering,          // nvgpu.tma.async.store
       NVGPUTmaCreateDescriptorOpLowering,    // nvgpu.tma.create.descriptor
       NVGPUTmaPrefetchOpLowering,            // nvgpu.tma.prefetch.descriptor
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index c9756ae8fc11ce..5ffa854e97cb17 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -402,6 +402,29 @@ LogicalResult TmaAsyncLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// NVGPU_TmaAsyncStoreOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TmaAsyncStoreOp::verify() {
+  std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
+      *this, getTensorMapDescriptor().getType(), getSrc().getType());
+  if (error.has_value())
+    return error.value();
+
+  if (getCoordinates().size() > kMaxTMATensorDimension) {
+    return emitError() << "Maximum " << kMaxTMATensorDimension
+                       << " coordinates are supported.";
+  }
+  if (getCoordinates().size() !=
+      size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
+    return emitError() << "number of coordinates do not match with the rank of "
+                          "tensor descriptor map.";
+  }
+
+  return success();
+}
+
 LogicalResult TmaCreateDescriptorOp::verify() {
   if (getBoxDimensions().size() > kMaxTMATensorDimension) {
     return emitError() << "Maximum " << kMaxTMATensorDimension
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index b8a0f75d1cc8b9..edccd7e80603bd 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -728,6 +728,52 @@ func.func @async_tma_load_multicast(
   func.return 
 }
 
+func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d, 
+                           %buffer1d: memref<128xf32,3>,      
+                           %buffer2d: memref<32x32xf32,3>,    
+                           %buffer3d: memref<2x32x32xf32,3>,  
+                           %buffer4d: memref<2x2x32x32xf32,3>,  
+                           %buffer5d: memref<2x2x2x32x32xf32,3>) {
+  %c0 = arith.constant 0 : index
+  %crd0 = arith.constant 0 : index
+  %crd1 = arith.constant 0 : index
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}] 
+  nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0] : memref<128xf32,3> -> !tensorMap1d 
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1]  : memref<32x32xf32,3> -> !tensorMap2d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0]  : memref<2x32x32xf32,3> -> !tensorMap3d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0]  : memref<2x2x32x32xf32,3> -> !tensorMap4d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0]  : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+  func.return 
+}
+
+
+func.func @async_tma_store_predicate(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d, 
+                           %buffer1d: memref<128xf32,3>,      
+                           %buffer2d: memref<32x32xf32,3>,    
+                           %buffer3d: memref<2x32x32xf32,3>,  
+                           %buffer4d: memref<2x2x32x32xf32,3>,  
+                           %buffer5d: memref<2x2x2x32x32xf32,3>,
+                           %p: i1) {
+  %c0 = arith.constant 0 : index
+  %crd0 = arith.constant 0 : index
+  %crd1 = arith.constant 0 : index
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0], predicate = %p : memref<128xf32,3> -> !tensorMap1d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1], predicate = %p  : memref<32x32xf32,3> -> !tensorMap2d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0], predicate = %p  : memref<2x32x32xf32,3> -> !tensorMap3d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0], predicate = %p  : memref<2x2x32x32xf32,3> -> !tensorMap4d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], predicate = %p  : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+  func.return 
+}
+
 func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
   %crd0 = arith.constant 64 : index
   %crd1 = arith.constant 128 : index

@grypp grypp requested a review from jpienaar January 11, 2024 18:23
`->` type($tensorMapDescriptor)
}];
let hasVerifier = 1;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to add the commit/wait_group version for TMA stores? Otherwise this instruction is not that useful? Unless you expect people to be calling the NVVM versions for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid point. I suggest using these OPs directly from the NVVM dialect.

Why? The NVGPU dialect acts as a bridge between high-level dialects (such as memref, vector, and etc.) and the NVVM dialect. The cp.async.bulk.commit.group is straightforward to use, and introducing it to the NVGPU dialect would only result in replication.

If we need higher-level abstractions (what we have for operations like mbarrier.group or nvgpu.tma.descriptor) then we can introduce specific Ops in the NVGPU dialect.

What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(just noticed that cp.async.bulk.wait_group.read is missing in NVVM)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, sounds good to me. Perhaps it would be a good idea to add a pointer to the NVVM op to the description here.

Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
adaptor.getSrc(), {}, rewriter);
SmallVector<Value> coords = adaptor.getCoordinates();
for (auto [index, value] : llvm::enumerate(coords)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not this for PR but we should finally add reversal to the coords to make the dimension ordering work as MLIR normally expects

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, let's do this in another PR

`->` type($tensorMapDescriptor)
}];
let hasVerifier = 1;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, sounds good to me. Perhaps it would be a good idea to add a pointer to the NVVM op to the description here.

@grypp grypp merged commit 8dd0d95 into llvm:main Jan 15, 2024
@grypp grypp deleted the nvgpu-tma-store branch January 15, 2024 10:44
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
PR adds `nvgpu.tma.async.store` Op for asynchronous stores using the
Tensor Memory Access (TMA) unit.

It also implements Op lowering to NVVM dialect. The Op currently
performs asynchronous stores of a tile memory region from shared to
global memory for a single CTA.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants