Skip to content

Commit 0934dcb

Browse files
committed
Allow multiple stages, and fix the kernels.
Add test_short that test multiple cases.
1 parent 777f208 commit 0934dcb

File tree

2 files changed

+308
-155
lines changed

2 files changed

+308
-155
lines changed

mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py

Lines changed: 130 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
22
# RUN: %PYTHON %s | FileCheck %s
3-
# CHECK: PASS
3+
44

55
# ===--- GEMM Hopper Tensor Core Integration Test ---===
66
#
@@ -168,6 +168,8 @@ def matmul(
168168
no_verify=False,
169169
):
170170
# Print the configuration
171+
required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
172+
num_stages = min(required_stages, max_num_stages)
171173
ity = "f16" if input_type == np.float16 else "f32"
172174
oty = "f16" if output_type == np.float16 else "f32"
173175
gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
@@ -193,7 +195,7 @@ def matmul(
193195
+ "x"
194196
+ str(BLOCK_K)
195197
+ ", stages "
196-
+ str(max_num_stages)
198+
+ str(num_stages)
197199
+ " --==="
198200
)
199201

@@ -209,7 +211,7 @@ def matmul(
209211
BLOCK_K,
210212
use_warp_specialization,
211213
saveIR,
212-
max_num_stages,
214+
num_stages,
213215
)
214216

215217
# Allocate matrices and invoke the matmul
@@ -219,10 +221,17 @@ def matmul(
219221
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
220222
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
221223
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
222-
kernelName = (
223-
"mlir_matmul_warpspecialized"
224-
if use_warp_specialization
225-
else "mlir_matmul_multistage"
224+
kernelName = matmulBuilder.make_kernel_name(
225+
input_type,
226+
output_type,
227+
M,
228+
N,
229+
K,
230+
BLOCK_M,
231+
BLOCK_N,
232+
BLOCK_K,
233+
num_stages,
234+
use_warp_specialization,
226235
)
227236

228237
# Launch the MLIR generated kernel
@@ -243,24 +252,118 @@ def matmul(
243252

244253
print("PASS ")
245254

255+
# Takes longer time to run
256+
def test_long():
257+
for stages in range(1,7):
258+
for M in [128, 512, 1024, 4096, 8192]:
259+
for N in [128, 512, 1024, 4096, 8192]:
260+
for K in [64, 128, 512, 1024, 4096, 8192]:
261+
matmul(
262+
np.float16,
263+
np.float32,
264+
M,N,
265+
K,
266+
max_num_stages=stages,
267+
use_warp_specialization=False,
268+
no_verify=True,
269+
saveIR=True
270+
)
271+
matmul(
272+
np.float16,
273+
np.float32,
274+
M,N,
275+
K,
276+
max_num_stages=stages,
277+
use_warp_specialization=True,
278+
)
246279

247-
# GEMM Multistage f32 += f16 * f16
248-
matmul(
249-
np.float16,
250-
np.float32,
251-
128,
252-
128,
253-
4096,
254-
max_num_stages=3,
255-
use_warp_specialization=False,
256-
)
257-
# GEMM Warp Specilized f32 += f16 * f16
258-
matmul(
259-
np.float16,
260-
np.float32,
261-
256,
262-
1024,
263-
512,
264-
max_num_stages=3,
265-
use_warp_specialization=True,
266-
)
280+
def test_short():
281+
for stages in [1, 3]:
282+
for M in [128, 512]:
283+
for N in [128, 512]:
284+
for K in [64, 512]:
285+
matmul(
286+
np.float16,
287+
np.float32,
288+
M,N,
289+
K,
290+
max_num_stages=stages,
291+
use_warp_specialization=False,
292+
no_verify=True,
293+
saveIR=True
294+
)
295+
matmul(
296+
np.float16,
297+
np.float32,
298+
M,N,
299+
K,
300+
max_num_stages=stages,
301+
use_warp_specialization=True,
302+
)
303+
304+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
305+
# CHECK: PASS
306+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
307+
# CHECK: PASS
308+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
309+
# CHECK: PASS
310+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
311+
# CHECK: PASS
312+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
313+
# CHECK: PASS
314+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
315+
# CHECK: PASS
316+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
317+
# CHECK: PASS
318+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
319+
# CHECK: PASS
320+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
321+
# CHECK: PASS
322+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
323+
# CHECK: PASS
324+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
325+
# CHECK: PASS
326+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
327+
# CHECK: PASS
328+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
329+
# CHECK: PASS
330+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
331+
# CHECK: PASS
332+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
333+
# CHECK: PASS
334+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
335+
# CHECK: PASS
336+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
337+
# CHECK: PASS
338+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
339+
# CHECK: PASS
340+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
341+
# CHECK: PASS
342+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
343+
# CHECK: PASS
344+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
345+
# CHECK: PASS
346+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
347+
# CHECK: PASS
348+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
349+
# CHECK: PASS
350+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
351+
# CHECK: PASS
352+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
353+
# CHECK: PASS
354+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
355+
# CHECK: PASS
356+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
357+
# CHECK: PASS
358+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
359+
# CHECK: PASS
360+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
361+
# CHECK: PASS
362+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
363+
# CHECK: PASS
364+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
365+
# CHECK: PASS
366+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
367+
# CHECK: PASS
368+
369+
test_short()

0 commit comments

Comments
 (0)