1
1
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2
2
# RUN: %PYTHON %s | FileCheck %s
3
- # CHECK: PASS
3
+
4
4
5
5
# ===--- GEMM Hopper Tensor Core Integration Test ---===
6
6
#
@@ -168,6 +168,8 @@ def matmul(
168
168
no_verify = False ,
169
169
):
170
170
# Print the configuration
171
+ required_stages = (M * K + K * N ) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N )
172
+ num_stages = min (required_stages , max_num_stages )
171
173
ity = "f16" if input_type == np .float16 else "f32"
172
174
oty = "f16" if output_type == np .float16 else "f32"
173
175
gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
@@ -193,7 +195,7 @@ def matmul(
193
195
+ "x"
194
196
+ str (BLOCK_K )
195
197
+ ", stages "
196
- + str (max_num_stages )
198
+ + str (num_stages )
197
199
+ " --==="
198
200
)
199
201
@@ -209,7 +211,7 @@ def matmul(
209
211
BLOCK_K ,
210
212
use_warp_specialization ,
211
213
saveIR ,
212
- max_num_stages ,
214
+ num_stages ,
213
215
)
214
216
215
217
# Allocate matrices and invoke the matmul
@@ -219,10 +221,17 @@ def matmul(
219
221
mem_a = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (a )))
220
222
mem_b = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (b )))
221
223
mem_c = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (c )))
222
- kernelName = (
223
- "mlir_matmul_warpspecialized"
224
- if use_warp_specialization
225
- else "mlir_matmul_multistage"
224
+ kernelName = matmulBuilder .make_kernel_name (
225
+ input_type ,
226
+ output_type ,
227
+ M ,
228
+ N ,
229
+ K ,
230
+ BLOCK_M ,
231
+ BLOCK_N ,
232
+ BLOCK_K ,
233
+ num_stages ,
234
+ use_warp_specialization ,
226
235
)
227
236
228
237
# Launch the MLIR generated kernel
@@ -243,24 +252,118 @@ def matmul(
243
252
244
253
print ("PASS " )
245
254
255
+ # Takes longer time to run
256
+ def test_long ():
257
+ for stages in range (1 ,7 ):
258
+ for M in [128 , 512 , 1024 , 4096 , 8192 ]:
259
+ for N in [128 , 512 , 1024 , 4096 , 8192 ]:
260
+ for K in [64 , 128 , 512 , 1024 , 4096 , 8192 ]:
261
+ matmul (
262
+ np .float16 ,
263
+ np .float32 ,
264
+ M ,N ,
265
+ K ,
266
+ max_num_stages = stages ,
267
+ use_warp_specialization = False ,
268
+ no_verify = True ,
269
+ saveIR = True
270
+ )
271
+ matmul (
272
+ np .float16 ,
273
+ np .float32 ,
274
+ M ,N ,
275
+ K ,
276
+ max_num_stages = stages ,
277
+ use_warp_specialization = True ,
278
+ )
246
279
247
- # GEMM Multistage f32 += f16 * f16
248
- matmul (
249
- np .float16 ,
250
- np .float32 ,
251
- 128 ,
252
- 128 ,
253
- 4096 ,
254
- max_num_stages = 3 ,
255
- use_warp_specialization = False ,
256
- )
257
- # GEMM Warp Specilized f32 += f16 * f16
258
- matmul (
259
- np .float16 ,
260
- np .float32 ,
261
- 256 ,
262
- 1024 ,
263
- 512 ,
264
- max_num_stages = 3 ,
265
- use_warp_specialization = True ,
266
- )
280
+ def test_short ():
281
+ for stages in [1 , 3 ]:
282
+ for M in [128 , 512 ]:
283
+ for N in [128 , 512 ]:
284
+ for K in [64 , 512 ]:
285
+ matmul (
286
+ np .float16 ,
287
+ np .float32 ,
288
+ M ,N ,
289
+ K ,
290
+ max_num_stages = stages ,
291
+ use_warp_specialization = False ,
292
+ no_verify = True ,
293
+ saveIR = True
294
+ )
295
+ matmul (
296
+ np .float16 ,
297
+ np .float32 ,
298
+ M ,N ,
299
+ K ,
300
+ max_num_stages = stages ,
301
+ use_warp_specialization = True ,
302
+ )
303
+
304
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
305
+ # CHECK: PASS
306
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
307
+ # CHECK: PASS
308
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
309
+ # CHECK: PASS
310
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
311
+ # CHECK: PASS
312
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
313
+ # CHECK: PASS
314
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
315
+ # CHECK: PASS
316
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
317
+ # CHECK: PASS
318
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
319
+ # CHECK: PASS
320
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
321
+ # CHECK: PASS
322
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
323
+ # CHECK: PASS
324
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
325
+ # CHECK: PASS
326
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
327
+ # CHECK: PASS
328
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
329
+ # CHECK: PASS
330
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
331
+ # CHECK: PASS
332
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
333
+ # CHECK: PASS
334
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
335
+ # CHECK: PASS
336
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
337
+ # CHECK: PASS
338
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
339
+ # CHECK: PASS
340
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
341
+ # CHECK: PASS
342
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
343
+ # CHECK: PASS
344
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
345
+ # CHECK: PASS
346
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
347
+ # CHECK: PASS
348
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
349
+ # CHECK: PASS
350
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
351
+ # CHECK: PASS
352
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
353
+ # CHECK: PASS
354
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
355
+ # CHECK: PASS
356
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
357
+ # CHECK: PASS
358
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
359
+ # CHECK: PASS
360
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
361
+ # CHECK: PASS
362
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
363
+ # CHECK: PASS
364
+ # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
365
+ # CHECK: PASS
366
+ # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
367
+ # CHECK: PASS
368
+
369
+ test_short ()
0 commit comments