-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][nvgpu] Add nvgpu.tma.async.store
#77811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
PR adds `nvgpu.tma.async.store` Op for asynchronous stores usingfrom the Tensor Memory Access (TMA) unit. It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Guray Ozen (grypp) ChangesPR adds It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA. Full diff: https://github.com/llvm/llvm-project/pull/77811.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 7e139663d74b47..239a5f1e2bc298 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -661,6 +661,28 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]
}
+def NVGPU_TmaAsyncStoreOp : NVGPU_Op<"tma.async.store", [AttrSizedOperandSegments]> {
+ let summary = "TMA asynchronous store";
+ let description = [{
+ The Op store a tile memory region from global memory to shared memory by
+ Tensor Memory Access (TMA).
+
+ `$tensorMapDescriptor` is tensor map descriptor which has information about
+ tile shape. The descriptor is created by `nvgpu.tma.create.descriptor`
+ }];
+ let arguments = (ins Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
+ NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
+ Variadic<Index>:$coordinates,
+ Optional<I1>:$predicate);
+ let assemblyFormat = [{
+ $src `to` $tensorMapDescriptor `[` $coordinates `]`
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type($src)
+ `->` type($tensorMapDescriptor)
+ }];
+ let hasVerifier = 1;
+}
+
def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
let summary = "TMA create descriptor";
let description = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index db84e5cf62a5e9..759766275de4a5 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -995,6 +995,29 @@ struct NVGPUTmaAsyncLoadOpLowering
return success();
}
};
+
+struct NVGPUTmaAsyncStoreOpLowering
+ : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
+ using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
+ LogicalResult
+ matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
+ Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
+ adaptor.getSrc(), {}, rewriter);
+ SmallVector<Value> coords = adaptor.getCoordinates();
+ for (auto [index, value] : llvm::enumerate(coords)) {
+ coords[index] = truncToI32(b, value);
+ }
+
+ rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
+ op, adaptor.getTensorMapDescriptor(), dest, coords,
+ adaptor.getPredicate());
+ return success();
+ }
+};
+
struct NVGPUGenerateWarpgroupDescriptorLowering
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
using ConvertOpToLLVMPattern<
@@ -1639,6 +1662,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
+ NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index c9756ae8fc11ce..5ffa854e97cb17 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -402,6 +402,29 @@ LogicalResult TmaAsyncLoadOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// NVGPU_TmaAsyncStoreOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TmaAsyncStoreOp::verify() {
+ std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
+ *this, getTensorMapDescriptor().getType(), getSrc().getType());
+ if (error.has_value())
+ return error.value();
+
+ if (getCoordinates().size() > kMaxTMATensorDimension) {
+ return emitError() << "Maximum " << kMaxTMATensorDimension
+ << " coordinates are supported.";
+ }
+ if (getCoordinates().size() !=
+ size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
+ return emitError() << "number of coordinates do not match with the rank of "
+ "tensor descriptor map.";
+ }
+
+ return success();
+}
+
LogicalResult TmaCreateDescriptorOp::verify() {
if (getBoxDimensions().size() > kMaxTMATensorDimension) {
return emitError() << "Maximum " << kMaxTMATensorDimension
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index b8a0f75d1cc8b9..edccd7e80603bd 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -728,6 +728,52 @@ func.func @async_tma_load_multicast(
func.return
}
+func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
+ %buffer1d: memref<128xf32,3>,
+ %buffer2d: memref<32x32xf32,3>,
+ %buffer3d: memref<2x32x32xf32,3>,
+ %buffer4d: memref<2x2x32x32xf32,3>,
+ %buffer5d: memref<2x2x2x32x32xf32,3>) {
+ %c0 = arith.constant 0 : index
+ %crd0 = arith.constant 0 : index
+ %crd1 = arith.constant 0 : index
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}]
+ nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0] : memref<128xf32,3> -> !tensorMap1d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1] : memref<32x32xf32,3> -> !tensorMap2d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0] : memref<2x32x32xf32,3> -> !tensorMap3d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0] : memref<2x2x32x32xf32,3> -> !tensorMap4d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0] : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+ func.return
+}
+
+
+func.func @async_tma_store_predicate(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
+ %buffer1d: memref<128xf32,3>,
+ %buffer2d: memref<32x32xf32,3>,
+ %buffer3d: memref<2x32x32xf32,3>,
+ %buffer4d: memref<2x2x32x32xf32,3>,
+ %buffer5d: memref<2x2x2x32x32xf32,3>,
+ %p: i1) {
+ %c0 = arith.constant 0 : index
+ %crd0 = arith.constant 0 : index
+ %crd1 = arith.constant 0 : index
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0], predicate = %p : memref<128xf32,3> -> !tensorMap1d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1], predicate = %p : memref<32x32xf32,3> -> !tensorMap2d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0], predicate = %p : memref<2x32x32xf32,3> -> !tensorMap3d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0], predicate = %p : memref<2x2x32x32xf32,3> -> !tensorMap4d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], predicate = %p : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+ func.return
+}
+
func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
%crd0 = arith.constant 64 : index
%crd1 = arith.constant 128 : index
|
@llvm/pr-subscribers-mlir-nvgpu Author: Guray Ozen (grypp) ChangesPR adds It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA. Full diff: https://github.com/llvm/llvm-project/pull/77811.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 7e139663d74b47..239a5f1e2bc298 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -661,6 +661,28 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]
}
+def NVGPU_TmaAsyncStoreOp : NVGPU_Op<"tma.async.store", [AttrSizedOperandSegments]> {
+ let summary = "TMA asynchronous store";
+ let description = [{
+ The Op store a tile memory region from global memory to shared memory by
+ Tensor Memory Access (TMA).
+
+ `$tensorMapDescriptor` is tensor map descriptor which has information about
+ tile shape. The descriptor is created by `nvgpu.tma.create.descriptor`
+ }];
+ let arguments = (ins Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
+ NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
+ Variadic<Index>:$coordinates,
+ Optional<I1>:$predicate);
+ let assemblyFormat = [{
+ $src `to` $tensorMapDescriptor `[` $coordinates `]`
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type($src)
+ `->` type($tensorMapDescriptor)
+ }];
+ let hasVerifier = 1;
+}
+
def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
let summary = "TMA create descriptor";
let description = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index db84e5cf62a5e9..759766275de4a5 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -995,6 +995,29 @@ struct NVGPUTmaAsyncLoadOpLowering
return success();
}
};
+
+struct NVGPUTmaAsyncStoreOpLowering
+ : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
+ using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
+ LogicalResult
+ matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
+ Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
+ adaptor.getSrc(), {}, rewriter);
+ SmallVector<Value> coords = adaptor.getCoordinates();
+ for (auto [index, value] : llvm::enumerate(coords)) {
+ coords[index] = truncToI32(b, value);
+ }
+
+ rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
+ op, adaptor.getTensorMapDescriptor(), dest, coords,
+ adaptor.getPredicate());
+ return success();
+ }
+};
+
struct NVGPUGenerateWarpgroupDescriptorLowering
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
using ConvertOpToLLVMPattern<
@@ -1639,6 +1662,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
+ NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index c9756ae8fc11ce..5ffa854e97cb17 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -402,6 +402,29 @@ LogicalResult TmaAsyncLoadOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// NVGPU_TmaAsyncStoreOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TmaAsyncStoreOp::verify() {
+ std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
+ *this, getTensorMapDescriptor().getType(), getSrc().getType());
+ if (error.has_value())
+ return error.value();
+
+ if (getCoordinates().size() > kMaxTMATensorDimension) {
+ return emitError() << "Maximum " << kMaxTMATensorDimension
+ << " coordinates are supported.";
+ }
+ if (getCoordinates().size() !=
+ size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
+ return emitError() << "number of coordinates do not match with the rank of "
+ "tensor descriptor map.";
+ }
+
+ return success();
+}
+
LogicalResult TmaCreateDescriptorOp::verify() {
if (getBoxDimensions().size() > kMaxTMATensorDimension) {
return emitError() << "Maximum " << kMaxTMATensorDimension
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index b8a0f75d1cc8b9..edccd7e80603bd 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -728,6 +728,52 @@ func.func @async_tma_load_multicast(
func.return
}
+func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
+ %buffer1d: memref<128xf32,3>,
+ %buffer2d: memref<32x32xf32,3>,
+ %buffer3d: memref<2x32x32xf32,3>,
+ %buffer4d: memref<2x2x32x32xf32,3>,
+ %buffer5d: memref<2x2x2x32x32xf32,3>) {
+ %c0 = arith.constant 0 : index
+ %crd0 = arith.constant 0 : index
+ %crd1 = arith.constant 0 : index
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}]
+ nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0] : memref<128xf32,3> -> !tensorMap1d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1] : memref<32x32xf32,3> -> !tensorMap2d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0] : memref<2x32x32xf32,3> -> !tensorMap3d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0] : memref<2x2x32x32xf32,3> -> !tensorMap4d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0] : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+ func.return
+}
+
+
+func.func @async_tma_store_predicate(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
+ %buffer1d: memref<128xf32,3>,
+ %buffer2d: memref<32x32xf32,3>,
+ %buffer3d: memref<2x32x32xf32,3>,
+ %buffer4d: memref<2x2x32x32xf32,3>,
+ %buffer5d: memref<2x2x2x32x32xf32,3>,
+ %p: i1) {
+ %c0 = arith.constant 0 : index
+ %crd0 = arith.constant 0 : index
+ %crd1 = arith.constant 0 : index
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0], predicate = %p : memref<128xf32,3> -> !tensorMap1d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1], predicate = %p : memref<32x32xf32,3> -> !tensorMap2d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0], predicate = %p : memref<2x32x32xf32,3> -> !tensorMap3d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0], predicate = %p : memref<2x2x32x32xf32,3> -> !tensorMap4d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], predicate = %p : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+ func.return
+}
+
func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
%crd0 = arith.constant 64 : index
%crd1 = arith.constant 128 : index
|
`->` type($tensorMapDescriptor) | ||
}]; | ||
let hasVerifier = 1; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we also need to add the commit/wait_group version for TMA stores? Otherwise this instruction is not that useful? Unless you expect people to be calling the NVVM versions for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Valid point. I suggest using these OPs directly from the NVVM
dialect.
Why? The NVGPU
dialect acts as a bridge between high-level dialects (such as memref
, vector
, and etc.) and the NVVM
dialect. The cp.async.bulk.commit.group
is straightforward to use, and introducing it to the NVGPU dialect would only result in replication.
If we need higher-level abstractions (what we have for operations like mbarrier.group
or nvgpu.tma.descriptor
) then we can introduce specific Ops in the NVGPU dialect.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(just noticed that cp.async.bulk.wait_group.read
is missing in NVVM
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, sounds good to me. Perhaps it would be a good idea to add a pointer to the NVVM op to the description here.
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, | ||
adaptor.getSrc(), {}, rewriter); | ||
SmallVector<Value> coords = adaptor.getCoordinates(); | ||
for (auto [index, value] : llvm::enumerate(coords)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not this for PR but we should finally add reversal to the coords to make the dimension ordering work as MLIR normally expects
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, let's do this in another PR
`->` type($tensorMapDescriptor) | ||
}]; | ||
let hasVerifier = 1; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, sounds good to me. Perhaps it would be a good idea to add a pointer to the NVVM op to the description here.
PR adds `nvgpu.tma.async.store` Op for asynchronous stores using the Tensor Memory Access (TMA) unit. It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.
PR adds
nvgpu.tma.async.store
Op for asynchronous stores using the Tensor Memory Access (TMA) unit.It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.