Skip to content

Commit ad767a6

Browse files
committed
format
1 parent 09fb9c8 commit ad767a6

File tree

2 files changed

+44
-75
lines changed

2 files changed

+44
-75
lines changed

mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py

Lines changed: 40 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -252,118 +252,90 @@ def matmul(
252252

253253
print("PASS ")
254254

255+
255256
# Takes longer time to run
256257
def test_long():
257-
for stages in range(1,7):
258+
for stages in range(1, 7):
258259
for M in [128, 512, 1024, 4096, 8192]:
259260
for N in [128, 512, 1024, 4096, 8192]:
260261
for K in [64, 128, 512, 1024, 4096, 8192]:
261262
matmul(
262263
np.float16,
263264
np.float32,
264-
M,N,
265+
M,
266+
N,
265267
K,
266268
max_num_stages=stages,
267269
use_warp_specialization=False,
268270
no_verify=True,
269-
saveIR=True
270271
)
271272
matmul(
272273
np.float16,
273274
np.float32,
274-
M,N,
275+
M,
276+
N,
275277
K,
276278
max_num_stages=stages,
277-
use_warp_specialization=True,
279+
use_warp_specialization=True,
278280
)
279281

282+
280283
def test_short():
281284
for stages in [1, 3]:
282285
for M in [128, 512]:
283-
for N in [128, 512]:
284-
for K in [64, 512]:
286+
for N in [128]:
287+
for K in [64, 256]:
285288
matmul(
286289
np.float16,
287290
np.float32,
288-
M,N,
291+
M,
292+
N,
289293
K,
290294
max_num_stages=stages,
291295
use_warp_specialization=False,
292-
no_verify=True,
293-
saveIR=True
294296
)
295297
matmul(
296298
np.float16,
297299
np.float32,
298-
M,N,
300+
M,
301+
N,
299302
K,
300303
max_num_stages=stages,
301-
use_warp_specialization=True,
304+
use_warp_specialization=True,
302305
)
303306

307+
304308
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
305-
# CHECK: PASS
309+
# CHECK: PASS
306310
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
307-
# CHECK: PASS
308-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
309-
# CHECK: PASS
310-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
311-
# CHECK: PASS
312-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
313-
# CHECK: PASS
314-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
315-
# CHECK: PASS
316-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
317-
# CHECK: PASS
318-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
319-
# CHECK: PASS
311+
# CHECK: PASS
312+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
313+
# CHECK: PASS
314+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
315+
# CHECK: PASS
320316
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
321-
# CHECK: PASS
317+
# CHECK: PASS
322318
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
323-
# CHECK: PASS
324-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
325-
# CHECK: PASS
326-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
327-
# CHECK: PASS
328-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
329-
# CHECK: PASS
330-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
331-
# CHECK: PASS
332-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
333-
# CHECK: PASS
334-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
335-
# CHECK: PASS
319+
# CHECK: PASS
320+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
321+
# CHECK: PASS
322+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
323+
# CHECK: PASS
336324
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
337-
# CHECK: PASS
325+
# CHECK: PASS
338326
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
339-
# CHECK: PASS
340-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
341-
# CHECK: PASS
342-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
343-
# CHECK: PASS
344-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
345-
# CHECK: PASS
346-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
347-
# CHECK: PASS
348-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
349-
# CHECK: PASS
350-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
351-
# CHECK: PASS
327+
# CHECK: PASS
328+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
329+
# CHECK: PASS
330+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
331+
# CHECK: PASS
352332
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
353-
# CHECK: PASS
333+
# CHECK: PASS
354334
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
355-
# CHECK: PASS
356-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
357-
# CHECK: PASS
358-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
359-
# CHECK: PASS
360-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
361-
# CHECK: PASS
362-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
363-
# CHECK: PASS
364-
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
365-
# CHECK: PASS
366-
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
335+
# CHECK: PASS
336+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
337+
# CHECK: PASS
338+
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
367339
# CHECK: PASS
368340

369-
test_short()
341+
test_short()

mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,7 @@ def make_kernel_name(
117117
num_stages=3,
118118
use_warp_specialization=False,
119119
):
120-
kernelName = (
121-
"warpspecialized"
122-
if use_warp_specialization
123-
else "multistage"
124-
)
120+
kernelName = "warpspecialized" if use_warp_specialization else "multistage"
125121
return (
126122
kernelName
127123
+ "_"
@@ -140,6 +136,7 @@ def make_kernel_name(
140136
+ str(num_stages)
141137
)
142138

139+
143140
def generate_matmul_ws(
144141
input_type=np.float16,
145142
output_type=np.float32,
@@ -644,7 +641,7 @@ def generate_matmul_ws(
644641
t8 = gpu.wait(token_ty, [launch_op])
645642
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
646643
gpu.dealloc(token_ty, [t8], a_device)
647-
gpu.dealloc(token_ty, [t8], b_device)
644+
gpu.dealloc(token_ty, [t8], b_device)
648645
gpu.wait(token_ty, [t9])
649646
gpu.dealloc(token_ty, [t8], c_device)
650647
func.ReturnOp([])
@@ -1162,7 +1159,7 @@ def generate_matmul_multistage(
11621159
t8 = gpu.wait(token_ty, [launch_op])
11631160
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
11641161
gpu.dealloc(token_ty, [t8], a_device)
1165-
gpu.dealloc(token_ty, [t8], b_device)
1162+
gpu.dealloc(token_ty, [t8], b_device)
11661163
gpu.wait(token_ty, [t9])
11671164
gpu.dealloc(token_ty, [t8], c_device)
11681165
func.ReturnOp([])

0 commit comments

Comments
 (0)