40
40
// |-------------------------------|
41
41
42
42
43
- !barrierType = !nvgpu.mbarrier.barrier <memorySpace = #gpu.address_space <workgroup >>
43
+ !barrierType = !nvgpu.mbarrier.group <memorySpace = #gpu.address_space <workgroup >>
44
44
!tokenType = !nvgpu.mbarrier.token
45
45
46
46
!lhs = memref <128 x64 xf16 >
@@ -96,56 +96,50 @@ module @mymod {
96
96
memref.store %vL32 , %lhs32 [%j , %i ] : memref <128 x64 xf32 >
97
97
}
98
98
}
99
-
100
- // Step 2. Print on the host
101
- %lhs32_unranked = memref.cast %lhs32 : memref <128 x64 xf32 > to memref <*xf32 >
102
- call @printMemrefF32 (%lhs32_unranked ) : (memref <*xf32 >) -> ()
103
- %rhs32_unranked = memref.cast %rhs32 : memref <64 x128 xf32 > to memref <*xf32 >
104
- call @printMemrefF32 (%rhs32_unranked ) : (memref <*xf32 >) -> ()
105
99
106
- // Step 3 . Copy host to device
100
+ // Step 2 . Copy host to device
107
101
%0 = gpu.wait async
108
102
%d_glbmem_lhs , %asyncToken = gpu.alloc async [%0 ] () : !lhs
109
103
%d_glbmem_rhs , %asyncToken_2 = gpu.alloc async [%0 ] () : !rhs
110
104
%1 = gpu.memcpy async [%0 ] %d_glbmem_lhs , %lhs : !lhs , !lhs
111
105
%2 = gpu.memcpy async [%0 ] %d_glbmem_rhs , %rhs : !rhs , !rhs
112
106
113
- // Step 4 . Create TMA tensor descriptor
107
+ // Step 3 . Create TMA tensor descriptor
114
108
%d_lhs_unranked = memref.cast %d_glbmem_lhs :!lhs to memref <*xf16 >
115
109
%d_rhs_unranked = memref.cast %d_glbmem_rhs :!rhs to memref <*xf16 >
116
110
117
111
%d_lhsTensorMap = nvgpu.tma.create.descriptor %d_lhs_unranked box [%c128 , %c64 ] : memref <*xf16 > -> !lhsTensorMap
118
112
%d_rhsTensorMap = nvgpu.tma.create.descriptor %d_rhs_unranked box [%c64 , %c64 ] : memref <*xf16 > -> !rhsTensorMap
119
113
120
- // Step 5 . Launch a GPU kernel
114
+ // Step 4 . Launch a GPU kernel
121
115
gpu.launch blocks (%arg0 , %arg1 , %arg2 ) in (%arg6 = %c1 , %arg7 = %c1 , %arg8 = %c1 ) threads (%arg3 , %arg4 , %arg5 ) in (%arg9 = %c128 , %arg10 = %c1 , %arg11 = %c1 ) {
122
116
%5 = gpu.block_dim x
123
117
%6 = gpu.thread_id x
124
118
%lhsShmem = memref.get_global @bufferLhsGlobal : !shmemlhs
125
119
%rhsShmem = memref.get_global @bufferRhsGlobal : !shmemrhs
126
120
%rhsShmem2 = memref.subview %rhsShmem [%c32 , %c0 ][%c32 , %c128 ][%c1 , %c1 ] : !shmemrhs to memref <?x?xf16 , strided <[?, ?], offset : ?>, 3 >
127
121
128
- // Step 6 . Initialize the mbarrier
122
+ // Step 5 . Initialize the mbarrier
129
123
%9 = nvgpu.mbarrier.create -> !barrierType
130
- nvgpu.mbarrier.init %9 , %5 : !barrierType
124
+ nvgpu.mbarrier.init %9 [ %c0 ] , %5 : !barrierType
131
125
%10 = arith.cmpi eq , %6 , %c0 : index
132
126
133
127
134
- // Step 7 . First thread does TMA load
128
+ // Step 6 . First thread does TMA load
135
129
scf.if %10 {
136
130
gpu.printf " [GPU] TMA SIZE %d\0A" %c32768 : index
137
- nvgpu.tma.async.load %d_lhsTensorMap [%c0 , %c0 ], %9 to %lhsShmem : !lhsTensorMap , !barrierType -> !shmemlhs
138
- nvgpu.tma.async.load %d_rhsTensorMap [%c0 , %c0 ], %9 to %rhsShmem : !rhsTensorMap , !barrierType -> !shmemrhs
139
- nvgpu.tma.async.load %d_rhsTensorMap [%c64 , %c0 ], %9 to %rhsShmem2 : !rhsTensorMap , !barrierType -> memref <?x?xf16 , strided <[?, ?], offset : ?>, 3 >
140
- nvgpu.mbarrier.arrive.expect_tx %9 , %c32768 : !barrierType
131
+ nvgpu.tma.async.load %d_lhsTensorMap [%c0 , %c0 ], %9 [ %c0 ] to %lhsShmem : !lhsTensorMap , !barrierType -> !shmemlhs
132
+ nvgpu.tma.async.load %d_rhsTensorMap [%c0 , %c0 ], %9 [ %c0 ] to %rhsShmem : !rhsTensorMap , !barrierType -> !shmemrhs
133
+ nvgpu.tma.async.load %d_rhsTensorMap [%c64 , %c0 ], %9 [ %c0 ] to %rhsShmem2 : !rhsTensorMap , !barrierType -> memref <?x?xf16 , strided <[?, ?], offset : ?>, 3 >
134
+ nvgpu.mbarrier.arrive.expect_tx %9 [ %c0 ] , %c32768 : !barrierType
141
135
} else {
142
- nvgpu.mbarrier.arrive.expect_tx %9 , %c0 : !barrierType
136
+ nvgpu.mbarrier.arrive.expect_tx %9 [ %c0 ] , %c0 : !barrierType
143
137
}
144
138
145
- // Step 8 . Wait until TMA is done
146
- nvgpu.mbarrier.try_wait.parity %9 , %c0 , %c10000000 : !barrierType
139
+ // Step 7 . Wait until TMA is done
140
+ nvgpu.mbarrier.try_wait.parity %9 [ %c0 ] , %c0 , %c10000000 : !barrierType
147
141
148
- // Step 9 . Print loaded data in 128b swizzled
142
+ // Step 8 . Print loaded data in 128b swizzled
149
143
scf.if %10 {
150
144
gpu.printf " ===--- Matrix B ---=== %d \n " %c -1 _i32 : i32
151
145
scf.for %ii = %c0 to %c64 step %c1 {
@@ -158,6 +152,7 @@ module @mymod {
158
152
}
159
153
gpu.printf " ===----------------=== %d \n " %c -1 _i32 : i32
160
154
}
155
+ gpu.barrier
161
156
gpu.terminator
162
157
}
163
158
return
0 commit comments