Skip to content

Commit d95e6d0

Browse files
authored
[mlir] GEMM Hopper Tensor Core Integration Test (#81478)
1 parent 2e93ee6 commit d95e6d0

File tree

5 files changed

+1547
-0
lines changed

5 files changed

+1547
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
2+
config.unsupported = True
Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2+
# RUN: %PYTHON %s | FileCheck %s
3+
4+
5+
# ===--- GEMM Hopper Tensor Core Integration Test ---===
6+
#
7+
# This test aims to validate the correctness of the supported GEMM kernels in
8+
# NVGPU dialects, with current support for Multistage and Warp Specialization
9+
# kernels.
10+
# The test constructs and metaprograms IR using Python bindings, allowing
11+
# generic IR building. This flexibility enables changes to the shape,
12+
# tile size, or data type of the GEMM for testing purposes.
13+
# The entry function is `matmul`, where one can specify GEMM shape, tile size,
14+
# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
15+
# number of stages.
16+
# Verification is done via numpy's matmul operation.
17+
#
18+
# Example:
19+
# matmul(input_type=np.float16, # input types
20+
# output_type=np.float32, # output type
21+
# M=4096, N=4096, K=4096, # Shape
22+
# BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
23+
# use_warp_specialization=True, # Enable Warp Specialization
24+
# max_num_stages=3) # Number of stages in shared memory
25+
#
26+
# ===--- Parallelism Across CTAs ---===
27+
#
28+
# GEMM includes three loops defining the shape of the GEMM, specified in the
29+
# `matmul` function.
30+
# The program builds IR using the following loop structure, tiling the loops
31+
# with the given tile size and parallelizing the two outermost loops into the
32+
# first and second dimensions of CTAs.
33+
#
34+
# for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x
35+
# for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y
36+
# for(bk = 0; k < K; K += BLOCK_K)
37+
# for(i = bi; i < (bi + BLOCK_M); ++i)
38+
# for(j = bj; j < (bj + BLOCK_N); ++j)
39+
# for(k = bk; k < (bk + BLOCK_K); ++k)
40+
#
41+
# ===--- Multistage Kernel ---===
42+
#
43+
# This kernel launches a single warp group (128 threads). The primary thread
44+
# (pthread) requests load from TMA. Threads collectively wait for the data and
45+
# perform mma operations. After completing the shape, threads together store
46+
# first fragmented registers to shared memory, then from shared memory to global
47+
# memory; this part is called the epilogue.
48+
#
49+
# Execution Timeline of Multistage Kernel with 3 stages:
50+
# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
51+
# | |Prologue ----> |MainLoop ----> |Epilogue |
52+
# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
53+
# |pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
54+
# |wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]|
55+
# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
56+
#
57+
# ===--- Warp Specialization Kernel ---===
58+
#
59+
# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
60+
# as `producer warp group` and another as `consumer warp group`. The
61+
# `producer warp group` is responsible for requesting TMA load, while the
62+
# `consumer warp group` performs the mma operation. The epilogue section is
63+
# handled by the `consumer warp group` as its threads own the fragmented registers.
64+
#
65+
# Execution Timeline of Warp Specialization Kernel with 2 stages:
66+
# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
67+
# | |MainLoop ----> | 1st Epilogue | 2nd Epilogue |
68+
# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
69+
# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] |
70+
# |wgroup1 | .......| | | | | | [shmem->global] |
71+
# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
72+
# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
73+
# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
74+
75+
import errno
76+
import numpy as np
77+
import subprocess
78+
import ctypes
79+
from tools import nvgpucompiler
80+
from tools import matmulBuilder
81+
import contextlib
82+
import os
83+
import sys
84+
import pathlib
85+
import ctypes
86+
from mlir import runtime as rt
87+
88+
89+
def generate_matmul(
90+
input_type=np.float16,
91+
output_type=np.float32,
92+
M=4096,
93+
N=4096,
94+
K=4096,
95+
BLOCK_M=128,
96+
BLOCK_N=128,
97+
BLOCK_K=64,
98+
use_warp_specialization=True,
99+
saveIR=False,
100+
max_num_stages=3,
101+
options=f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3",
102+
):
103+
with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
104+
if use_warp_specialization:
105+
mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
106+
input_type,
107+
output_type,
108+
M,
109+
N,
110+
K,
111+
BLOCK_M,
112+
BLOCK_N,
113+
BLOCK_K,
114+
max_num_stages,
115+
)
116+
else:
117+
mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
118+
input_type,
119+
output_type,
120+
M,
121+
N,
122+
K,
123+
BLOCK_M,
124+
BLOCK_N,
125+
BLOCK_K,
126+
max_num_stages,
127+
)
128+
129+
mlir_nvgpu_module.operation.verify()
130+
131+
# Save generated IR
132+
if saveIR:
133+
# print(mlir_nvgpu_module)
134+
original_stdout = sys.stdout
135+
with open("gemm.mlir", "w") as f:
136+
sys.stdout = f
137+
print(mlir_nvgpu_module)
138+
sys.stdout = original_stdout
139+
140+
# Get compiler
141+
support_lib = os.getenv("SUPPORT_LIB")
142+
if not os.path.exists(support_lib):
143+
raise FileNotFoundError(
144+
errno.ENOENT, os.strerror(errno.ENOENT), support_lib
145+
)
146+
compiler = nvgpucompiler.NvgpuCompiler(
147+
options, opt_level=3, shared_libs=[support_lib]
148+
)
149+
150+
# Compile
151+
engine = compiler.compile_and_jit(mlir_nvgpu_module)
152+
return engine
153+
154+
155+
def matmul(
156+
input_type=np.float16,
157+
output_type=np.float32,
158+
M=128,
159+
N=128,
160+
K=128,
161+
BLOCK_M=128,
162+
BLOCK_N=128,
163+
BLOCK_K=64,
164+
use_warp_specialization=True,
165+
saveIR=False,
166+
max_num_stages=3,
167+
print_results=False,
168+
no_verify=False,
169+
):
170+
# Print the configuration
171+
required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
172+
num_stages = min(required_stages, max_num_stages)
173+
ity = "f16" if input_type == np.float16 else "f32"
174+
oty = "f16" if output_type == np.float16 else "f32"
175+
gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
176+
print(
177+
"===-- Running GEMM "
178+
+ gemmty
179+
+ " "
180+
+ oty
181+
+ " += "
182+
+ ity
183+
+ " * "
184+
+ ity
185+
+ ", Size "
186+
+ str(M)
187+
+ "x"
188+
+ str(N)
189+
+ "x"
190+
+ str(K)
191+
+ ", Tile "
192+
+ str(BLOCK_M)
193+
+ "x"
194+
+ str(BLOCK_N)
195+
+ "x"
196+
+ str(BLOCK_K)
197+
+ ", stages "
198+
+ str(num_stages)
199+
+ " --==="
200+
)
201+
202+
# Build IR and compile
203+
engine = generate_matmul(
204+
input_type,
205+
output_type,
206+
M,
207+
N,
208+
K,
209+
BLOCK_M,
210+
BLOCK_N,
211+
BLOCK_K,
212+
use_warp_specialization,
213+
saveIR,
214+
num_stages,
215+
)
216+
217+
# Allocate matrices and invoke the matmul
218+
c = np.zeros((M, N), output_type)
219+
a = np.random.randn(M, K).astype(input_type)
220+
b = np.random.randn(K, N).astype(input_type)
221+
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
222+
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
223+
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
224+
kernelName = matmulBuilder.make_kernel_name(
225+
input_type,
226+
output_type,
227+
M,
228+
N,
229+
K,
230+
BLOCK_M,
231+
BLOCK_N,
232+
BLOCK_K,
233+
num_stages,
234+
use_warp_specialization,
235+
)
236+
237+
# Launch the MLIR generated kernel
238+
engine.invoke(kernelName, mem_a, mem_b, mem_c)
239+
240+
float_formatter = "{:.2f}".format
241+
np.set_printoptions(formatter={"float_kind": float_formatter})
242+
243+
if print_results:
244+
print(c)
245+
246+
# Verify the results
247+
if not no_verify:
248+
ref = a.astype(input_type) @ b.astype(input_type)
249+
if print_results:
250+
print(ref)
251+
np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
252+
253+
print("PASS ")
254+
255+
256+
# Takes longer time to run
257+
def test_long():
258+
for stages in range(1, 7):
259+
for M in [128, 512, 1024, 4096, 8192]:
260+
for N in [128, 512, 1024, 4096, 8192]:
261+
for K in [64, 128, 512, 1024, 4096, 8192]:
262+
matmul(
263+
np.float16,
264+
np.float32,
265+
M,
266+
N,
267+
K,
268+
max_num_stages=stages,
269+
use_warp_specialization=False,
270+
no_verify=True,
271+
)
272+
matmul(
273+
np.float16,
274+
np.float32,
275+
M,
276+
N,
277+
K,
278+
max_num_stages=stages,
279+
use_warp_specialization=True,
280+
)
281+
282+
283+
def test_short():
284+
for stages in [1, 3]:
285+
for M in [128, 512]:
286+
for N in [128]:
287+
for K in [64, 256]:
288+
matmul(
289+
np.float16,
290+
np.float32,
291+
M,
292+
N,
293+
K,
294+
max_num_stages=stages,
295+
use_warp_specialization=False,
296+
)
297+
matmul(
298+
np.float16,
299+
np.float32,
300+
M,
301+
N,
302+
K,
303+
max_num_stages=stages,
304+
use_warp_specialization=True,
305+
)
306+
307+
308+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
309+
# CHECK: PASS
310+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
311+
# CHECK: PASS
312+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
313+
# CHECK: PASS
314+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
315+
# CHECK: PASS
316+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
317+
# CHECK: PASS
318+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
319+
# CHECK: PASS
320+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
321+
# CHECK: PASS
322+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
323+
# CHECK: PASS
324+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
325+
# CHECK: PASS
326+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
327+
# CHECK: PASS
328+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
329+
# CHECK: PASS
330+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
331+
# CHECK: PASS
332+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
333+
# CHECK: PASS
334+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
335+
# CHECK: PASS
336+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
337+
# CHECK: PASS
338+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
339+
# CHECK: PASS
340+
341+
test_short()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Files in this directory are tools, not tests.
2+
config.unsupported = True
3+

0 commit comments

Comments
 (0)