Skip to content

Commit c82f45f

Browse files
gryppManish Gupta
andauthored
[mlir][nvgpu] Simplify TMA IR generation (#87153)
This PR add `TmaDescriptorBuilder` - class simplifies TMA generation. - Makes the code ready to support various Tma configurations - removes strings and use the enums from `mlir.nvgpu.ENUMs`. - Example "swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none" to enums in `mlir.nvgpu` dialects. - Enums have string equivalent that are used during the IR writing and generation (see `TmaDescriptorBuilder::tensormap_descriptor_ty`). - Improves readability and abstracts out TMA descriptor builders in reusable component. --------- Co-authored-by: Manish Gupta <[email protected]>
1 parent 609ee9f commit c82f45f

File tree

1 file changed

+101
-93
lines changed

1 file changed

+101
-93
lines changed

mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py

Lines changed: 101 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,48 @@
2828
DEBUG = False
2929

3030

31+
class TmaDescriptorBuilder:
32+
"""A class that builds a TMA descriptor."""
33+
34+
def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
35+
self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
36+
self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
37+
self.oob = oob # mlir.nvgpu.TensorMapOOBKind
38+
self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
39+
self.tma_box_shape = tma_box_shape
40+
self.memref_ty = memref_ty # MemRefType
41+
42+
@property
43+
def tensormap_descriptor_ty(self):
44+
"""Returns a tensormap descriptor type."""
45+
tensorMemrefType = ir.MemRefType.get(
46+
self.tma_box_shape,
47+
self.memref_ty.element_type,
48+
memory_space=ir.Attribute.parse("3"),
49+
)
50+
return nvgpu.TensorMapDescriptorType.get(
51+
tensorMemrefType,
52+
self.swizzle,
53+
self.l2promo,
54+
self.oob,
55+
self.interleave,
56+
)
57+
58+
def tma_descriptor_op(self, device_ptr):
59+
"""Returns a tensormap descriptor op."""
60+
tma_descriptor_ty = self.tensormap_descriptor_ty
61+
device_unranked_memref = memref.CastOp(
62+
ir.UnrankedMemRefType.get(
63+
self.memref_ty.element_type, self.memref_ty.memory_space
64+
),
65+
device_ptr,
66+
)
67+
tma_descriptor_op = nvgpu.TmaCreateDescriptorOp(
68+
tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape)
69+
)
70+
return tma_descriptor_op.result
71+
72+
3173
def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
3274
if not DEBUG and not forcePrint:
3375
return
@@ -162,28 +204,6 @@ def generate_matmul_ws(
162204
+ str(num_stages)
163205
+ ">"
164206
)
165-
a_tma_desc_ty = ir.Type.parse(
166-
"!nvgpu.tensormap.descriptor<tensor = memref<"
167-
+ str(BLOCK_M)
168-
+ "x"
169-
+ str(TMA_LAST_DIM_F16)
170-
+ "x"
171-
+ str(a_elem_ty)
172-
+ ", "
173-
+ str(smem_space)
174-
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
175-
)
176-
b_tma_desc_ty = ir.Type.parse(
177-
"!nvgpu.tensormap.descriptor<tensor = memref<"
178-
+ str(BLOCK_K)
179-
+ "x"
180-
+ str(TMA_LAST_DIM_F16)
181-
+ "x"
182-
+ str(b_elem_ty)
183-
+ ", "
184-
+ str(smem_space)
185-
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
186-
)
187207
acc_ty = ir.Type.parse(
188208
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
189209
+ str(BLOCK_M)
@@ -240,21 +260,26 @@ def generate_matmul_ws(
240260
t7 = gpu.wait(token_ty, [t6])
241261

242262
# Step 2. Create TMA Descriptors
243-
tma_specs = [
244-
(a_device, a_tma_desc_ty, a_tma_shape),
245-
(b_device, b_tma_desc_ty, b_tma_shape),
246-
]
247-
tma_descs = []
248-
for x_device, tensor_map_ty, tile_shape in tma_specs:
249-
x_unranked = memref.cast(
250-
ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
251-
)
252-
tma_descs.append(
253-
nvgpu.TmaCreateDescriptorOp(
254-
tensor_map_ty, x_unranked, map(c, tile_shape)
255-
).result
256-
)
257-
a_tma_desc, b_tma_desc = tma_descs
263+
a_tma_desc = TmaDescriptorBuilder(
264+
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
265+
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
266+
nvgpu.TensorMapOOBKind.OOB_ZERO,
267+
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
268+
a_tma_shape,
269+
a_ty,
270+
)
271+
272+
b_tma_desc = TmaDescriptorBuilder(
273+
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
274+
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
275+
nvgpu.TensorMapOOBKind.OOB_ZERO,
276+
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
277+
b_tma_shape,
278+
b_ty,
279+
)
280+
281+
a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
282+
b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
258283

259284
# Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
260285
cta_m = M // BLOCK_M
@@ -267,7 +292,7 @@ def generate_matmul_ws(
267292
[t7],
268293
*map(c, grid),
269294
*map(c, block),
270-
dynamicSharedMemorySize=c(smem_size, ty=T.i32())
295+
dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
271296
)
272297
launch_op.body.blocks.append(*([T.index()] * 12))
273298
with ir.InsertionPoint(launch_op.body.blocks[0]):
@@ -315,8 +340,8 @@ def generate_matmul_ws(
315340
gpu.barrier()
316341

317342
# GPU Step 3. Prefetch TMA descriptors
318-
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
319-
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
343+
nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread)
344+
nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread)
320345

321346
ns = num_stages if num_stages == 1 else num_stages - 1
322347
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
@@ -405,15 +430,15 @@ def generate_matmul_ws(
405430
nvgpu.TmaAsyncLoadOp(
406431
a_tma_slice,
407432
mbarTMA,
408-
a_tma_desc,
433+
a_tma_desc_op,
409434
coordinates=[coord, dimX],
410435
mbarId=stage,
411436
predicate=producerPrimaryThread,
412437
)
413438
nvgpu.TmaAsyncLoadOp(
414439
b_tma_slice_1,
415440
mbarTMA,
416-
b_tma_desc,
441+
b_tma_desc_op,
417442
coordinates=[dimY, coord],
418443
mbarId=stage,
419444
predicate=producerPrimaryThread,
@@ -422,7 +447,7 @@ def generate_matmul_ws(
422447
nvgpu.TmaAsyncLoadOp(
423448
b_tma_slice_2,
424449
mbarTMA,
425-
b_tma_desc,
450+
b_tma_desc_op,
426451
coordinates=[dimY2, coord],
427452
mbarId=stage,
428453
predicate=producerPrimaryThread,
@@ -514,10 +539,10 @@ def generate_matmul_ws(
514539
predicate=consumerPrimaryThread,
515540
)
516541
da = nvgpu.WarpgroupGenerateDescriptorOp(
517-
a_wgmma_ty, a_tile_slice, a_tma_desc
542+
a_wgmma_ty, a_tile_slice, a_tma_desc_op
518543
)
519544
db = nvgpu.WarpgroupGenerateDescriptorOp(
520-
b_wgmma_ty, b_tile_slice, b_tma_desc
545+
b_wgmma_ty, b_tile_slice, b_tma_desc_op
521546
)
522547

523548
# Step 6.3.3. MMA
@@ -679,28 +704,6 @@ def generate_matmul_multistage(
679704
+ str(num_stages)
680705
+ ">"
681706
)
682-
a_tma_desc_ty = ir.Type.parse(
683-
"!nvgpu.tensormap.descriptor<tensor = memref<"
684-
+ str(BLOCK_M)
685-
+ "x"
686-
+ str(TMA_LAST_DIM_F16)
687-
+ "x"
688-
+ str(a_elem_ty)
689-
+ ", "
690-
+ str(smem_space)
691-
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
692-
)
693-
b_tma_desc_ty = ir.Type.parse(
694-
"!nvgpu.tensormap.descriptor<tensor = memref<"
695-
+ str(BLOCK_K)
696-
+ "x"
697-
+ str(TMA_LAST_DIM_F16)
698-
+ "x"
699-
+ str(b_elem_ty)
700-
+ ", "
701-
+ str(smem_space)
702-
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
703-
)
704707
acc_ty = ir.Type.parse(
705708
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
706709
+ str(BLOCK_M)
@@ -767,21 +770,26 @@ def generate_matmul_multistage(
767770
t7 = gpu.wait(token_ty, [t6])
768771

769772
# Step 2. Create TMA Descriptors
770-
tma_specs = [
771-
(a_device, a_tma_desc_ty, a_tma_shape),
772-
(b_device, b_tma_desc_ty, b_tma_shape),
773-
]
774-
tma_descs = []
775-
for x_device, tensor_map_ty, tile_shape in tma_specs:
776-
x_unranked = memref.cast(
777-
ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
778-
)
779-
tma_descs.append(
780-
nvgpu.TmaCreateDescriptorOp(
781-
tensor_map_ty, x_unranked, map(c, tile_shape)
782-
).result
783-
)
784-
a_tma_desc, b_tma_desc = tma_descs
773+
a_tma_desc = TmaDescriptorBuilder(
774+
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
775+
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
776+
nvgpu.TensorMapOOBKind.OOB_ZERO,
777+
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
778+
a_tma_shape,
779+
a_ty,
780+
)
781+
782+
b_tma_desc = TmaDescriptorBuilder(
783+
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
784+
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
785+
nvgpu.TensorMapOOBKind.OOB_ZERO,
786+
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
787+
b_tma_shape,
788+
b_ty,
789+
)
790+
791+
a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
792+
b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
785793

786794
# Step 3. Launch Kernel with 1 Warpgroup
787795
cta_m = M // BLOCK_M
@@ -794,7 +802,7 @@ def generate_matmul_multistage(
794802
[t7],
795803
*map(c, grid),
796804
*map(c, block),
797-
dynamicSharedMemorySize=c(smem_size, ty=T.i32())
805+
dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
798806
)
799807
launch_op.body.blocks.append(*([T.index()] * 12))
800808
with ir.InsertionPoint(launch_op.body.blocks[0]):
@@ -819,8 +827,8 @@ def generate_matmul_multistage(
819827
gpu.barrier()
820828

821829
# GPU Step 2. Prefetch TMA descriptors
822-
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
823-
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
830+
nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread)
831+
nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread)
824832

825833
# GPU Step 3. Prologue (global memory --> shared memory)
826834
ns = num_stages if num_stages == 1 else num_stages - 1
@@ -880,23 +888,23 @@ def generate_matmul_multistage(
880888
nvgpu.TmaAsyncLoadOp(
881889
a_tma_slice,
882890
mbarTMA,
883-
a_tma_desc,
891+
a_tma_desc_op,
884892
coordinates=[coord, dimX],
885893
mbarId=iv,
886894
predicate=primaryThread,
887895
)
888896
nvgpu.TmaAsyncLoadOp(
889897
b_tma_slice_1,
890898
mbarTMA,
891-
b_tma_desc,
899+
b_tma_desc_op,
892900
coordinates=[dimY, coord],
893901
mbarId=iv,
894902
predicate=primaryThread,
895903
)
896904
nvgpu.TmaAsyncLoadOp(
897905
b_tma_slice_2,
898906
mbarTMA,
899-
b_tma_desc,
907+
b_tma_desc_op,
900908
coordinates=[dimY2, coord],
901909
mbarId=iv,
902910
predicate=primaryThread,
@@ -972,10 +980,10 @@ def generate_matmul_multistage(
972980
predicate=primaryThread,
973981
)
974982
da = nvgpu.WarpgroupGenerateDescriptorOp(
975-
a_wgmma_ty, a_tile_slice, a_tma_desc
983+
a_wgmma_ty, a_tile_slice, a_tma_desc_op
976984
)
977985
db = nvgpu.WarpgroupGenerateDescriptorOp(
978-
b_wgmma_ty, b_tile_slice, b_tma_desc
986+
b_wgmma_ty, b_tile_slice, b_tma_desc_op
979987
)
980988

981989
# Step 4.3. MMA
@@ -1060,15 +1068,15 @@ def generate_matmul_multistage(
10601068
nvgpu.TmaAsyncLoadOp(
10611069
a_tma_slice,
10621070
mbarTMA,
1063-
a_tma_desc,
1071+
a_tma_desc_op,
10641072
coordinates=[coord, dimX],
10651073
mbarId=nextSlot,
10661074
predicate=p,
10671075
)
10681076
nvgpu.TmaAsyncLoadOp(
10691077
b_tma_slice_1,
10701078
mbarTMA,
1071-
b_tma_desc,
1079+
b_tma_desc_op,
10721080
coordinates=[dimY, coord],
10731081
mbarId=nextSlot,
10741082
predicate=p,
@@ -1077,7 +1085,7 @@ def generate_matmul_multistage(
10771085
nvgpu.TmaAsyncLoadOp(
10781086
b_tma_slice_2,
10791087
mbarTMA,
1080-
b_tma_desc,
1088+
b_tma_desc_op,
10811089
coordinates=[dimY2, coord],
10821090
mbarId=nextSlot,
10831091
predicate=p,

0 commit comments

Comments
 (0)