1
+ // RUN: mlir-opt %s --convert-nvgpu-to-nvvm -gpu-kernel-outlining \
2
+ // RUN: -convert-scf-to-cf -convert-nvvm-to-llvm \
3
+ // RUN: -convert-vector-to-llvm \
4
+ // RUN: -convert-math-to-llvm \
5
+ // RUN: -expand-strided-metadata \
6
+ // RUN: -lower-affine \
7
+ // RUN: -convert-index-to-llvm=index-bitwidth=32 \
8
+ // RUN: -convert-arith-to-llvm \
9
+ // RUN: -finalize-memref-to-llvm \
10
+ // RUN: -convert-func-to-llvm \
11
+ // RUN: -canonicalize \
12
+ // RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_90 features=+ptx80 dump-ptx}))' \
13
+ // RUN: 2&>1 | FileCheck %s --check-prefixes=CHECK-PTX
14
+
15
+ // CHECK-PTX: mbarrier.init.shared.b64
16
+ // CHECK-PTX: mbarrier.arrive.expect_tx.shared.b64
17
+ // CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
18
+ // CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
19
+ // CHECK-PTX: mbarrier.arrive.expect_tx.shared.b64
20
+ // CHECK-PTX: mbarrier.try_wait.parity.shared.b64
21
+
22
+ module @mymod {
23
+ memref.global " private" @bufferLhsGlobal : memref <64 x8 xf32 , 3 >
24
+ memref.global " private" @bufferRhsGlobal : memref <8 x128 xf32 , 3 >
25
+ func.func @main () {
26
+ %c10000000 = arith.constant 10000000 : index
27
+ %c6144 = arith.constant 6144 : index
28
+ %c45 = arith.constant 45 : index
29
+ %c7 = arith.constant 7 : index
30
+ %c64 = arith.constant 64 : index
31
+ %c1 = arith.constant 1 : index
32
+ %c0 = arith.constant 0 : index
33
+ %c8 = arith.constant 8 : index
34
+ %c128 = arith.constant 128 : index
35
+ %cst = arith.constant 3.000000e+00 : f32
36
+ %alloc = memref.alloc () : memref <64 x8 xf32 >
37
+ %alloc_0 = memref.alloc () : memref <8 x128 xf32 >
38
+ scf.for %arg0 = %c0 to %c8 step %c1 {
39
+ scf.for %arg1 = %c0 to %c128 step %c1 {
40
+ memref.store %cst , %alloc_0 [%arg0 , %arg1 ] : memref <8 x128 xf32 >
41
+ }
42
+ }
43
+ scf.for %arg0 = %c0 to %c64 step %c1 {
44
+ scf.for %arg1 = %c0 to %c8 step %c1 {
45
+ %5 = arith.index_cast %arg1 : index to i64
46
+ %6 = arith.uitofp %5 : i64 to f32
47
+ memref.store %6 , %alloc [%arg0 , %arg1 ] : memref <64 x8 xf32 >
48
+ }
49
+ }
50
+ %0 = gpu.wait async
51
+ %memref , %asyncToken = gpu.alloc async [%0 ] () : memref <64 x8 xf32 >
52
+ %memref_1 , %asyncToken_2 = gpu.alloc async [%0 ] () : memref <8 x128 xf32 >
53
+ %1 = gpu.memcpy async [%0 ] %memref , %alloc : memref <64 x8 xf32 >, memref <64 x8 xf32 >
54
+ %2 = gpu.memcpy async [%0 ] %memref_1 , %alloc_0 : memref <8 x128 xf32 >, memref <8 x128 xf32 >
55
+ %cast = memref.cast %memref : memref <64 x8 xf32 > to memref <*xf32 >
56
+ %cast_3 = memref.cast %memref_1 : memref <8 x128 xf32 > to memref <*xf32 >
57
+ %3 = nvgpu.tma.create.descriptor %cast box [%c64 , %c8 ] : memref <*xf32 > -> <tensor = memref <64 x8 xf32 , 3 >, swizzle = none , l2promo = none , oob = zero , interleave = none >
58
+ %4 = nvgpu.tma.create.descriptor %cast_3 box [%c8 , %c128 ] : memref <*xf32 > -> <tensor = memref <8 x128 xf32 , 3 >, swizzle = none , l2promo = none , oob = zero , interleave = none >
59
+ gpu.launch blocks (%arg0 , %arg1 , %arg2 ) in (%arg6 = %c1 , %arg7 = %c1 , %arg8 = %c1 ) threads (%arg3 , %arg4 , %arg5 ) in (%arg9 = %c128 , %arg10 = %c1 , %arg11 = %c1 ) {
60
+ %5 = gpu.block_dim x
61
+ %6 = gpu.thread_id x
62
+ %7 = memref.get_global @bufferLhsGlobal : memref <64 x8 xf32 , 3 >
63
+ %8 = memref.get_global @bufferRhsGlobal : memref <8 x128 xf32 , 3 >
64
+ %9 = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space <workgroup >>
65
+ nvgpu.mbarrier.init %9 , %5 : <memorySpace = #gpu.address_space <workgroup >>
66
+ gpu.barrier
67
+ %10 = arith.cmpi eq , %6 , %c0 : index
68
+ scf.if %10 {
69
+ nvgpu.mbarrier.arrive.expect_tx %9 , %c6144 : <memorySpace = #gpu.address_space <workgroup >>
70
+ %11 = memref.load %7 [%c0 , %c0 ] : memref <64 x8 xf32 , 3 >
71
+ %12 = memref.load %8 [%c0 , %c0 ] : memref <8 x128 xf32 , 3 >
72
+ gpu.printf " [GPU] TMA BEFORE lhs[45][7] %f\0A" %11 : f32
73
+ gpu.printf " [GPU] TMA BEFORE rhs[7][0] %f\0A" %12 : f32
74
+ nvgpu.tma.async.load %3 [%c0 , %c0 ], %9 to %7 : <tensor = memref <64 x8 xf32 , 3 >, swizzle = none , l2promo = none , oob = zero , interleave = none >, <memorySpace = #gpu.address_space <workgroup >> -> memref <64 x8 xf32 , 3 >
75
+ nvgpu.tma.async.load %4 [%c0 , %c0 ], %9 to %8 : <tensor = memref <8 x128 xf32 , 3 >, swizzle = none , l2promo = none , oob = zero , interleave = none >, <memorySpace = #gpu.address_space <workgroup >> -> memref <8 x128 xf32 , 3 >
76
+ } else {
77
+ nvgpu.mbarrier.arrive.expect_tx %9 , %c0 : <memorySpace = #gpu.address_space <workgroup >>
78
+ }
79
+ nvgpu.mbarrier.try_wait.parity %9 , %c0 , %c10000000 : <memorySpace = #gpu.address_space <workgroup >>
80
+ scf.if %10 {
81
+ %11 = memref.load %7 [%c45 , %c7 ] : memref <64 x8 xf32 , 3 >
82
+ %12 = memref.load %8 [%c7 , %c0 ] : memref <8 x128 xf32 , 3 >
83
+ gpu.printf " [GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32
84
+ gpu.printf " [GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32
85
+ }
86
+ gpu.terminator
87
+ }
88
+ return
89
+ }
90
+ }
0 commit comments