Skip to content

Commit 9ceea08

Browse files
authored
[mlir] im2col & l2cache on cp.async.bulk.tensor.shared.cluster.global` (#72967)
PR adds support of `im2col` and `l2cache` to `cp.async.bulk.tensor.shared.cluster.global`. The Op is now supports all the traits of the corresponding PTX instruction. The current structure of this operation looks somewhat like below. The PR also simplifies types so we don't need to write obvious types after `:` anymore. ``` nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] <-- PR introduces multicast_mask = %ctamask l2_cache_hint = %cacheHint <-- PR introduces : !llvm.ptr<3>, !llvm.ptr ```
1 parent ed5404c commit 9ceea08

File tree

6 files changed

+170
-78
lines changed

6 files changed

+170
-78
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,20 +1404,34 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
14041404
AttrSizedOperandSegments]>,
14051405
Arguments<(ins LLVM_PointerShared:$dstMem,
14061406
LLVM_AnyPointer:$tmaDescriptor,
1407-
LLVM_PointerShared:$mbar,
1408-
Optional<I16>:$multicastMask,
14091407
Variadic<I32>:$coordinates,
1408+
LLVM_PointerShared:$mbar,
1409+
Variadic<I16>:$im2colOffsets,
1410+
Optional<I16>:$multicastMask,
1411+
Optional<I64>:$l2CacheHint,
14101412
PtxPredicate:$predicate)> {
14111413
let description = [{
14121414
Initiates an asynchronous copy operation on the tensor data from global
14131415
memory to shared memory.
14141416

1417+
The Op operates has two load modes:
1418+
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
1419+
layout is preserved at the destination.
1420+
1421+
2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
1422+
the elements in the Bounding Box of the source tensor are rearranged into
1423+
columns at the destination. In this mode, the tensor has to be at least
1424+
3-dimensional.
1425+
14151426
The `multicastMask` operand is optional. When it is present, the Op copies
14161427
data from global memory to shared memory of multiple CTAs in the cluster.
14171428
Operand `multicastMask` specifies the destination CTAs in the cluster such
14181429
that each bit position in the 16-bit `multicastMask` operand corresponds to
1419-
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
1430+
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
14201431

1432+
The `l2CacheHint` operand is optional, and it is used to specify cache
1433+
eviction policy that may be used during the memory access.
1434+
14211435
[For more information, see PTX ISA]
14221436
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
14231437
}];
@@ -1426,32 +1440,42 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
14261440
$dstMem `,`
14271441
$tmaDescriptor `,`
14281442
$mbar `,`
1429-
( `multicast_mask` `=` $multicastMask^ `,` )?
1430-
`box` `[`$coordinates `]`
1431-
(`,` `predicate` `=` $predicate^)?
1432-
attr-dict `:` type(operands)
1443+
`box` `[`$coordinates `]`
1444+
(`im2col` `[` $im2colOffsets^ `]` )?
1445+
(`multicast_mask` `=` $multicastMask^ )?
1446+
(`l2_cache_hint` `=` $l2CacheHint^ )?
1447+
(`predicate` `=` $predicate^)?
1448+
attr-dict `:` type($dstMem) `,` type($tmaDescriptor)
14331449
}];
14341450

14351451
let extraClassDefinition = [{
14361452
std::string $cppClass::getPtx() {
1453+
int im2colDim = getIm2colOffsets().size();
14371454
int dim = getCoordinates().size();
14381455
std::string ptx = "cp.async.bulk.tensor.";
14391456
ptx += std::to_string(dim) + "d.";
1440-
ptx += "shared::cluster.global.mbarrier::complete_tx::bytes";
1441-
if(getMulticastMask()) {
1442-
ptx += ".multicast::cluster";
1443-
if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
1444-
if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
1445-
if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
1446-
if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
1447-
if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
1448-
} else {
1449-
if(dim == 1) ptx += " [%0], [%1, {%3} ], [%2];";
1450-
if(dim == 2) ptx += " [%0], [%1, {%3, %4} ], [%2];";
1451-
if(dim == 3) ptx += " [%0], [%1, {%3, %4, %5} ], [%2];";
1452-
if(dim == 4) ptx += " [%0], [%1, {%3, %4, %5, %6} ], [%2];";
1453-
if(dim == 5) ptx += " [%0], [%1, {%3, %4, %5, %6, %7} ], [%2];";
1457+
ptx += "shared::cluster.global.mbarrier::complete_tx::bytes";
1458+
if(im2colDim) ptx += ".im2col";
1459+
if(getMulticastMask()) ptx += ".multicast::cluster";
1460+
if(getL2CacheHint()) ptx += ".L2::cache_hint";
1461+
1462+
auto preg = [](int r) { return "%" + std::to_string(r); };
1463+
1464+
// Build Registers
1465+
ptx += " [%0], [%1, {";
1466+
int r = 2;
1467+
for(int i = 0; i < dim; i++) ptx += preg(r+i) + ",";
1468+
ptx.pop_back(); r += dim;
1469+
ptx += "} ], [%" + std::to_string(r++) + "]";
1470+
if(im2colDim) {
1471+
ptx += ",{";
1472+
for(int i = 0; i < im2colDim; i++) ptx += preg(r+i) + ",";
1473+
ptx.pop_back(); r += im2colDim;
1474+
ptx += "}";
14541475
}
1476+
if(getMulticastMask()) ptx += ", " + preg(r++);
1477+
if(getL2CacheHint()) ptx += ", " + preg(r++);
1478+
ptx += ";";
14551479
return ptx;
14561480
}
14571481
}];

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -973,10 +973,9 @@ struct NVGPUTmaAsyncLoadOpLowering
973973
for (auto [index, value] : llvm::enumerate(coords)) {
974974
coords[index] = truncToI32(b, value);
975975
}
976-
977976
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
978-
op, dest, adaptor.getTensorMapDescriptor(), barrier, Value(), coords,
979-
adaptor.getPredicate());
977+
op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
978+
ValueRange{}, Value{}, Value{}, adaptor.getPredicate());
980979
return success();
981980
}
982981
};

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,18 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
7676
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
7777

7878
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
79-
if (getCoordinates().size() > 5)
80-
return emitError("Maximum 5 coordinates and dimension is supported.");
79+
if (getCoordinates().empty() || getCoordinates().size() > 5)
80+
return emitError("expects coordinates between 1 to 5 dimension");
81+
82+
// Check for im2col mode
83+
if (!getIm2colOffsets().empty()) {
84+
if (getCoordinates().size() < 3)
85+
return emitError(
86+
"to use im2col mode, the tensor has to be at least 3-dimensional");
87+
if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
88+
return emitError(
89+
"im2col offsets must be 2 less than number of coordinates");
90+
}
8191
return success();
8292
}
8393

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -653,15 +653,15 @@ func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d
653653
%c0 = arith.constant 0 : index
654654
%crd0 = arith.constant 0 : index
655655
%crd1 = arith.constant 0 : index
656-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}]
656+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}]
657657
nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier[%c0] to %buffer1d : !tensorMap1d, !mbarrier -> memref<128xf32,3>
658-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
658+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}]
659659
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d : !tensorMap2d, !mbarrier -> memref<32x32xf32,3>
660-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}]
660+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}]
661661
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d : !tensorMap3d, !mbarrier -> memref<2x32x32xf32,3>
662-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
662+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
663663
nvgpu.tma.async.load %tensorMap4d[%crd0, %crd1, %crd1, %crd0], %mbarrier[%c0] to %buffer4d : !tensorMap4d, !mbarrier -> memref<2x2x32x32xf32,3>
664-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
664+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
665665
nvgpu.tma.async.load %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], %mbarrier[%c0] to %buffer5d : !tensorMap5d, !mbarrier -> memref<2x2x2x32x32xf32,3>
666666
func.return
667667
}
@@ -678,15 +678,15 @@ func.func @async_tma_load_pred(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensor
678678
%c0 = arith.constant 0 : index
679679
%crd0 = arith.constant 0 : index
680680
%crd1 = arith.constant 0 : index
681-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
681+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}] predicate = %{{.*}}
682682
nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier[%c0] to %buffer1d, predicate = %p : !tensorMap1d, !mbarrier -> memref<128xf32,3>
683-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
683+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}] predicate = %{{.*}}
684684
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d, predicate = %p : !tensorMap2d, !mbarrier -> memref<32x32xf32,3>
685-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
685+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}] predicate = %{{.*}}
686686
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d, predicate = %p : !tensorMap3d, !mbarrier -> memref<2x32x32xf32,3>
687-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
687+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] predicate = %{{.*}}
688688
nvgpu.tma.async.load %tensorMap4d[%crd0, %crd1, %crd1, %crd0], %mbarrier[%c0] to %buffer4d, predicate = %p : !tensorMap4d, !mbarrier -> memref<2x2x32x32xf32,3>
689-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
689+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] predicate = %{{.*}}
690690
nvgpu.tma.async.load %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], %mbarrier[%c0] to %buffer5d, predicate = %p : !tensorMap5d, !mbarrier -> memref<2x2x2x32x32xf32,3>
691691
func.return
692692
}
@@ -737,8 +737,8 @@ module @mymodule {
737737
nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
738738
// CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
739739
// CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
740-
// CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
741-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32
740+
// CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64)
741+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
742742
nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs
743743
return
744744
}

0 commit comments

Comments
 (0)