-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] GEMM Hopper Tensor Core Integration Test #81478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Guray Ozen (grypp) ChangesThis test aims to validate the correctness of the supported GEMM kernels in NVGPU dialects, with current support for Multistage and Warp Specialization kernels. Example:
Parallelism Across CTAsGEMM includes three loops defining the shape of the GEMM, specified in the
Multistage KernelThis kernel launches a single warp group (128 threads). The primary thread (pthread) requests load from TMA. Threads collectively wait for the data and perform mma operations. After completing the shape, threads together store first fragmented registers to shared memory, then from shared memory to global memory; this part is called the epilogue. Execution Timeline of Multistage Kernel with 3 stages:
Warp Specialization KernelThis kernel launches 2 warp groups (2x128 threads) per CTA, specializing one as Execution Timeline of Warp Specialization Kernel with 2 stages:
Patch is 49.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81478.diff 5 Files Affected:
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
new file mode 100644
index 00000000000000..2d5a9d00e73226
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+ config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
new file mode 100644
index 00000000000000..e153fcb44b9860
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -0,0 +1,186 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+# CHECK: PASS
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16, # input types
+# output_type=np.float32, # output type
+# M=4096, N=4096, K=4096, # Shape
+# BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+# use_warp_specialization=True, # Enable Warp Specialization
+# max_num_stages=3) # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x
+# for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y
+# for(bk = 0; k < K; K += BLOCK_K)
+# for(i = bi; i < (bi + BLOCK_M); ++i)
+# for(j = bj; j < (bj + BLOCK_N); ++j)
+# for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# | |Prologue ----> |MainLoop ----> |Epilogue |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# | |MainLoop ----> | 1st Epilogue | 2nd Epilogue |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] |
+# |wgroup1 | .......| | | | | | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+def generate_matmul(input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specilization=True,
+ saveIR=False,
+ max_num_stages=3):
+ with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+ if use_warp_specilization:
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
+ BLOCK_K, max_num_stages)
+ else:
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
+ BLOCK_N, BLOCK_K, max_num_stages)
+
+ mlir_nvgpu_module.operation.verify()
+
+ # Save generated IR
+ if saveIR:
+ # print(mlir_nvgpu_module)
+ original_stdout = sys.stdout
+ with open('gemm.mlir', 'w') as f:
+ sys.stdout = f
+ print(mlir_nvgpu_module)
+ sys.stdout = original_stdout
+
+ # Get compiler
+ options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+ support_lib = os.getenv("SUPPORT_LIB")
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+ compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+
+ # Compile
+ engine = compiler.compile_and_jit(mlir_nvgpu_module)
+ return engine
+
+
+def matmul(input_type=np.float16,
+ output_type=np.float32,
+ M=128,
+ N=128,
+ K=128,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specilization=True,
+ saveIR=False,
+ max_num_stages=3,
+ print_results=False,
+ no_verify=False):
+ # Print the configuration
+ ity = "f16" if input_type == np.float16 else "f32"
+ oty = "f16" if output_type == np.float16 else "f32"
+ gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+ print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
+ "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
+ str(max_num_stages) + " --===")
+
+ # Build IR and compile
+ engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
+ saveIR, max_num_stages)
+
+ # Allocate matrices and invoke the matmul
+ c = np.zeros((M, N), output_type)
+ a = np.random.randn(M, K).astype(input_type)
+ b = np.random.randn(K, N).astype(input_type)
+ mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+ mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+ mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+ kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+
+ # Launch the MLIR generated kernel
+ engine.invoke(kernelName, mem_a, mem_b, mem_c)
+
+ float_formatter = "{:.2f}".format
+ np.set_printoptions(formatter={'float_kind': float_formatter})
+
+ if print_results:
+ print(c)
+
+ # Verify the results
+ if not no_verify:
+ ref = a.astype(input_type) @ b.astype(input_type)
+ if print_results:
+ print(ref)
+ np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
+
+ print("PASS ")
+
+
+# GEMM Multistage f32 += f16 * f16
+matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+# GEMM Warp Specilized f32 += f16 * f16
+matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
new file mode 100644
index 00000000000000..09ab35d4c5f15e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -0,0 +1,676 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+
+
+TMA_LAST_DIM_F16 = 64 # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+MLIR_DYNAMIC = -9223372036854775808
+f16_byte = 2
+f32_byte = 4
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+ if not DEBUG and not forcePrint:
+ return
+ type_formats = []
+ for arg in args:
+ ty_format = None
+ if ir.IndexType.isinstance(arg.type):
+ ty_format = "%llu"
+ if ir.IntegerType.isinstance(arg.type):
+ width = ir.IntegerType(arg.type).width
+ if width == 64:
+ ty_format = "%llu"
+ elif width == 32:
+ ty_format = "%d"
+ elif width == 1:
+ ty_format = "%i"
+ if ir.F32Type.isinstance(arg.type):
+ ty_format = "%f"
+ if ty_format is None:
+ raise NotImplementedError(arg.type)
+ type_formats.append(ty_format)
+ if threadNumber != -1:
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+ scf.yield_([])
+ if_op = scf.IfOp(predicate)
+ with ir.InsertionPoint(if_op.then_block):
+ gpu.printf(fmt.format(*type_formats) + "\n", args)
+ scf.yield_([])
+
+
+def c(value, ty=None):
+ ty = ir.IndexType.get() if ty is None else ty
+ return arith.constant(ty, value)
+
+
+def generate_matmul_ws(input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=128,
+ max_num_stages=3):
+ # Limitaitons for now
+ assert input_type == np.float16
+ assert output_type == np.float32
+ assert M % BLOCK_M == 0
+ assert N % BLOCK_N == 0
+ assert K % BLOCK_K == 0
+
+ required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+ num_stages = min(required_stages, max_num_stages)
+
+ module = ir.Module.create()
+ f16 = ir.F16Type.get()
+ f32 = ir.F32Type.get()
+ i1 = ir.IntegerType.get_signless(1)
+ i32 = ir.IntegerType.get_signless(32)
+ index = ir.IndexType.get()
+ i8 = ir.IntegerType.get_signless(8)
+ token_ty = ir.Type.parse("!gpu.async.token")
+ a_ty = ir.MemRefType.get([M, K], f16)
+ b_ty = ir.MemRefType.get((K, N), f16)
+ c_elem_ty = f16 if output_type == np.float16 else f32
+ c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+ a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+ b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+ b_tile_shape = (BLOCK_K, BLOCK_N)
+ txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+ smem_space_str = "#gpu.address_space<workgroup>"
+ smem_space = ir.Attribute.parse(smem_space_str)
+ input_type_str = "f16" if input_type == np.float16 else "f32"
+ output_type_str = "f16" if output_type == np.float16 else "f32"
+ mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+ str(num_stages) + ">")
+ a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+ str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+ str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+ str(output_type_str) + ">>")
+ a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+ str(input_type_str) + ", " + smem_space_str + ">>")
+ b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+ str(input_type_str) + ", " + smem_space_str + ">>")
+
+ with ir.InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
+ lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+ rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+ smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+ smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+ smem_size = max(smem_size_input, smem_size_output)
+
+ # Step 1. Allocate device memory and memcpy
+ t1 = gpu.wait(token_ty, [])
+ a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+ b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+ c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+ t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+ t7 = gpu.wait(token_ty, [t6])
+
+ # Step 2. Create TMA Descriptors
+ tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+ tma_descs = []
+ for x_device, tensor_map_ty, tile_shape in tma_specs:
+ x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+ tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+ a_tma_desc, b_tma_desc = tma_descs
+
+ # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
+ cta_m = M // BLOCK_M
+ cta_n = N // BLOCK_N
+ assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+ grid = (cta_m, cta_n, 1)
+ block = (WARP_GROUP_SIZE * 2, 1, 1)
+ launch_op = gpu.LaunchOp(token_ty, [t7],
+ *map(c, grid),
+ *map(c, block),
+ dynamicSharedMemorySize=c(smem_size, ty=i32))
+ launch_op.body.blocks.append(*([index] * 12))
+ with ir.InsertionPoint(launch_op.body.blocks[0]):
+ # GPU Step 0. This is need for vectorized ld/st
+ memref.assume_alignment(c_device, 16)
+ dynamic_smem = gpu.dynamic_shared_memory(
+ ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+ ticks = c(10000000)
+
+ # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+ warp_id = arith.divui(tidx, c(32))
+ warpgroup_id = arith.divui(warp_id, c(4))
+ is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+ c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+ is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+ c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+ producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
+ consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+ bidx = gpu.block_id(gpu.Dimension.x)
+ bidy = gpu.block_id(gpu.Dimension.y)
+ dimX = arith.muli(bidx, c(BLOCK_M))
+ dimY = arith.muli(bidy, c(BLOCK_N))
+
+ # GPU Step 2. Initialize mbarrier groups
+ mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+ mbarDONE = nvgpu.mbarrier_create(mbar_ty)
+ for i in range(num_stages):
+ nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+ nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+ gpu.barrier()
+
+ # GPU Step 3. Prefetch TMA descriptors
+ nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+
+ # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
+ with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
+
+ # Step 5.1. Reduce register size
+ nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+
+ # Step 5.2. TMA Main Loop
+ for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+ with ir.InsertionPoint(for_op.body):
+ phaseParity = for_op.inner_iter_args[0]
+ iv = for_op.induction_variable
+ stage = arith.remui(iv, c(num_stages))
+
+ # Step 5.2.1. Wait mbarDONE
+ debug_print("[prod] {} | mbarDONE[{}] try_wait phase={}",
+ iv,
+ stage,
+ ...
[truncated]
|
✅ With the latest revision this PR passed the Python code formatter. |
@apaszke @manishucsd could you please review this? |
Thanks for the great PR description, would it be worth adding this verbatim as a comment in the file? Edit: actually it is there, thanks :) |
ab3e7c4
to
aa3e348
Compare
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Thanks for adding this. I've left a few comments for consideration. The PR is a promising start in using MLIR's Python bindings to generate PTX. It's important to move forward with merging this, even if it currently supports a limited number of cases. Doing so will help maintain the robustness of the Python bindings, allowing us to continue refining and expanding upon the ideas introduced here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 5 files reviewed, 24 unresolved discussions (waiting on @grypp, @joker-eph, @manishucsd, and @vinodgro)
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 44 at r3 (raw file):
# # This kernel launches a single warp group (128 threads). The primary thread # (pthread) requests load from TMA. Threads collectively wait for the data and
pthread is rather engrained in folks. How about just main?
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 128 at r3 (raw file):
) mlir_nvgpu_module.operation.verify()
Why is this needed?
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 132 at r3 (raw file):
# Save generated IR if saveIR: # print(mlir_nvgpu_module)
Rm commented out code
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 140 at r3 (raw file):
# Get compiler options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
Should this be an option to the function?
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 147 at r3 (raw file):
) compiler = nvgpucompiler.NvgpuCompiler( options, opt_level=3, shared_libs=[support_lib]
Does this opt_level have to match the one above?
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 256 at r3 (raw file):
# Takes longer time to run
Are both ran unconditionally? (it feels like one should be part of integration tests while the other could be for unit test)
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 259 at r3 (raw file):
def test_long(): for stages in range(1, 7): for M in [128, 512, 1024, 4096, 8192]:
Should gtest parameterized be used here?
mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
line 1 at r3 (raw file):
# Files in this directory are tools, not tests.
Mmm, feels like we need a different spot for these.
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 25 at r3 (raw file):
CONSUMER_PRIMARY_THREAD = 0 MLIR_DYNAMIC = -9223372036854775808
Document ? (also could you use the the constant from C/C++ side)
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 30 at r3 (raw file):
def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
Is the thread the gpu specific part here?
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 61 at r3 (raw file):
def get_type_str(ty):
There has to be something already for this. What do you get if you str on the type?
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 76 at r3 (raw file):
def get_type_size(ty): if ir.F16Type.isinstance(ty):
The FloatType recently got added to query the width, so you could collapse the fp ones
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 85 at r3 (raw file):
return ir.IntegerType(ty).width // 8 if ir.IndexType.isinstance(ty): return 8
This is an assumption/not generally true. Why is index here? (I would have expected these to be lowered out before here to specific size)
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 177 at r3 (raw file):
smem_space_str = "#gpu.address_space<workgroup>" smem_space = ir.Attribute.parse(smem_space_str) mbar_ty = ir.Type.parse(
So none of the nvgpu types have python bindings yet, and just parse the string variant?
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 184 at r3 (raw file):
+ ">" ) a_tma_desc_ty = ir.Type.parse(
Could we wrap these behind helper functions so that we don't have the parses here?
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 376 at r3 (raw file):
) p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) phaseParity = arith.select(
Is this formatted using the black formatter? (I don't think we are testing here yet in presubmit)
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 694 at r3 (raw file):
smem_space_str = "#gpu.address_space<workgroup>" smem_space = ir.Attribute.parse(smem_space_str) mbar_ty = ir.Type.parse(
Same for these (well most parse instances)
mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
line 26 at r3 (raw file):
self.pipeline = pipeline self.shared_libs = shared_libs self.opt_level = opt_level
Should this be used to add to pipeline?
Previously, jpienaar (Jacques Pienaar) wrote…
One could either give the gpu thread id as a integer in |
This test aims to validate the correctness of the supported GEMM kernels in NVGPU dialects, with current support for Multistage and Warp Specialization kernels. The test constructs and metaprograms IR using Python bindings, allowing generic IR building. This flexibility enables changes to the shape, tile size, or data type of the GEMM for testing purposes. The entry function is `matmul`, where one can specify GEMM shape, tile size, data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum number of stages. Verification is done via numpy's matmul operation. Example: ``` matmul(input_type=np.float16, # input types output_type=np.float32, # output type M=4096, N=4096, K=4096, # Shape BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size use_warp_specialization=True, # Enable Warp Specialization max_num_stages=3) # Number of stages in shared memory ``` ### Parallelism Across CTAs GEMM includes three loops defining the shape of the GEMM, specified in the `matmul` function. The program builds IR using the following loop structure, tiling the loops with the given tile size and parallelizing the two outermost loops into the first and second dimensions of CTAs. ``` for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y for(bk = 0; k < K; K += BLOCK_K) for(i = bi; i < (bi + BLOCK_M); ++i) for(j = bj; j < (bj + BLOCK_N); ++j) for(k = bk; k < (bk + BLOCK_K); ++k) ``` ## Multistage Kernel This kernel launches a single warp group (128 threads). The primary thread (pthread) requests load from TMA. Threads collectively wait for the data and perform mma operations. After completing the shape, threads together store first fragmented registers to shared memory, then from shared memory to global memory; this part is called the epilogue. Execution Timeline of Multistage Kernel with 3 stages: ``` +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ | |Prologue ----> |MainLoop ----> |Epilogue | +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ |pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]| |wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]| +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ ``` ## Warp Specialization Kernel This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one as `producer warp group` and another as `consumer warp group`. The `producer warp group` is responsible for requesting TMA load, while the `consumer warp group` performs the mma operation. The epilogue section is handled by the `consumer warp group` as its threads own the fragmented registers. Execution Timeline of Warp Specialization Kernel with 2 stages: ``` +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ | |MainLoop ----> | 1st Epilogue | 2nd Epilogue | +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] | |wgroup1 | .......| | | | | | [shmem->global] | +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]| +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ ```
Add test_short that test multiple cases.
Previously, jpienaar (Jacques Pienaar) wrote…
I will delete this line. |
Previously, jpienaar (Jacques Pienaar) wrote…
We can add that. for the time being we are parsing strings |
Previously, jpienaar (Jacques Pienaar) wrote…
Yes we can. I've plan to improve this code in the next pr. |
Previously, jpienaar (Jacques Pienaar) wrote…
Yes it is from the black formatter. Apparently, the presubmit checks that. |
(I tried out a new tool that I saw got added, I do not like how the tool inserted comments here/seems not friendly for folks not using it, so probably won't use again) |
Previously, jpienaar (Jacques Pienaar) wrote…
This is for host not for the device. |
Previously, jpienaar (Jacques Pienaar) wrote…
Where do you think we can put them? I think sparsifier is similar (https://github.com/llvm/llvm-project/blob/main/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparsifier.py) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 5 files reviewed, 24 unresolved discussions (waiting on @joker-eph, @jpienaar, @manishucsd, and @vinodgro)
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 44 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
pthread is rather engrained in folks. How about just main?
took the terminology of"primary thread" t from openmp. But pthread shortcut isn't the best.
Let me think about that.
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 128 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
Why is this needed?
Doesn't this verifies the module? I am not sure whether the module is verified before?
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 132 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
Rm commented out code
Done.
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 140 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
Should this be an option to the function?
Good point, I can do that, but this code is entirely for hopper gpus. So I don't expect we will need any other compilation flag.
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 147 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
Does this opt_level have to match the one above?
This one is for the host code while another one is for the device code. They don't necessarily need to match
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 256 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
Are both ran unconditionally? (it feels like one should be part of integration tests while the other could be for unit test)
I agree with you but note that llvm test machines don't have hopper gpus :) so it won't run any of these tests.
I put this code as an example.
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
line 259 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
Should gtest parameterized be used here?
How can I use that?
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 80 at r1 (raw file):
Previously, manishucsd (Manish Gupta) wrote…
ok. I see that, even though we have these come in as arguments to the function
generate_matmul_ws
, but we are generating it for a fixed datatype.
Done.
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 83 at r1 (raw file):
Previously, manishucsd (Manish Gupta) wrote…
Sounds good!
For this PR, should we add asserts on (BLOCK_M, BLOCK_N, BLOCK_K) == (128,128,64)?
Done.
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 25 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
Document ? (also could you use the the constant from C/C++ side)
Done.
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 61 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
There has to be something already for this. What do you get if you str on the type?
Done.
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 76 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
The FloatType recently got added to query the width, so you could collapse the fp ones
Done.
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
line 694 at r3 (raw file):
Previously, jpienaar (Jacques Pienaar) wrote…
Same for these (well most parse instances)
Done.
6da82f4
to
89370ef
Compare
@grypp : this broke the bot FYI (you should have had an email) |
@joker-eph yes I've just seen them. But I don't really understand why. |
I didn't look at the failure before, this is mysterious to me: it seems to have resolved itself. I don't quite understand, seeing how it affected many bots at the same time on the exact same range. |
This test aims to validate the correctness of the supported GEMM kernels in NVGPU dialects, with current support for Multistage and Warp Specialization kernels.
The test constructs and metaprograms IR using Python bindings, allowing generic IR building. This flexibility enables changes to the shape, tile size, or data type of the GEMM for testing purposes. The entry function is
matmul
, where one can specify GEMM shape, tile size, data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum number of stages.Verification is done via numpy's matmul operation.
Example:
Parallelism Across CTAs
GEMM includes three loops defining the shape of the GEMM, specified in the
matmul
function.The program builds IR using the following loop structure, tiling the loops with the given tile size and parallelizing the two outermost loops into the first and second dimensions of CTAs.
Multistage Kernel
This kernel launches a single warp group (128 threads). The primary thread (pthread) requests load from TMA. Threads collectively wait for the data and perform mma operations. After completing the shape, threads together store first fragmented registers to shared memory, then from shared memory to global memory; this part is called the epilogue.
Execution Timeline of Multistage Kernel with 3 stages:
Warp Specialization Kernel
This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one as
producer warp group
and another asconsumer warp group
. Theproducer warp group
is responsible for requesting TMA load, while theconsumer warp group
performs the mma operation. The epilogue section is handled by theconsumer warp group
as its threads own the fragmented registers.Execution Timeline of Warp Specialization Kernel with 2 stages:
Limitations
128x128x64
Tile Size. Possible to extend this functionality in Python bindings if it is desired. It is not here as we want to test nvgpu ops here, not the TMA itself.F32 += F16 * F16
operations. Further expansion of nvgpu ops support is needed for other data types.Benchmarks with Tile Size 128x128x64
To clarify, this is designed as a test rather than a benchmark, yet it provides insightful preliminary performance metrics. The graph below illustrates the performance on H100 SXM, with cuBLAS 12.1 serving as a reference for comparison. It's important to note that both are still below the peak performance of ~1000tf for
f16
.This change is