Skip to content

Commit 4319e19

Browse files
authored
[mlir][nvgpu] Introduce Multicast Capability to nvgpu.tma.async.load (#76935)
This PR improves the functionality of the `nvgpu.tma.async.load` Op by adding support for multicast. While we already had this capability in the lower-level `nvvm.cp.async.bulk.tensor.shared.cluster.global` NVVM Op, this PR lowers mask information to the NVVM operation.
1 parent a001e97 commit 4319e19

File tree

4 files changed

+35
-9
lines changed

4 files changed

+35
-9
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -642,16 +642,18 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]
642642

643643
The Op uses `$barrier` mbarrier based completion mechanism.
644644
}];
645-
let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$dst,
646-
NVGPU_MBarrierGroup:$barriers,
647-
NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
648-
Variadic<Index>:$coordinates,
649-
Index:$mbarId,
650-
Optional<I1>:$predicate);
645+
let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$dst,
646+
NVGPU_MBarrierGroup:$barriers,
647+
NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
648+
Variadic<Index>:$coordinates,
649+
Index:$mbarId,
650+
Optional<I16>:$multicastMask,
651+
Optional<I1>:$predicate);
651652
let assemblyFormat = [{
652653
$tensorMapDescriptor `[` $coordinates `]` `,` $barriers `[` $mbarId `]`
653654
`to` $dst
654-
(`,` `predicate` `=` $predicate^)?
655+
(`multicast_mask` `=` $multicastMask^ )?
656+
(`,` `predicate` `=` $predicate^)?
655657
attr-dict `:` type($tensorMapDescriptor) `,` type($barriers)
656658
`->` type($dst)
657659
}];

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,8 @@ struct NVGPUTmaAsyncLoadOpLowering
990990
}
991991
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
992992
op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
993-
ValueRange{}, Value{}, Value{}, adaptor.getPredicate());
993+
ValueRange{}, adaptor.getMulticastMask(), Value{},
994+
adaptor.getPredicate());
994995
return success();
995996
}
996997
};

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad(
980980
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
981981
Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
982982
loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
983-
Value());
983+
Value(), Value());
984984
loadOps.push_back(loadOp);
985985
auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
986986
SmallVector<AffineExpr> symbols(mixedSizes.size());

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,29 @@ func.func @async_tma_load_pred(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensor
704704
func.return
705705
}
706706

707+
func.func @async_tma_load_multicast(
708+
%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d,
709+
%tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d,
710+
%tensorMap5d: !tensorMap5d, %buffer1d: memref<128xf32,3>,
711+
%buffer2d: memref<32x32xf32,3>, %buffer3d: memref<2x32x32xf32,3>,
712+
%buffer4d: memref<2x2x32x32xf32,3>, %buffer5d: memref<2x2x2x32x32xf32,3>,
713+
%mbarrier: !mbarrier,
714+
%multicastMask: i16) {
715+
%c0 = arith.constant 0 : index
716+
%crd0 = arith.constant 0 : index
717+
%crd1 = arith.constant 0 : index
718+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}]
719+
nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier[%c0] to %buffer1d multicast_mask = %multicastMask : !tensorMap1d, !mbarrier -> memref<128xf32,3>
720+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}]
721+
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d multicast_mask = %multicastMask : !tensorMap2d, !mbarrier -> memref<32x32xf32,3>
722+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}]
723+
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d multicast_mask = %multicastMask : !tensorMap3d, !mbarrier -> memref<2x32x32xf32,3>
724+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
725+
nvgpu.tma.async.load %tensorMap4d[%crd0, %crd1, %crd1, %crd0], %mbarrier[%c0] to %buffer4d multicast_mask = %multicastMask : !tensorMap4d, !mbarrier -> memref<2x2x32x32xf32,3>
726+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
727+
nvgpu.tma.async.load %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], %mbarrier[%c0] to %buffer5d multicast_mask = %multicastMask : !tensorMap5d, !mbarrier -> memref<2x2x2x32x32xf32,3>
728+
func.return
729+
}
707730

708731
func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
709732
%crd0 = arith.constant 64 : index

0 commit comments

Comments
 (0)