Skip to content

[mlir][nvgpu] Simplify TMA IR generation #87153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 18, 2024
Merged

[mlir][nvgpu] Simplify TMA IR generation #87153

merged 2 commits into from
Apr 18, 2024

Conversation

grypp
Copy link
Member

@grypp grypp commented Mar 30, 2024

This PR add TmaDescriptorBuilder

  • class simplifies TMA generation.
  • Makes the code ready to support various Tma configurations
    • removes strings and use the enums from mlir.nvgpu.ENUMs.
    • Example "swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none" to enums in mlir.nvgpu dialects.
    • Enums have string equivalent that are used during the IR writing and generation (see TmaDescriptorBuilder::tensormap_descriptor_ty).
  • Improves readability and abstracts out TMA descriptor builders in reusable component.

@llvmbot
Copy link
Member

llvmbot commented Mar 30, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Guray Ozen (grypp)

Changes

This PR add TmaDescriptorBuilder class simplifies TMA generation in the test, makes the code more readable.


Full diff: https://github.com/llvm/llvm-project/pull/87153.diff

1 Files Affected:

  • (modified) mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py (+96-93)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index fac138dce605a7..6823587801a7b0 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -28,6 +28,43 @@
 DEBUG = False
 
 
+class TmaDescriptorBuilder:
+    """A class that builds a TMA descriptor."""
+
+    def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
+        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
+        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
+        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
+        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
+        self.tma_box_shape = tma_box_shape
+        self.memref_ty = memref_ty  # MemRefType
+
+    @property
+    def tensormap_descriptor_ty(self):
+        """Returns a tensormap descriptor type."""
+        memref_str = f"memref<{self.tma_box_shape[0]}x{self.tma_box_shape[1]}x{self.memref_ty.element_type}, 3>"
+        parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
+                                              swizzle = {self.swizzle},\
+                                              l2promo = {self.l2promo},\
+                                              oob = {self.oob},\
+                                              interleave = {self.interleave}>"
+        return ir.Type.parse(parse_str)
+
+    def tma_descriptor_op(self, device_ptr):
+        """Returns a tensormap descriptor op."""
+        tma_descriptor_ty = self.tensormap_descriptor_ty
+        device_unranked_memref = memref.CastOp(
+            ir.UnrankedMemRefType.get(
+                self.memref_ty.element_type, self.memref_ty.memory_space
+            ),
+            device_ptr,
+        )
+        tma_descriptor_op = nvgpu.TmaCreateDescriptorOp(
+            tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape)
+        )
+        return tma_descriptor_op.result
+
+
 def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
     if not DEBUG and not forcePrint:
         return
@@ -162,28 +199,6 @@ def generate_matmul_ws(
         + str(num_stages)
         + ">"
     )
-    a_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<"
-        + str(BLOCK_M)
-        + "x"
-        + str(TMA_LAST_DIM_F16)
-        + "x"
-        + str(a_elem_ty)
-        + ", "
-        + str(smem_space)
-        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
-    )
-    b_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<"
-        + str(BLOCK_K)
-        + "x"
-        + str(TMA_LAST_DIM_F16)
-        + "x"
-        + str(b_elem_ty)
-        + ", "
-        + str(smem_space)
-        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
-    )
     acc_ty = ir.Type.parse(
         "!nvgpu.warpgroup.accumulator<fragmented=vector<"
         + str(BLOCK_M)
@@ -240,21 +255,26 @@ def generate_matmul_ws(
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [
-                (a_device, a_tma_desc_ty, a_tma_shape),
-                (b_device, b_tma_desc_ty, b_tma_shape),
-            ]
-            tma_descs = []
-            for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
-                )
-                tma_descs.append(
-                    nvgpu.TmaCreateDescriptorOp(
-                        tensor_map_ty, x_unranked, map(c, tile_shape)
-                    ).result
-                )
-            a_tma_desc, b_tma_desc = tma_descs
+            a_tma_desc = TmaDescriptorBuilder(
+                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+                nvgpu.TensorMapOOBKind.OOB_ZERO,
+                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+                a_tma_shape,
+                a_ty,
+            )
+
+            b_tma_desc = TmaDescriptorBuilder(
+                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+                nvgpu.TensorMapOOBKind.OOB_ZERO,
+                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+                b_tma_shape,
+                b_ty,
+            )
+
+            a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
+            b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
 
             # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
             cta_m = M // BLOCK_M
@@ -267,7 +287,7 @@ def generate_matmul_ws(
                 [t7],
                 *map(c, grid),
                 *map(c, block),
-                dynamicSharedMemorySize=c(smem_size, ty=T.i32())
+                dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
             )
             launch_op.body.blocks.append(*([T.index()] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
@@ -315,8 +335,8 @@ def generate_matmul_ws(
                 gpu.barrier()
 
                 # GPU Step 3. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread)
 
                 ns = num_stages if num_stages == 1 else num_stages - 1
                 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
@@ -405,7 +425,7 @@ def generate_matmul_ws(
                         nvgpu.TmaAsyncLoadOp(
                             a_tma_slice,
                             mbarTMA,
-                            a_tma_desc,
+                            a_tma_desc_op,
                             coordinates=[coord, dimX],
                             mbarId=stage,
                             predicate=producerPrimaryThread,
@@ -413,7 +433,7 @@ def generate_matmul_ws(
                         nvgpu.TmaAsyncLoadOp(
                             b_tma_slice_1,
                             mbarTMA,
-                            b_tma_desc,
+                            b_tma_desc_op,
                             coordinates=[dimY, coord],
                             mbarId=stage,
                             predicate=producerPrimaryThread,
@@ -422,7 +442,7 @@ def generate_matmul_ws(
                         nvgpu.TmaAsyncLoadOp(
                             b_tma_slice_2,
                             mbarTMA,
-                            b_tma_desc,
+                            b_tma_desc_op,
                             coordinates=[dimY2, coord],
                             mbarId=stage,
                             predicate=producerPrimaryThread,
@@ -514,10 +534,10 @@ def generate_matmul_ws(
                             predicate=consumerPrimaryThread,
                         )
                         da = nvgpu.WarpgroupGenerateDescriptorOp(
-                            a_wgmma_ty, a_tile_slice, a_tma_desc
+                            a_wgmma_ty, a_tile_slice, a_tma_desc_op
                         )
                         db = nvgpu.WarpgroupGenerateDescriptorOp(
-                            b_wgmma_ty, b_tile_slice, b_tma_desc
+                            b_wgmma_ty, b_tile_slice, b_tma_desc_op
                         )
 
                         # Step 6.3.3. MMA
@@ -679,28 +699,6 @@ def generate_matmul_multistage(
         + str(num_stages)
         + ">"
     )
-    a_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<"
-        + str(BLOCK_M)
-        + "x"
-        + str(TMA_LAST_DIM_F16)
-        + "x"
-        + str(a_elem_ty)
-        + ", "
-        + str(smem_space)
-        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
-    )
-    b_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<"
-        + str(BLOCK_K)
-        + "x"
-        + str(TMA_LAST_DIM_F16)
-        + "x"
-        + str(b_elem_ty)
-        + ", "
-        + str(smem_space)
-        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
-    )
     acc_ty = ir.Type.parse(
         "!nvgpu.warpgroup.accumulator<fragmented=vector<"
         + str(BLOCK_M)
@@ -767,21 +765,26 @@ def generate_matmul_multistage(
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [
-                (a_device, a_tma_desc_ty, a_tma_shape),
-                (b_device, b_tma_desc_ty, b_tma_shape),
-            ]
-            tma_descs = []
-            for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
-                )
-                tma_descs.append(
-                    nvgpu.TmaCreateDescriptorOp(
-                        tensor_map_ty, x_unranked, map(c, tile_shape)
-                    ).result
-                )
-            a_tma_desc, b_tma_desc = tma_descs
+            a_tma_desc = TmaDescriptorBuilder(
+                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+                nvgpu.TensorMapOOBKind.OOB_ZERO,
+                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+                a_tma_shape,
+                a_ty,
+            )
+
+            b_tma_desc = TmaDescriptorBuilder(
+                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+                nvgpu.TensorMapOOBKind.OOB_ZERO,
+                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+                b_tma_shape,
+                b_ty,
+            )
+
+            a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
+            b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
 
             # Step 3. Launch Kernel with 1 Warpgroup
             cta_m = M // BLOCK_M
@@ -794,7 +797,7 @@ def generate_matmul_multistage(
                 [t7],
                 *map(c, grid),
                 *map(c, block),
-                dynamicSharedMemorySize=c(smem_size, ty=T.i32())
+                dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
             )
             launch_op.body.blocks.append(*([T.index()] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
@@ -819,8 +822,8 @@ def generate_matmul_multistage(
                 gpu.barrier()
 
                 # GPU Step 2. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread)
 
                 # GPU Step 3. Prologue (global memory --> shared memory)
                 ns = num_stages if num_stages == 1 else num_stages - 1
@@ -880,7 +883,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         a_tma_slice,
                         mbarTMA,
-                        a_tma_desc,
+                        a_tma_desc_op,
                         coordinates=[coord, dimX],
                         mbarId=iv,
                         predicate=primaryThread,
@@ -888,7 +891,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         b_tma_slice_1,
                         mbarTMA,
-                        b_tma_desc,
+                        b_tma_desc_op,
                         coordinates=[dimY, coord],
                         mbarId=iv,
                         predicate=primaryThread,
@@ -896,7 +899,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         b_tma_slice_2,
                         mbarTMA,
-                        b_tma_desc,
+                        b_tma_desc_op,
                         coordinates=[dimY2, coord],
                         mbarId=iv,
                         predicate=primaryThread,
@@ -972,10 +975,10 @@ def generate_matmul_multistage(
                         predicate=primaryThread,
                     )
                     da = nvgpu.WarpgroupGenerateDescriptorOp(
-                        a_wgmma_ty, a_tile_slice, a_tma_desc
+                        a_wgmma_ty, a_tile_slice, a_tma_desc_op
                     )
                     db = nvgpu.WarpgroupGenerateDescriptorOp(
-                        b_wgmma_ty, b_tile_slice, b_tma_desc
+                        b_wgmma_ty, b_tile_slice, b_tma_desc_op
                     )
 
                     # Step 4.3. MMA
@@ -1060,7 +1063,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         a_tma_slice,
                         mbarTMA,
-                        a_tma_desc,
+                        a_tma_desc_op,
                         coordinates=[coord, dimX],
                         mbarId=nextSlot,
                         predicate=p,
@@ -1068,7 +1071,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         b_tma_slice_1,
                         mbarTMA,
-                        b_tma_desc,
+                        b_tma_desc_op,
                         coordinates=[dimY, coord],
                         mbarId=nextSlot,
                         predicate=p,
@@ -1077,7 +1080,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         b_tma_slice_2,
                         mbarTMA,
-                        b_tma_desc,
+                        b_tma_desc_op,
                         coordinates=[dimY2, coord],
                         mbarId=nextSlot,
                         predicate=p,

@grypp grypp marked this pull request as draft March 30, 2024 12:16
Copy link

@manishucsd manishucsd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@manishucsd manishucsd marked this pull request as ready for review March 30, 2024 15:25
grypp added a commit that referenced this pull request Apr 17, 2024
This PR adds NVGPU dialects' TensorMapDescriptorType in the py bindings.

This is a follow-up issue from [this
PR](#87153 (comment))
grypp and others added 2 commits April 18, 2024 07:25
This PR simplifies TMA generation in the test, makes the code more readable.

Co-authored-by: Manish Gupta <[email protected]>
@grypp grypp merged commit c82f45f into llvm:main Apr 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants