28
28
DEBUG = False
29
29
30
30
31
+ class TmaDescriptorBuilder :
32
+ """A class that builds a TMA descriptor."""
33
+
34
+ def __init__ (self , swizzle , l2promo , oob , interleave , tma_box_shape , memref_ty ):
35
+ self .swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
36
+ self .l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
37
+ self .oob = oob # mlir.nvgpu.TensorMapOOBKind
38
+ self .interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
39
+ self .tma_box_shape = tma_box_shape
40
+ self .memref_ty = memref_ty # MemRefType
41
+
42
+ @property
43
+ def tensormap_descriptor_ty (self ):
44
+ """Returns a tensormap descriptor type."""
45
+ tensorMemrefType = ir .MemRefType .get (
46
+ self .tma_box_shape ,
47
+ self .memref_ty .element_type ,
48
+ memory_space = ir .Attribute .parse ("3" ),
49
+ )
50
+ return nvgpu .TensorMapDescriptorType .get (
51
+ tensorMemrefType ,
52
+ self .swizzle ,
53
+ self .l2promo ,
54
+ self .oob ,
55
+ self .interleave ,
56
+ )
57
+
58
+ def tma_descriptor_op (self , device_ptr ):
59
+ """Returns a tensormap descriptor op."""
60
+ tma_descriptor_ty = self .tensormap_descriptor_ty
61
+ device_unranked_memref = memref .CastOp (
62
+ ir .UnrankedMemRefType .get (
63
+ self .memref_ty .element_type , self .memref_ty .memory_space
64
+ ),
65
+ device_ptr ,
66
+ )
67
+ tma_descriptor_op = nvgpu .TmaCreateDescriptorOp (
68
+ tma_descriptor_ty , device_unranked_memref , map (c , self .tma_box_shape )
69
+ )
70
+ return tma_descriptor_op .result
71
+
72
+
31
73
def debug_print (fmt , * args , predicate = None , threadNumber = - 1 , forcePrint = False ):
32
74
if not DEBUG and not forcePrint :
33
75
return
@@ -162,28 +204,6 @@ def generate_matmul_ws(
162
204
+ str (num_stages )
163
205
+ ">"
164
206
)
165
- a_tma_desc_ty = ir .Type .parse (
166
- "!nvgpu.tensormap.descriptor<tensor = memref<"
167
- + str (BLOCK_M )
168
- + "x"
169
- + str (TMA_LAST_DIM_F16 )
170
- + "x"
171
- + str (a_elem_ty )
172
- + ", "
173
- + str (smem_space )
174
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
175
- )
176
- b_tma_desc_ty = ir .Type .parse (
177
- "!nvgpu.tensormap.descriptor<tensor = memref<"
178
- + str (BLOCK_K )
179
- + "x"
180
- + str (TMA_LAST_DIM_F16 )
181
- + "x"
182
- + str (b_elem_ty )
183
- + ", "
184
- + str (smem_space )
185
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
186
- )
187
207
acc_ty = ir .Type .parse (
188
208
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
189
209
+ str (BLOCK_M )
@@ -240,21 +260,26 @@ def generate_matmul_ws(
240
260
t7 = gpu .wait (token_ty , [t6 ])
241
261
242
262
# Step 2. Create TMA Descriptors
243
- tma_specs = [
244
- (a_device , a_tma_desc_ty , a_tma_shape ),
245
- (b_device , b_tma_desc_ty , b_tma_shape ),
246
- ]
247
- tma_descs = []
248
- for x_device , tensor_map_ty , tile_shape in tma_specs :
249
- x_unranked = memref .cast (
250
- ir .UnrankedMemRefType .get (a_elem_ty , a_ty .memory_space ), x_device
251
- )
252
- tma_descs .append (
253
- nvgpu .TmaCreateDescriptorOp (
254
- tensor_map_ty , x_unranked , map (c , tile_shape )
255
- ).result
256
- )
257
- a_tma_desc , b_tma_desc = tma_descs
263
+ a_tma_desc = TmaDescriptorBuilder (
264
+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
265
+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
266
+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
267
+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
268
+ a_tma_shape ,
269
+ a_ty ,
270
+ )
271
+
272
+ b_tma_desc = TmaDescriptorBuilder (
273
+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
274
+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
275
+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
276
+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
277
+ b_tma_shape ,
278
+ b_ty ,
279
+ )
280
+
281
+ a_tma_desc_op = a_tma_desc .tma_descriptor_op (a_device )
282
+ b_tma_desc_op = b_tma_desc .tma_descriptor_op (b_device )
258
283
259
284
# Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
260
285
cta_m = M // BLOCK_M
@@ -267,7 +292,7 @@ def generate_matmul_ws(
267
292
[t7 ],
268
293
* map (c , grid ),
269
294
* map (c , block ),
270
- dynamicSharedMemorySize = c (smem_size , ty = T .i32 ())
295
+ dynamicSharedMemorySize = c (smem_size , ty = T .i32 ()),
271
296
)
272
297
launch_op .body .blocks .append (* ([T .index ()] * 12 ))
273
298
with ir .InsertionPoint (launch_op .body .blocks [0 ]):
@@ -315,8 +340,8 @@ def generate_matmul_ws(
315
340
gpu .barrier ()
316
341
317
342
# GPU Step 3. Prefetch TMA descriptors
318
- nvgpu .tma_prefetch_descriptor (a_tma_desc , predicate = wgPrimaryThread )
319
- nvgpu .tma_prefetch_descriptor (b_tma_desc , predicate = wgPrimaryThread )
343
+ nvgpu .tma_prefetch_descriptor (a_tma_desc_op , predicate = wgPrimaryThread )
344
+ nvgpu .tma_prefetch_descriptor (b_tma_desc_op , predicate = wgPrimaryThread )
320
345
321
346
ns = num_stages if num_stages == 1 else num_stages - 1
322
347
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
@@ -405,15 +430,15 @@ def generate_matmul_ws(
405
430
nvgpu .TmaAsyncLoadOp (
406
431
a_tma_slice ,
407
432
mbarTMA ,
408
- a_tma_desc ,
433
+ a_tma_desc_op ,
409
434
coordinates = [coord , dimX ],
410
435
mbarId = stage ,
411
436
predicate = producerPrimaryThread ,
412
437
)
413
438
nvgpu .TmaAsyncLoadOp (
414
439
b_tma_slice_1 ,
415
440
mbarTMA ,
416
- b_tma_desc ,
441
+ b_tma_desc_op ,
417
442
coordinates = [dimY , coord ],
418
443
mbarId = stage ,
419
444
predicate = producerPrimaryThread ,
@@ -422,7 +447,7 @@ def generate_matmul_ws(
422
447
nvgpu .TmaAsyncLoadOp (
423
448
b_tma_slice_2 ,
424
449
mbarTMA ,
425
- b_tma_desc ,
450
+ b_tma_desc_op ,
426
451
coordinates = [dimY2 , coord ],
427
452
mbarId = stage ,
428
453
predicate = producerPrimaryThread ,
@@ -514,10 +539,10 @@ def generate_matmul_ws(
514
539
predicate = consumerPrimaryThread ,
515
540
)
516
541
da = nvgpu .WarpgroupGenerateDescriptorOp (
517
- a_wgmma_ty , a_tile_slice , a_tma_desc
542
+ a_wgmma_ty , a_tile_slice , a_tma_desc_op
518
543
)
519
544
db = nvgpu .WarpgroupGenerateDescriptorOp (
520
- b_wgmma_ty , b_tile_slice , b_tma_desc
545
+ b_wgmma_ty , b_tile_slice , b_tma_desc_op
521
546
)
522
547
523
548
# Step 6.3.3. MMA
@@ -679,28 +704,6 @@ def generate_matmul_multistage(
679
704
+ str (num_stages )
680
705
+ ">"
681
706
)
682
- a_tma_desc_ty = ir .Type .parse (
683
- "!nvgpu.tensormap.descriptor<tensor = memref<"
684
- + str (BLOCK_M )
685
- + "x"
686
- + str (TMA_LAST_DIM_F16 )
687
- + "x"
688
- + str (a_elem_ty )
689
- + ", "
690
- + str (smem_space )
691
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
692
- )
693
- b_tma_desc_ty = ir .Type .parse (
694
- "!nvgpu.tensormap.descriptor<tensor = memref<"
695
- + str (BLOCK_K )
696
- + "x"
697
- + str (TMA_LAST_DIM_F16 )
698
- + "x"
699
- + str (b_elem_ty )
700
- + ", "
701
- + str (smem_space )
702
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
703
- )
704
707
acc_ty = ir .Type .parse (
705
708
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
706
709
+ str (BLOCK_M )
@@ -767,21 +770,26 @@ def generate_matmul_multistage(
767
770
t7 = gpu .wait (token_ty , [t6 ])
768
771
769
772
# Step 2. Create TMA Descriptors
770
- tma_specs = [
771
- (a_device , a_tma_desc_ty , a_tma_shape ),
772
- (b_device , b_tma_desc_ty , b_tma_shape ),
773
- ]
774
- tma_descs = []
775
- for x_device , tensor_map_ty , tile_shape in tma_specs :
776
- x_unranked = memref .cast (
777
- ir .UnrankedMemRefType .get (a_elem_ty , a_ty .memory_space ), x_device
778
- )
779
- tma_descs .append (
780
- nvgpu .TmaCreateDescriptorOp (
781
- tensor_map_ty , x_unranked , map (c , tile_shape )
782
- ).result
783
- )
784
- a_tma_desc , b_tma_desc = tma_descs
773
+ a_tma_desc = TmaDescriptorBuilder (
774
+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
775
+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
776
+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
777
+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
778
+ a_tma_shape ,
779
+ a_ty ,
780
+ )
781
+
782
+ b_tma_desc = TmaDescriptorBuilder (
783
+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
784
+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
785
+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
786
+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
787
+ b_tma_shape ,
788
+ b_ty ,
789
+ )
790
+
791
+ a_tma_desc_op = a_tma_desc .tma_descriptor_op (a_device )
792
+ b_tma_desc_op = b_tma_desc .tma_descriptor_op (b_device )
785
793
786
794
# Step 3. Launch Kernel with 1 Warpgroup
787
795
cta_m = M // BLOCK_M
@@ -794,7 +802,7 @@ def generate_matmul_multistage(
794
802
[t7 ],
795
803
* map (c , grid ),
796
804
* map (c , block ),
797
- dynamicSharedMemorySize = c (smem_size , ty = T .i32 ())
805
+ dynamicSharedMemorySize = c (smem_size , ty = T .i32 ()),
798
806
)
799
807
launch_op .body .blocks .append (* ([T .index ()] * 12 ))
800
808
with ir .InsertionPoint (launch_op .body .blocks [0 ]):
@@ -819,8 +827,8 @@ def generate_matmul_multistage(
819
827
gpu .barrier ()
820
828
821
829
# GPU Step 2. Prefetch TMA descriptors
822
- nvgpu .tma_prefetch_descriptor (a_tma_desc , predicate = primaryThread )
823
- nvgpu .tma_prefetch_descriptor (b_tma_desc , predicate = primaryThread )
830
+ nvgpu .tma_prefetch_descriptor (a_tma_desc_op , predicate = primaryThread )
831
+ nvgpu .tma_prefetch_descriptor (b_tma_desc_op , predicate = primaryThread )
824
832
825
833
# GPU Step 3. Prologue (global memory --> shared memory)
826
834
ns = num_stages if num_stages == 1 else num_stages - 1
@@ -880,23 +888,23 @@ def generate_matmul_multistage(
880
888
nvgpu .TmaAsyncLoadOp (
881
889
a_tma_slice ,
882
890
mbarTMA ,
883
- a_tma_desc ,
891
+ a_tma_desc_op ,
884
892
coordinates = [coord , dimX ],
885
893
mbarId = iv ,
886
894
predicate = primaryThread ,
887
895
)
888
896
nvgpu .TmaAsyncLoadOp (
889
897
b_tma_slice_1 ,
890
898
mbarTMA ,
891
- b_tma_desc ,
899
+ b_tma_desc_op ,
892
900
coordinates = [dimY , coord ],
893
901
mbarId = iv ,
894
902
predicate = primaryThread ,
895
903
)
896
904
nvgpu .TmaAsyncLoadOp (
897
905
b_tma_slice_2 ,
898
906
mbarTMA ,
899
- b_tma_desc ,
907
+ b_tma_desc_op ,
900
908
coordinates = [dimY2 , coord ],
901
909
mbarId = iv ,
902
910
predicate = primaryThread ,
@@ -972,10 +980,10 @@ def generate_matmul_multistage(
972
980
predicate = primaryThread ,
973
981
)
974
982
da = nvgpu .WarpgroupGenerateDescriptorOp (
975
- a_wgmma_ty , a_tile_slice , a_tma_desc
983
+ a_wgmma_ty , a_tile_slice , a_tma_desc_op
976
984
)
977
985
db = nvgpu .WarpgroupGenerateDescriptorOp (
978
- b_wgmma_ty , b_tile_slice , b_tma_desc
986
+ b_wgmma_ty , b_tile_slice , b_tma_desc_op
979
987
)
980
988
981
989
# Step 4.3. MMA
@@ -1060,15 +1068,15 @@ def generate_matmul_multistage(
1060
1068
nvgpu .TmaAsyncLoadOp (
1061
1069
a_tma_slice ,
1062
1070
mbarTMA ,
1063
- a_tma_desc ,
1071
+ a_tma_desc_op ,
1064
1072
coordinates = [coord , dimX ],
1065
1073
mbarId = nextSlot ,
1066
1074
predicate = p ,
1067
1075
)
1068
1076
nvgpu .TmaAsyncLoadOp (
1069
1077
b_tma_slice_1 ,
1070
1078
mbarTMA ,
1071
- b_tma_desc ,
1079
+ b_tma_desc_op ,
1072
1080
coordinates = [dimY , coord ],
1073
1081
mbarId = nextSlot ,
1074
1082
predicate = p ,
@@ -1077,7 +1085,7 @@ def generate_matmul_multistage(
1077
1085
nvgpu .TmaAsyncLoadOp (
1078
1086
b_tma_slice_2 ,
1079
1087
mbarTMA ,
1080
- b_tma_desc ,
1088
+ b_tma_desc_op ,
1081
1089
coordinates = [dimY2 , coord ],
1082
1090
mbarId = nextSlot ,
1083
1091
predicate = p ,
0 commit comments