Skip to content

Commit 52030c8

Browse files
gryppManish Gupta
andcommitted
[mlir][nvgpu] Simplify TMA IR generation
This PR simplifies TMA generation in the test, makes the code more readable. Co-authored-by: Manish Gupta <[email protected]>
1 parent d06ba37 commit 52030c8

File tree

1 file changed

+96
-93
lines changed

1 file changed

+96
-93
lines changed

mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py

Lines changed: 96 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,43 @@
2828
DEBUG = False
2929

3030

31+
class TmaDescriptorBuilder:
32+
"""A class that builds a TMA descriptor."""
33+
34+
def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
35+
self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
36+
self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
37+
self.oob = oob # mlir.nvgpu.TensorMapOOBKind
38+
self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
39+
self.tma_box_shape = tma_box_shape
40+
self.memref_ty = memref_ty # MemRefType
41+
42+
@property
43+
def tensormap_descriptor_ty(self):
44+
"""Returns a tensormap descriptor type."""
45+
memref_str = f"memref<{self.tma_box_shape[0]}x{self.tma_box_shape[1]}x{self.memref_ty.element_type}, 3>"
46+
parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
47+
swizzle = {self.swizzle},\
48+
l2promo = {self.l2promo},\
49+
oob = {self.oob},\
50+
interleave = {self.interleave}>"
51+
return ir.Type.parse(parse_str)
52+
53+
def tma_descriptor_op(self, device_ptr):
54+
"""Returns a tensormap descriptor op."""
55+
tma_descriptor_ty = self.tensormap_descriptor_ty
56+
device_unranked_memref = memref.CastOp(
57+
ir.UnrankedMemRefType.get(
58+
self.memref_ty.element_type, self.memref_ty.memory_space
59+
),
60+
device_ptr,
61+
)
62+
tma_descriptor_op = nvgpu.TmaCreateDescriptorOp(
63+
tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape)
64+
)
65+
return tma_descriptor_op.result
66+
67+
3168
def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
3269
if not DEBUG and not forcePrint:
3370
return
@@ -162,28 +199,6 @@ def generate_matmul_ws(
162199
+ str(num_stages)
163200
+ ">"
164201
)
165-
a_tma_desc_ty = ir.Type.parse(
166-
"!nvgpu.tensormap.descriptor<tensor = memref<"
167-
+ str(BLOCK_M)
168-
+ "x"
169-
+ str(TMA_LAST_DIM_F16)
170-
+ "x"
171-
+ str(a_elem_ty)
172-
+ ", "
173-
+ str(smem_space)
174-
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
175-
)
176-
b_tma_desc_ty = ir.Type.parse(
177-
"!nvgpu.tensormap.descriptor<tensor = memref<"
178-
+ str(BLOCK_K)
179-
+ "x"
180-
+ str(TMA_LAST_DIM_F16)
181-
+ "x"
182-
+ str(b_elem_ty)
183-
+ ", "
184-
+ str(smem_space)
185-
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
186-
)
187202
acc_ty = ir.Type.parse(
188203
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
189204
+ str(BLOCK_M)
@@ -240,21 +255,26 @@ def generate_matmul_ws(
240255
t7 = gpu.wait(token_ty, [t6])
241256

242257
# Step 2. Create TMA Descriptors
243-
tma_specs = [
244-
(a_device, a_tma_desc_ty, a_tma_shape),
245-
(b_device, b_tma_desc_ty, b_tma_shape),
246-
]
247-
tma_descs = []
248-
for x_device, tensor_map_ty, tile_shape in tma_specs:
249-
x_unranked = memref.cast(
250-
ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
251-
)
252-
tma_descs.append(
253-
nvgpu.TmaCreateDescriptorOp(
254-
tensor_map_ty, x_unranked, map(c, tile_shape)
255-
).result
256-
)
257-
a_tma_desc, b_tma_desc = tma_descs
258+
a_tma_desc = TmaDescriptorBuilder(
259+
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
260+
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
261+
nvgpu.TensorMapOOBKind.OOB_ZERO,
262+
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
263+
a_tma_shape,
264+
a_ty,
265+
)
266+
267+
b_tma_desc = TmaDescriptorBuilder(
268+
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
269+
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
270+
nvgpu.TensorMapOOBKind.OOB_ZERO,
271+
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
272+
b_tma_shape,
273+
b_ty,
274+
)
275+
276+
a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
277+
b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
258278

259279
# Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
260280
cta_m = M // BLOCK_M
@@ -267,7 +287,7 @@ def generate_matmul_ws(
267287
[t7],
268288
*map(c, grid),
269289
*map(c, block),
270-
dynamicSharedMemorySize=c(smem_size, ty=T.i32())
290+
dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
271291
)
272292
launch_op.body.blocks.append(*([T.index()] * 12))
273293
with ir.InsertionPoint(launch_op.body.blocks[0]):
@@ -315,8 +335,8 @@ def generate_matmul_ws(
315335
gpu.barrier()
316336

317337
# GPU Step 3. Prefetch TMA descriptors
318-
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
319-
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
338+
nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread)
339+
nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread)
320340

321341
ns = num_stages if num_stages == 1 else num_stages - 1
322342
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
@@ -405,15 +425,15 @@ def generate_matmul_ws(
405425
nvgpu.TmaAsyncLoadOp(
406426
a_tma_slice,
407427
mbarTMA,
408-
a_tma_desc,
428+
a_tma_desc_op,
409429
coordinates=[coord, dimX],
410430
mbarId=stage,
411431
predicate=producerPrimaryThread,
412432
)
413433
nvgpu.TmaAsyncLoadOp(
414434
b_tma_slice_1,
415435
mbarTMA,
416-
b_tma_desc,
436+
b_tma_desc_op,
417437
coordinates=[dimY, coord],
418438
mbarId=stage,
419439
predicate=producerPrimaryThread,
@@ -422,7 +442,7 @@ def generate_matmul_ws(
422442
nvgpu.TmaAsyncLoadOp(
423443
b_tma_slice_2,
424444
mbarTMA,
425-
b_tma_desc,
445+
b_tma_desc_op,
426446
coordinates=[dimY2, coord],
427447
mbarId=stage,
428448
predicate=producerPrimaryThread,
@@ -514,10 +534,10 @@ def generate_matmul_ws(
514534
predicate=consumerPrimaryThread,
515535
)
516536
da = nvgpu.WarpgroupGenerateDescriptorOp(
517-
a_wgmma_ty, a_tile_slice, a_tma_desc
537+
a_wgmma_ty, a_tile_slice, a_tma_desc_op
518538
)
519539
db = nvgpu.WarpgroupGenerateDescriptorOp(
520-
b_wgmma_ty, b_tile_slice, b_tma_desc
540+
b_wgmma_ty, b_tile_slice, b_tma_desc_op
521541
)
522542

523543
# Step 6.3.3. MMA
@@ -679,28 +699,6 @@ def generate_matmul_multistage(
679699
+ str(num_stages)
680700
+ ">"
681701
)
682-
a_tma_desc_ty = ir.Type.parse(
683-
"!nvgpu.tensormap.descriptor<tensor = memref<"
684-
+ str(BLOCK_M)
685-
+ "x"
686-
+ str(TMA_LAST_DIM_F16)
687-
+ "x"
688-
+ str(a_elem_ty)
689-
+ ", "
690-
+ str(smem_space)
691-
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
692-
)
693-
b_tma_desc_ty = ir.Type.parse(
694-
"!nvgpu.tensormap.descriptor<tensor = memref<"
695-
+ str(BLOCK_K)
696-
+ "x"
697-
+ str(TMA_LAST_DIM_F16)
698-
+ "x"
699-
+ str(b_elem_ty)
700-
+ ", "
701-
+ str(smem_space)
702-
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
703-
)
704702
acc_ty = ir.Type.parse(
705703
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
706704
+ str(BLOCK_M)
@@ -767,21 +765,26 @@ def generate_matmul_multistage(
767765
t7 = gpu.wait(token_ty, [t6])
768766

769767
# Step 2. Create TMA Descriptors
770-
tma_specs = [
771-
(a_device, a_tma_desc_ty, a_tma_shape),
772-
(b_device, b_tma_desc_ty, b_tma_shape),
773-
]
774-
tma_descs = []
775-
for x_device, tensor_map_ty, tile_shape in tma_specs:
776-
x_unranked = memref.cast(
777-
ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
778-
)
779-
tma_descs.append(
780-
nvgpu.TmaCreateDescriptorOp(
781-
tensor_map_ty, x_unranked, map(c, tile_shape)
782-
).result
783-
)
784-
a_tma_desc, b_tma_desc = tma_descs
768+
a_tma_desc = TmaDescriptorBuilder(
769+
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
770+
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
771+
nvgpu.TensorMapOOBKind.OOB_ZERO,
772+
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
773+
a_tma_shape,
774+
a_ty,
775+
)
776+
777+
b_tma_desc = TmaDescriptorBuilder(
778+
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
779+
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
780+
nvgpu.TensorMapOOBKind.OOB_ZERO,
781+
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
782+
b_tma_shape,
783+
b_ty,
784+
)
785+
786+
a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
787+
b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
785788

786789
# Step 3. Launch Kernel with 1 Warpgroup
787790
cta_m = M // BLOCK_M
@@ -794,7 +797,7 @@ def generate_matmul_multistage(
794797
[t7],
795798
*map(c, grid),
796799
*map(c, block),
797-
dynamicSharedMemorySize=c(smem_size, ty=T.i32())
800+
dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
798801
)
799802
launch_op.body.blocks.append(*([T.index()] * 12))
800803
with ir.InsertionPoint(launch_op.body.blocks[0]):
@@ -819,8 +822,8 @@ def generate_matmul_multistage(
819822
gpu.barrier()
820823

821824
# GPU Step 2. Prefetch TMA descriptors
822-
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
823-
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
825+
nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread)
826+
nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread)
824827

825828
# GPU Step 3. Prologue (global memory --> shared memory)
826829
ns = num_stages if num_stages == 1 else num_stages - 1
@@ -880,23 +883,23 @@ def generate_matmul_multistage(
880883
nvgpu.TmaAsyncLoadOp(
881884
a_tma_slice,
882885
mbarTMA,
883-
a_tma_desc,
886+
a_tma_desc_op,
884887
coordinates=[coord, dimX],
885888
mbarId=iv,
886889
predicate=primaryThread,
887890
)
888891
nvgpu.TmaAsyncLoadOp(
889892
b_tma_slice_1,
890893
mbarTMA,
891-
b_tma_desc,
894+
b_tma_desc_op,
892895
coordinates=[dimY, coord],
893896
mbarId=iv,
894897
predicate=primaryThread,
895898
)
896899
nvgpu.TmaAsyncLoadOp(
897900
b_tma_slice_2,
898901
mbarTMA,
899-
b_tma_desc,
902+
b_tma_desc_op,
900903
coordinates=[dimY2, coord],
901904
mbarId=iv,
902905
predicate=primaryThread,
@@ -972,10 +975,10 @@ def generate_matmul_multistage(
972975
predicate=primaryThread,
973976
)
974977
da = nvgpu.WarpgroupGenerateDescriptorOp(
975-
a_wgmma_ty, a_tile_slice, a_tma_desc
978+
a_wgmma_ty, a_tile_slice, a_tma_desc_op
976979
)
977980
db = nvgpu.WarpgroupGenerateDescriptorOp(
978-
b_wgmma_ty, b_tile_slice, b_tma_desc
981+
b_wgmma_ty, b_tile_slice, b_tma_desc_op
979982
)
980983

981984
# Step 4.3. MMA
@@ -1060,15 +1063,15 @@ def generate_matmul_multistage(
10601063
nvgpu.TmaAsyncLoadOp(
10611064
a_tma_slice,
10621065
mbarTMA,
1063-
a_tma_desc,
1066+
a_tma_desc_op,
10641067
coordinates=[coord, dimX],
10651068
mbarId=nextSlot,
10661069
predicate=p,
10671070
)
10681071
nvgpu.TmaAsyncLoadOp(
10691072
b_tma_slice_1,
10701073
mbarTMA,
1071-
b_tma_desc,
1074+
b_tma_desc_op,
10721075
coordinates=[dimY, coord],
10731076
mbarId=nextSlot,
10741077
predicate=p,
@@ -1077,7 +1080,7 @@ def generate_matmul_multistage(
10771080
nvgpu.TmaAsyncLoadOp(
10781081
b_tma_slice_2,
10791082
mbarTMA,
1080-
b_tma_desc,
1083+
b_tma_desc_op,
10811084
coordinates=[dimY2, coord],
10821085
mbarId=nextSlot,
10831086
predicate=p,

0 commit comments

Comments
 (0)