Skip to content

Commit 50a76a7

Browse files
committed
[MLIR][NVGPU] Handling Offset in nvgpu.tma.async.load
When using `nvgpu.tma.async.load` Op to asynchronously load data into shared memory, it fails to account for provided offsets, potentially leading to incorrect memory access. Using offset is common practice especially with the dynamic shared memory. This work addresses the problem by ensuring proper consideration of offsets. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D157380
1 parent af635a5 commit 50a76a7

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,8 +914,9 @@ struct NVGPUTmaAsyncLoadOpLowering
914914
LogicalResult
915915
matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
916916
ConversionPatternRewriter &rewriter) const override {
917-
auto dest = rewriter.create<LLVM::ExtractValueOp>(op->getLoc(),
918-
adaptor.getDst(), 1);
917+
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
918+
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
919+
adaptor.getDst(), {}, rewriter);
919920
Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
920921
op.getBarrier(), adaptor.getBarrier());
921922

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,35 @@ func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : m
647647
%tensorMap1d = nvgpu.tma.create.descriptor %devicePtr1d_unranked box[%crd1] : memref<*xf32> -> !tensorMap1d
648648
func.return
649649
}
650+
651+
// -----
652+
653+
!lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
654+
!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
655+
656+
!barrierType = !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
657+
658+
!shmemlhs = memref<128x64xf16,3>
659+
!shmemrhs = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>
660+
661+
module @mymodule {
662+
// Dynamic Shared memory
663+
memref.global "private" @dynamicShmem : memref<0xf16,3>
664+
665+
func.func @async_tma_load(%lhsTensorMap: !lhsTensorMap, %rhsTensorMap: !rhsTensorMap, %mbarrier: !barrierType) {
666+
%c0 = arith.constant 0 : index
667+
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
668+
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to !shmemlhs
669+
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2,64,128], strides: [8192,128,1] : memref<0xf16, 3> to memref<2x64x128xf16,3>
670+
%rhsShmem3 = memref.subview %rhsShmem2[1,0,0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16,3> to memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3>
671+
%rhsShmem = memref.subview %rhsShmem3[0,0,0][1, 64, 128][1, 1, 1] : memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3> to !shmemrhs
672+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global
673+
nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
674+
// CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
675+
// CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
676+
// CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
677+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32
678+
nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs
679+
return
680+
}
681+
}

0 commit comments

Comments
 (0)