@@ -252,118 +252,90 @@ def matmul(
252
252
253
253
print ("PASS " )
254
254
255
+
255
256
# Takes longer time to run
256
257
def test_long ():
257
- for stages in range (1 ,7 ):
258
+ for stages in range (1 , 7 ):
258
259
for M in [128 , 512 , 1024 , 4096 , 8192 ]:
259
260
for N in [128 , 512 , 1024 , 4096 , 8192 ]:
260
261
for K in [64 , 128 , 512 , 1024 , 4096 , 8192 ]:
261
262
matmul (
262
263
np .float16 ,
263
264
np .float32 ,
264
- M ,N ,
265
+ M ,
266
+ N ,
265
267
K ,
266
268
max_num_stages = stages ,
267
269
use_warp_specialization = False ,
268
270
no_verify = True ,
269
- saveIR = True
270
271
)
271
272
matmul (
272
273
np .float16 ,
273
274
np .float32 ,
274
- M ,N ,
275
+ M ,
276
+ N ,
275
277
K ,
276
278
max_num_stages = stages ,
277
- use_warp_specialization = True ,
279
+ use_warp_specialization = True ,
278
280
)
279
281
282
+
280
283
def test_short ():
281
284
for stages in [1 , 3 ]:
282
285
for M in [128 , 512 ]:
283
- for N in [128 , 512 ]:
284
- for K in [64 , 512 ]:
286
+ for N in [128 ]:
287
+ for K in [64 , 256 ]:
285
288
matmul (
286
289
np .float16 ,
287
290
np .float32 ,
288
- M ,N ,
291
+ M ,
292
+ N ,
289
293
K ,
290
294
max_num_stages = stages ,
291
295
use_warp_specialization = False ,
292
- no_verify = True ,
293
- saveIR = True
294
296
)
295
297
matmul (
296
298
np .float16 ,
297
299
np .float32 ,
298
- M ,N ,
300
+ M ,
301
+ N ,
299
302
K ,
300
303
max_num_stages = stages ,
301
- use_warp_specialization = True ,
304
+ use_warp_specialization = True ,
302
305
)
303
306
307
+
304
308
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
305
- # CHECK: PASS
309
+ # CHECK: PASS
306
310
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
307
- # CHECK: PASS
308
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
309
- # CHECK: PASS
310
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
311
- # CHECK: PASS
312
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
313
- # CHECK: PASS
314
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
315
- # CHECK: PASS
316
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
317
- # CHECK: PASS
318
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
319
- # CHECK: PASS
311
+ # CHECK: PASS
312
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
313
+ # CHECK: PASS
314
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
315
+ # CHECK: PASS
320
316
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
321
- # CHECK: PASS
317
+ # CHECK: PASS
322
318
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
323
- # CHECK: PASS
324
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
325
- # CHECK: PASS
326
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
327
- # CHECK: PASS
328
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
329
- # CHECK: PASS
330
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
331
- # CHECK: PASS
332
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
333
- # CHECK: PASS
334
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
335
- # CHECK: PASS
319
+ # CHECK: PASS
320
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
321
+ # CHECK: PASS
322
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
323
+ # CHECK: PASS
336
324
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
337
- # CHECK: PASS
325
+ # CHECK: PASS
338
326
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
339
- # CHECK: PASS
340
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
341
- # CHECK: PASS
342
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
343
- # CHECK: PASS
344
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
345
- # CHECK: PASS
346
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
347
- # CHECK: PASS
348
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
349
- # CHECK: PASS
350
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
351
- # CHECK: PASS
327
+ # CHECK: PASS
328
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
329
+ # CHECK: PASS
330
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
331
+ # CHECK: PASS
352
332
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
353
- # CHECK: PASS
333
+ # CHECK: PASS
354
334
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
355
- # CHECK: PASS
356
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
357
- # CHECK: PASS
358
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
359
- # CHECK: PASS
360
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
361
- # CHECK: PASS
362
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
363
- # CHECK: PASS
364
- # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
365
- # CHECK: PASS
366
- # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
335
+ # CHECK: PASS
336
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
337
+ # CHECK: PASS
338
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
367
339
# CHECK: PASS
368
340
369
- test_short ()
341
+ test_short ()
0 commit comments