28
28
DEBUG = False
29
29
30
30
31
+ class TmaDescriptorBuilder :
32
+ """A class that builds a TMA descriptor."""
33
+
34
+ def __init__ (self , swizzle , l2promo , oob , interleave , tma_box_shape , memref_ty ):
35
+ self .swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
36
+ self .l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
37
+ self .oob = oob # mlir.nvgpu.TensorMapOOBKind
38
+ self .interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
39
+ self .tma_box_shape = tma_box_shape
40
+ self .memref_ty = memref_ty # MemRefType
41
+
42
+ @property
43
+ def tensormap_descriptor_ty (self ):
44
+ """Returns a tensormap descriptor type."""
45
+ memref_str = f"memref<{ self .tma_box_shape [0 ]} x{ self .tma_box_shape [1 ]} x{ self .memref_ty .element_type } , 3>"
46
+ parse_str = f"!nvgpu.tensormap.descriptor<tensor = { memref_str } ,\
47
+ swizzle = { self .swizzle } ,\
48
+ l2promo = { self .l2promo } ,\
49
+ oob = { self .oob } ,\
50
+ interleave = { self .interleave } >"
51
+ return ir .Type .parse (parse_str )
52
+
53
+ def tma_descriptor_op (self , device_ptr ):
54
+ """Returns a tensormap descriptor op."""
55
+ tma_descriptor_ty = self .tensormap_descriptor_ty
56
+ device_unranked_memref = memref .CastOp (
57
+ ir .UnrankedMemRefType .get (
58
+ self .memref_ty .element_type , self .memref_ty .memory_space
59
+ ),
60
+ device_ptr ,
61
+ )
62
+ tma_descriptor_op = nvgpu .TmaCreateDescriptorOp (
63
+ tma_descriptor_ty , device_unranked_memref , map (c , self .tma_box_shape )
64
+ )
65
+ return tma_descriptor_op .result
66
+
67
+
31
68
def debug_print (fmt , * args , predicate = None , threadNumber = - 1 , forcePrint = False ):
32
69
if not DEBUG and not forcePrint :
33
70
return
@@ -162,28 +199,6 @@ def generate_matmul_ws(
162
199
+ str (num_stages )
163
200
+ ">"
164
201
)
165
- a_tma_desc_ty = ir .Type .parse (
166
- "!nvgpu.tensormap.descriptor<tensor = memref<"
167
- + str (BLOCK_M )
168
- + "x"
169
- + str (TMA_LAST_DIM_F16 )
170
- + "x"
171
- + str (a_elem_ty )
172
- + ", "
173
- + str (smem_space )
174
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
175
- )
176
- b_tma_desc_ty = ir .Type .parse (
177
- "!nvgpu.tensormap.descriptor<tensor = memref<"
178
- + str (BLOCK_K )
179
- + "x"
180
- + str (TMA_LAST_DIM_F16 )
181
- + "x"
182
- + str (b_elem_ty )
183
- + ", "
184
- + str (smem_space )
185
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
186
- )
187
202
acc_ty = ir .Type .parse (
188
203
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
189
204
+ str (BLOCK_M )
@@ -240,21 +255,26 @@ def generate_matmul_ws(
240
255
t7 = gpu .wait (token_ty , [t6 ])
241
256
242
257
# Step 2. Create TMA Descriptors
243
- tma_specs = [
244
- (a_device , a_tma_desc_ty , a_tma_shape ),
245
- (b_device , b_tma_desc_ty , b_tma_shape ),
246
- ]
247
- tma_descs = []
248
- for x_device , tensor_map_ty , tile_shape in tma_specs :
249
- x_unranked = memref .cast (
250
- ir .UnrankedMemRefType .get (a_elem_ty , a_ty .memory_space ), x_device
251
- )
252
- tma_descs .append (
253
- nvgpu .TmaCreateDescriptorOp (
254
- tensor_map_ty , x_unranked , map (c , tile_shape )
255
- ).result
256
- )
257
- a_tma_desc , b_tma_desc = tma_descs
258
+ a_tma_desc = TmaDescriptorBuilder (
259
+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
260
+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
261
+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
262
+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
263
+ a_tma_shape ,
264
+ a_ty ,
265
+ )
266
+
267
+ b_tma_desc = TmaDescriptorBuilder (
268
+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
269
+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
270
+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
271
+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
272
+ b_tma_shape ,
273
+ b_ty ,
274
+ )
275
+
276
+ a_tma_desc_op = a_tma_desc .tma_descriptor_op (a_device )
277
+ b_tma_desc_op = b_tma_desc .tma_descriptor_op (b_device )
258
278
259
279
# Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
260
280
cta_m = M // BLOCK_M
@@ -267,7 +287,7 @@ def generate_matmul_ws(
267
287
[t7 ],
268
288
* map (c , grid ),
269
289
* map (c , block ),
270
- dynamicSharedMemorySize = c (smem_size , ty = T .i32 ())
290
+ dynamicSharedMemorySize = c (smem_size , ty = T .i32 ()),
271
291
)
272
292
launch_op .body .blocks .append (* ([T .index ()] * 12 ))
273
293
with ir .InsertionPoint (launch_op .body .blocks [0 ]):
@@ -315,8 +335,8 @@ def generate_matmul_ws(
315
335
gpu .barrier ()
316
336
317
337
# GPU Step 3. Prefetch TMA descriptors
318
- nvgpu .tma_prefetch_descriptor (a_tma_desc , predicate = wgPrimaryThread )
319
- nvgpu .tma_prefetch_descriptor (b_tma_desc , predicate = wgPrimaryThread )
338
+ nvgpu .tma_prefetch_descriptor (a_tma_desc_op , predicate = wgPrimaryThread )
339
+ nvgpu .tma_prefetch_descriptor (b_tma_desc_op , predicate = wgPrimaryThread )
320
340
321
341
ns = num_stages if num_stages == 1 else num_stages - 1
322
342
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
@@ -405,15 +425,15 @@ def generate_matmul_ws(
405
425
nvgpu .TmaAsyncLoadOp (
406
426
a_tma_slice ,
407
427
mbarTMA ,
408
- a_tma_desc ,
428
+ a_tma_desc_op ,
409
429
coordinates = [coord , dimX ],
410
430
mbarId = stage ,
411
431
predicate = producerPrimaryThread ,
412
432
)
413
433
nvgpu .TmaAsyncLoadOp (
414
434
b_tma_slice_1 ,
415
435
mbarTMA ,
416
- b_tma_desc ,
436
+ b_tma_desc_op ,
417
437
coordinates = [dimY , coord ],
418
438
mbarId = stage ,
419
439
predicate = producerPrimaryThread ,
@@ -422,7 +442,7 @@ def generate_matmul_ws(
422
442
nvgpu .TmaAsyncLoadOp (
423
443
b_tma_slice_2 ,
424
444
mbarTMA ,
425
- b_tma_desc ,
445
+ b_tma_desc_op ,
426
446
coordinates = [dimY2 , coord ],
427
447
mbarId = stage ,
428
448
predicate = producerPrimaryThread ,
@@ -514,10 +534,10 @@ def generate_matmul_ws(
514
534
predicate = consumerPrimaryThread ,
515
535
)
516
536
da = nvgpu .WarpgroupGenerateDescriptorOp (
517
- a_wgmma_ty , a_tile_slice , a_tma_desc
537
+ a_wgmma_ty , a_tile_slice , a_tma_desc_op
518
538
)
519
539
db = nvgpu .WarpgroupGenerateDescriptorOp (
520
- b_wgmma_ty , b_tile_slice , b_tma_desc
540
+ b_wgmma_ty , b_tile_slice , b_tma_desc_op
521
541
)
522
542
523
543
# Step 6.3.3. MMA
@@ -679,28 +699,6 @@ def generate_matmul_multistage(
679
699
+ str (num_stages )
680
700
+ ">"
681
701
)
682
- a_tma_desc_ty = ir .Type .parse (
683
- "!nvgpu.tensormap.descriptor<tensor = memref<"
684
- + str (BLOCK_M )
685
- + "x"
686
- + str (TMA_LAST_DIM_F16 )
687
- + "x"
688
- + str (a_elem_ty )
689
- + ", "
690
- + str (smem_space )
691
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
692
- )
693
- b_tma_desc_ty = ir .Type .parse (
694
- "!nvgpu.tensormap.descriptor<tensor = memref<"
695
- + str (BLOCK_K )
696
- + "x"
697
- + str (TMA_LAST_DIM_F16 )
698
- + "x"
699
- + str (b_elem_ty )
700
- + ", "
701
- + str (smem_space )
702
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
703
- )
704
702
acc_ty = ir .Type .parse (
705
703
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
706
704
+ str (BLOCK_M )
@@ -767,21 +765,26 @@ def generate_matmul_multistage(
767
765
t7 = gpu .wait (token_ty , [t6 ])
768
766
769
767
# Step 2. Create TMA Descriptors
770
- tma_specs = [
771
- (a_device , a_tma_desc_ty , a_tma_shape ),
772
- (b_device , b_tma_desc_ty , b_tma_shape ),
773
- ]
774
- tma_descs = []
775
- for x_device , tensor_map_ty , tile_shape in tma_specs :
776
- x_unranked = memref .cast (
777
- ir .UnrankedMemRefType .get (a_elem_ty , a_ty .memory_space ), x_device
778
- )
779
- tma_descs .append (
780
- nvgpu .TmaCreateDescriptorOp (
781
- tensor_map_ty , x_unranked , map (c , tile_shape )
782
- ).result
783
- )
784
- a_tma_desc , b_tma_desc = tma_descs
768
+ a_tma_desc = TmaDescriptorBuilder (
769
+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
770
+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
771
+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
772
+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
773
+ a_tma_shape ,
774
+ a_ty ,
775
+ )
776
+
777
+ b_tma_desc = TmaDescriptorBuilder (
778
+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
779
+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
780
+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
781
+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
782
+ b_tma_shape ,
783
+ b_ty ,
784
+ )
785
+
786
+ a_tma_desc_op = a_tma_desc .tma_descriptor_op (a_device )
787
+ b_tma_desc_op = b_tma_desc .tma_descriptor_op (b_device )
785
788
786
789
# Step 3. Launch Kernel with 1 Warpgroup
787
790
cta_m = M // BLOCK_M
@@ -794,7 +797,7 @@ def generate_matmul_multistage(
794
797
[t7 ],
795
798
* map (c , grid ),
796
799
* map (c , block ),
797
- dynamicSharedMemorySize = c (smem_size , ty = T .i32 ())
800
+ dynamicSharedMemorySize = c (smem_size , ty = T .i32 ()),
798
801
)
799
802
launch_op .body .blocks .append (* ([T .index ()] * 12 ))
800
803
with ir .InsertionPoint (launch_op .body .blocks [0 ]):
@@ -819,8 +822,8 @@ def generate_matmul_multistage(
819
822
gpu .barrier ()
820
823
821
824
# GPU Step 2. Prefetch TMA descriptors
822
- nvgpu .tma_prefetch_descriptor (a_tma_desc , predicate = primaryThread )
823
- nvgpu .tma_prefetch_descriptor (b_tma_desc , predicate = primaryThread )
825
+ nvgpu .tma_prefetch_descriptor (a_tma_desc_op , predicate = primaryThread )
826
+ nvgpu .tma_prefetch_descriptor (b_tma_desc_op , predicate = primaryThread )
824
827
825
828
# GPU Step 3. Prologue (global memory --> shared memory)
826
829
ns = num_stages if num_stages == 1 else num_stages - 1
@@ -880,23 +883,23 @@ def generate_matmul_multistage(
880
883
nvgpu .TmaAsyncLoadOp (
881
884
a_tma_slice ,
882
885
mbarTMA ,
883
- a_tma_desc ,
886
+ a_tma_desc_op ,
884
887
coordinates = [coord , dimX ],
885
888
mbarId = iv ,
886
889
predicate = primaryThread ,
887
890
)
888
891
nvgpu .TmaAsyncLoadOp (
889
892
b_tma_slice_1 ,
890
893
mbarTMA ,
891
- b_tma_desc ,
894
+ b_tma_desc_op ,
892
895
coordinates = [dimY , coord ],
893
896
mbarId = iv ,
894
897
predicate = primaryThread ,
895
898
)
896
899
nvgpu .TmaAsyncLoadOp (
897
900
b_tma_slice_2 ,
898
901
mbarTMA ,
899
- b_tma_desc ,
902
+ b_tma_desc_op ,
900
903
coordinates = [dimY2 , coord ],
901
904
mbarId = iv ,
902
905
predicate = primaryThread ,
@@ -972,10 +975,10 @@ def generate_matmul_multistage(
972
975
predicate = primaryThread ,
973
976
)
974
977
da = nvgpu .WarpgroupGenerateDescriptorOp (
975
- a_wgmma_ty , a_tile_slice , a_tma_desc
978
+ a_wgmma_ty , a_tile_slice , a_tma_desc_op
976
979
)
977
980
db = nvgpu .WarpgroupGenerateDescriptorOp (
978
- b_wgmma_ty , b_tile_slice , b_tma_desc
981
+ b_wgmma_ty , b_tile_slice , b_tma_desc_op
979
982
)
980
983
981
984
# Step 4.3. MMA
@@ -1060,15 +1063,15 @@ def generate_matmul_multistage(
1060
1063
nvgpu .TmaAsyncLoadOp (
1061
1064
a_tma_slice ,
1062
1065
mbarTMA ,
1063
- a_tma_desc ,
1066
+ a_tma_desc_op ,
1064
1067
coordinates = [coord , dimX ],
1065
1068
mbarId = nextSlot ,
1066
1069
predicate = p ,
1067
1070
)
1068
1071
nvgpu .TmaAsyncLoadOp (
1069
1072
b_tma_slice_1 ,
1070
1073
mbarTMA ,
1071
- b_tma_desc ,
1074
+ b_tma_desc_op ,
1072
1075
coordinates = [dimY , coord ],
1073
1076
mbarId = nextSlot ,
1074
1077
predicate = p ,
@@ -1077,7 +1080,7 @@ def generate_matmul_multistage(
1077
1080
nvgpu .TmaAsyncLoadOp (
1078
1081
b_tma_slice_2 ,
1079
1082
mbarTMA ,
1080
- b_tma_desc ,
1083
+ b_tma_desc_op ,
1081
1084
coordinates = [dimY2 , coord ],
1082
1085
mbarId = nextSlot ,
1083
1086
predicate = p ,
0 commit comments