Skip to content

[mlir][nvgpu] Simplify TMA IR generation #87153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 101 additions & 93 deletions mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,48 @@
DEBUG = False


class TmaDescriptorBuilder:
"""A class that builds a TMA descriptor."""

def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
self.oob = oob # mlir.nvgpu.TensorMapOOBKind
self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
self.tma_box_shape = tma_box_shape
self.memref_ty = memref_ty # MemRefType

@property
def tensormap_descriptor_ty(self):
"""Returns a tensormap descriptor type."""
tensorMemrefType = ir.MemRefType.get(
self.tma_box_shape,
self.memref_ty.element_type,
memory_space=ir.Attribute.parse("3"),
)
return nvgpu.TensorMapDescriptorType.get(
tensorMemrefType,
self.swizzle,
self.l2promo,
self.oob,
self.interleave,
)

def tma_descriptor_op(self, device_ptr):
"""Returns a tensormap descriptor op."""
tma_descriptor_ty = self.tensormap_descriptor_ty
device_unranked_memref = memref.CastOp(
ir.UnrankedMemRefType.get(
self.memref_ty.element_type, self.memref_ty.memory_space
),
device_ptr,
)
tma_descriptor_op = nvgpu.TmaCreateDescriptorOp(
tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape)
)
return tma_descriptor_op.result


def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
if not DEBUG and not forcePrint:
return
Expand Down Expand Up @@ -162,28 +204,6 @@ def generate_matmul_ws(
+ str(num_stages)
+ ">"
)
a_tma_desc_ty = ir.Type.parse(
"!nvgpu.tensormap.descriptor<tensor = memref<"
+ str(BLOCK_M)
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
+ str(a_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
)
b_tma_desc_ty = ir.Type.parse(
"!nvgpu.tensormap.descriptor<tensor = memref<"
+ str(BLOCK_K)
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
+ str(b_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
)
acc_ty = ir.Type.parse(
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ str(BLOCK_M)
Expand Down Expand Up @@ -240,21 +260,26 @@ def generate_matmul_ws(
t7 = gpu.wait(token_ty, [t6])

# Step 2. Create TMA Descriptors
tma_specs = [
(a_device, a_tma_desc_ty, a_tma_shape),
(b_device, b_tma_desc_ty, b_tma_shape),
]
tma_descs = []
for x_device, tensor_map_ty, tile_shape in tma_specs:
x_unranked = memref.cast(
ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
)
tma_descs.append(
nvgpu.TmaCreateDescriptorOp(
tensor_map_ty, x_unranked, map(c, tile_shape)
).result
)
a_tma_desc, b_tma_desc = tma_descs
a_tma_desc = TmaDescriptorBuilder(
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
nvgpu.TensorMapOOBKind.OOB_ZERO,
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
a_tma_shape,
a_ty,
)

b_tma_desc = TmaDescriptorBuilder(
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
nvgpu.TensorMapOOBKind.OOB_ZERO,
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
b_tma_shape,
b_ty,
)

a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)

# Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
cta_m = M // BLOCK_M
Expand All @@ -267,7 +292,7 @@ def generate_matmul_ws(
[t7],
*map(c, grid),
*map(c, block),
dynamicSharedMemorySize=c(smem_size, ty=T.i32())
dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
)
launch_op.body.blocks.append(*([T.index()] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
Expand Down Expand Up @@ -315,8 +340,8 @@ def generate_matmul_ws(
gpu.barrier()

# GPU Step 3. Prefetch TMA descriptors
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread)
nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread)

ns = num_stages if num_stages == 1 else num_stages - 1
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
Expand Down Expand Up @@ -405,15 +430,15 @@ def generate_matmul_ws(
nvgpu.TmaAsyncLoadOp(
a_tma_slice,
mbarTMA,
a_tma_desc,
a_tma_desc_op,
coordinates=[coord, dimX],
mbarId=stage,
predicate=producerPrimaryThread,
)
nvgpu.TmaAsyncLoadOp(
b_tma_slice_1,
mbarTMA,
b_tma_desc,
b_tma_desc_op,
coordinates=[dimY, coord],
mbarId=stage,
predicate=producerPrimaryThread,
Expand All @@ -422,7 +447,7 @@ def generate_matmul_ws(
nvgpu.TmaAsyncLoadOp(
b_tma_slice_2,
mbarTMA,
b_tma_desc,
b_tma_desc_op,
coordinates=[dimY2, coord],
mbarId=stage,
predicate=producerPrimaryThread,
Expand Down Expand Up @@ -514,10 +539,10 @@ def generate_matmul_ws(
predicate=consumerPrimaryThread,
)
da = nvgpu.WarpgroupGenerateDescriptorOp(
a_wgmma_ty, a_tile_slice, a_tma_desc
a_wgmma_ty, a_tile_slice, a_tma_desc_op
)
db = nvgpu.WarpgroupGenerateDescriptorOp(
b_wgmma_ty, b_tile_slice, b_tma_desc
b_wgmma_ty, b_tile_slice, b_tma_desc_op
)

# Step 6.3.3. MMA
Expand Down Expand Up @@ -679,28 +704,6 @@ def generate_matmul_multistage(
+ str(num_stages)
+ ">"
)
a_tma_desc_ty = ir.Type.parse(
"!nvgpu.tensormap.descriptor<tensor = memref<"
+ str(BLOCK_M)
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
+ str(a_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
)
b_tma_desc_ty = ir.Type.parse(
"!nvgpu.tensormap.descriptor<tensor = memref<"
+ str(BLOCK_K)
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
+ str(b_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
)
acc_ty = ir.Type.parse(
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ str(BLOCK_M)
Expand Down Expand Up @@ -767,21 +770,26 @@ def generate_matmul_multistage(
t7 = gpu.wait(token_ty, [t6])

# Step 2. Create TMA Descriptors
tma_specs = [
(a_device, a_tma_desc_ty, a_tma_shape),
(b_device, b_tma_desc_ty, b_tma_shape),
]
tma_descs = []
for x_device, tensor_map_ty, tile_shape in tma_specs:
x_unranked = memref.cast(
ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
)
tma_descs.append(
nvgpu.TmaCreateDescriptorOp(
tensor_map_ty, x_unranked, map(c, tile_shape)
).result
)
a_tma_desc, b_tma_desc = tma_descs
a_tma_desc = TmaDescriptorBuilder(
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
nvgpu.TensorMapOOBKind.OOB_ZERO,
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
a_tma_shape,
a_ty,
)

b_tma_desc = TmaDescriptorBuilder(
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
nvgpu.TensorMapOOBKind.OOB_ZERO,
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
b_tma_shape,
b_ty,
)

a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)

# Step 3. Launch Kernel with 1 Warpgroup
cta_m = M // BLOCK_M
Expand All @@ -794,7 +802,7 @@ def generate_matmul_multistage(
[t7],
*map(c, grid),
*map(c, block),
dynamicSharedMemorySize=c(smem_size, ty=T.i32())
dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
)
launch_op.body.blocks.append(*([T.index()] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
Expand All @@ -819,8 +827,8 @@ def generate_matmul_multistage(
gpu.barrier()

# GPU Step 2. Prefetch TMA descriptors
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread)
nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread)

# GPU Step 3. Prologue (global memory --> shared memory)
ns = num_stages if num_stages == 1 else num_stages - 1
Expand Down Expand Up @@ -880,23 +888,23 @@ def generate_matmul_multistage(
nvgpu.TmaAsyncLoadOp(
a_tma_slice,
mbarTMA,
a_tma_desc,
a_tma_desc_op,
coordinates=[coord, dimX],
mbarId=iv,
predicate=primaryThread,
)
nvgpu.TmaAsyncLoadOp(
b_tma_slice_1,
mbarTMA,
b_tma_desc,
b_tma_desc_op,
coordinates=[dimY, coord],
mbarId=iv,
predicate=primaryThread,
)
nvgpu.TmaAsyncLoadOp(
b_tma_slice_2,
mbarTMA,
b_tma_desc,
b_tma_desc_op,
coordinates=[dimY2, coord],
mbarId=iv,
predicate=primaryThread,
Expand Down Expand Up @@ -972,10 +980,10 @@ def generate_matmul_multistage(
predicate=primaryThread,
)
da = nvgpu.WarpgroupGenerateDescriptorOp(
a_wgmma_ty, a_tile_slice, a_tma_desc
a_wgmma_ty, a_tile_slice, a_tma_desc_op
)
db = nvgpu.WarpgroupGenerateDescriptorOp(
b_wgmma_ty, b_tile_slice, b_tma_desc
b_wgmma_ty, b_tile_slice, b_tma_desc_op
)

# Step 4.3. MMA
Expand Down Expand Up @@ -1060,15 +1068,15 @@ def generate_matmul_multistage(
nvgpu.TmaAsyncLoadOp(
a_tma_slice,
mbarTMA,
a_tma_desc,
a_tma_desc_op,
coordinates=[coord, dimX],
mbarId=nextSlot,
predicate=p,
)
nvgpu.TmaAsyncLoadOp(
b_tma_slice_1,
mbarTMA,
b_tma_desc,
b_tma_desc_op,
coordinates=[dimY, coord],
mbarId=nextSlot,
predicate=p,
Expand All @@ -1077,7 +1085,7 @@ def generate_matmul_multistage(
nvgpu.TmaAsyncLoadOp(
b_tma_slice_2,
mbarTMA,
b_tma_desc,
b_tma_desc_op,
coordinates=[dimY2, coord],
mbarId=nextSlot,
predicate=p,
Expand Down