Skip to content

Commit 8dd0d95

Browse files
authored
[mlir][nvgpu] Add nvgpu.tma.async.store (#77811)
PR adds `nvgpu.tma.async.store` Op for asynchronous stores using the Tensor Memory Access (TMA) unit. It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.
1 parent 0e1037e commit 8dd0d95

File tree

4 files changed

+115
-0
lines changed

4 files changed

+115
-0
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,28 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]
661661

662662
}
663663

664+
def NVGPU_TmaAsyncStoreOp : NVGPU_Op<"tma.async.store", [AttrSizedOperandSegments]> {
665+
let summary = "TMA asynchronous store";
666+
let description = [{
667+
The Op store a tile memory region from global memory to shared memory by
668+
Tensor Memory Access (TMA).
669+
670+
`$tensorMapDescriptor` is tensor map descriptor which has information about
671+
tile shape. The descriptor is created by `nvgpu.tma.create.descriptor`
672+
}];
673+
let arguments = (ins Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
674+
NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
675+
Variadic<Index>:$coordinates,
676+
Optional<I1>:$predicate);
677+
let assemblyFormat = [{
678+
$src `to` $tensorMapDescriptor `[` $coordinates `]`
679+
(`,` `predicate` `=` $predicate^)?
680+
attr-dict `:` type($src)
681+
`->` type($tensorMapDescriptor)
682+
}];
683+
let hasVerifier = 1;
684+
}
685+
664686
def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
665687
let summary = "TMA create descriptor";
666688
let description = [{

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,29 @@ struct NVGPUTmaAsyncLoadOpLowering
995995
return success();
996996
}
997997
};
998+
999+
struct NVGPUTmaAsyncStoreOpLowering
1000+
: public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1001+
using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1002+
LogicalResult
1003+
matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1004+
ConversionPatternRewriter &rewriter) const override {
1005+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1006+
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1007+
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1008+
adaptor.getSrc(), {}, rewriter);
1009+
SmallVector<Value> coords = adaptor.getCoordinates();
1010+
for (auto [index, value] : llvm::enumerate(coords)) {
1011+
coords[index] = truncToI32(b, value);
1012+
}
1013+
1014+
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1015+
op, adaptor.getTensorMapDescriptor(), dest, coords,
1016+
adaptor.getPredicate());
1017+
return success();
1018+
}
1019+
};
1020+
9981021
struct NVGPUGenerateWarpgroupDescriptorLowering
9991022
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
10001023
using ConvertOpToLLVMPattern<
@@ -1639,6 +1662,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
16391662
NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
16401663
NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
16411664
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1665+
NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
16421666
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
16431667
NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
16441668
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,29 @@ LogicalResult TmaAsyncLoadOp::verify() {
405405
return success();
406406
}
407407

408+
//===----------------------------------------------------------------------===//
409+
// NVGPU_TmaAsyncStoreOp
410+
//===----------------------------------------------------------------------===//
411+
412+
LogicalResult TmaAsyncStoreOp::verify() {
413+
std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
414+
*this, getTensorMapDescriptor().getType(), getSrc().getType());
415+
if (error.has_value())
416+
return error.value();
417+
418+
if (getCoordinates().size() > kMaxTMATensorDimension) {
419+
return emitError() << "Maximum " << kMaxTMATensorDimension
420+
<< " coordinates are supported.";
421+
}
422+
if (getCoordinates().size() !=
423+
size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
424+
return emitError() << "number of coordinates do not match with the rank of "
425+
"tensor descriptor map.";
426+
}
427+
428+
return success();
429+
}
430+
408431
LogicalResult TmaCreateDescriptorOp::verify() {
409432
if (getBoxDimensions().size() > kMaxTMATensorDimension) {
410433
return emitError() << "Maximum " << kMaxTMATensorDimension

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,52 @@ func.func @async_tma_load_multicast(
728728
func.return
729729
}
730730

731+
func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
732+
%buffer1d: memref<128xf32,3>,
733+
%buffer2d: memref<32x32xf32,3>,
734+
%buffer3d: memref<2x32x32xf32,3>,
735+
%buffer4d: memref<2x2x32x32xf32,3>,
736+
%buffer5d: memref<2x2x2x32x32xf32,3>) {
737+
%c0 = arith.constant 0 : index
738+
%crd0 = arith.constant 0 : index
739+
%crd1 = arith.constant 0 : index
740+
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}]
741+
nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0] : memref<128xf32,3> -> !tensorMap1d
742+
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}]
743+
nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1] : memref<32x32xf32,3> -> !tensorMap2d
744+
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}]
745+
nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0] : memref<2x32x32xf32,3> -> !tensorMap3d
746+
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
747+
nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0] : memref<2x2x32x32xf32,3> -> !tensorMap4d
748+
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
749+
nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0] : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
750+
func.return
751+
}
752+
753+
754+
func.func @async_tma_store_predicate(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
755+
%buffer1d: memref<128xf32,3>,
756+
%buffer2d: memref<32x32xf32,3>,
757+
%buffer3d: memref<2x32x32xf32,3>,
758+
%buffer4d: memref<2x2x32x32xf32,3>,
759+
%buffer5d: memref<2x2x2x32x32xf32,3>,
760+
%p: i1) {
761+
%c0 = arith.constant 0 : index
762+
%crd0 = arith.constant 0 : index
763+
%crd1 = arith.constant 0 : index
764+
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
765+
nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0], predicate = %p : memref<128xf32,3> -> !tensorMap1d
766+
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
767+
nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1], predicate = %p : memref<32x32xf32,3> -> !tensorMap2d
768+
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
769+
nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0], predicate = %p : memref<2x32x32xf32,3> -> !tensorMap3d
770+
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
771+
nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0], predicate = %p : memref<2x2x32x32xf32,3> -> !tensorMap4d
772+
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
773+
nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], predicate = %p : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
774+
func.return
775+
}
776+
731777
func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
732778
%crd0 = arith.constant 64 : index
733779
%crd1 = arith.constant 128 : index

0 commit comments

Comments
 (0)