@@ -359,6 +359,127 @@ def matmul_tma_persistent(a, b):
359
359
return c
360
360
361
361
362
+ @triton .jit (launch_metadata = _matmul_launch_metadata )
363
+ def matmul_kernel_device_tma_persistent (a_desc_ptr , b_desc_ptr , c_desc_ptr , #
364
+ a_ptr , b_ptr , c_ptr , #
365
+ ready_flag , #
366
+ M , N , K , #
367
+ BLOCK_SIZE_M : tl .constexpr , #
368
+ BLOCK_SIZE_N : tl .constexpr , #
369
+ BLOCK_SIZE_K : tl .constexpr , #
370
+ GROUP_SIZE_M : tl .constexpr , #
371
+ NUM_SMS : tl .constexpr ): #
372
+ # Matmul using TMA and device-side descriptor creation
373
+ dtype = c_ptr .dtype .element_ty
374
+ start_pid = tl .program_id (axis = 0 )
375
+ num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
376
+ num_pid_n = tl .cdiv (N , BLOCK_SIZE_N )
377
+ k_tiles = tl .cdiv (K , BLOCK_SIZE_K )
378
+ num_tiles = num_pid_m * num_pid_n
379
+
380
+ if start_pid == 0 :
381
+ tl .extra .cuda .experimental_device_tensormap_create2d (desc_ptr = a_desc_ptr , global_address = a_ptr ,
382
+ load_size = [BLOCK_SIZE_M , BLOCK_SIZE_K ], global_size = [M , K ],
383
+ element_ty = a_ptr .dtype .element_ty )
384
+ tl .extra .cuda .experimental_device_tensormap_create2d (desc_ptr = b_desc_ptr , global_address = b_ptr ,
385
+ load_size = [BLOCK_SIZE_N , BLOCK_SIZE_K ], global_size = [N , K ],
386
+ element_ty = b_ptr .dtype .element_ty )
387
+ tl .extra .cuda .experimental_device_tensormap_create2d (desc_ptr = c_desc_ptr , global_address = c_ptr ,
388
+ load_size = [BLOCK_SIZE_M , BLOCK_SIZE_N ], global_size = [M , N ],
389
+ element_ty = c_ptr .dtype .element_ty )
390
+ tl .atomic_xchg (ready_flag , 1 , sem = "release" )
391
+ else :
392
+ flag = tl .full ([], 0 , tl .int32 )
393
+ while flag != 1 :
394
+ flag = tl .atomic_add (ready_flag , 0 , sem = "acquire" )
395
+ tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (a_desc_ptr )
396
+ tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (b_desc_ptr )
397
+ tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (c_desc_ptr )
398
+
399
+ tiles_per_SM = num_tiles // NUM_SMS
400
+ if start_pid < num_tiles % NUM_SMS :
401
+ tiles_per_SM += 1
402
+
403
+ tile_id = start_pid - NUM_SMS
404
+ ki = - 1
405
+
406
+ pid_m = 0
407
+ pid_n = 0
408
+ offs_am = 0
409
+ offs_bn = 0
410
+
411
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
412
+
413
+ accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
414
+
415
+ for _ in range (0 , k_tiles * tiles_per_SM ):
416
+ ki = tl .where (ki == k_tiles - 1 , 0 , ki + 1 )
417
+ if ki == 0 :
418
+ tile_id += NUM_SMS
419
+ group_id = tile_id // num_pid_in_group
420
+ first_pid_m = group_id * GROUP_SIZE_M
421
+ group_size_m = min (num_pid_m - first_pid_m , GROUP_SIZE_M )
422
+ pid_m = first_pid_m + (tile_id % group_size_m )
423
+ pid_n = (tile_id % num_pid_in_group ) // group_size_m
424
+
425
+ offs_am = pid_m * BLOCK_SIZE_M
426
+ offs_bn = pid_n * BLOCK_SIZE_N
427
+
428
+ offs_k = ki * BLOCK_SIZE_K
429
+
430
+ a = tl ._experimental_descriptor_load (a_desc_ptr , [offs_am , offs_k ], [BLOCK_SIZE_M , BLOCK_SIZE_K ], dtype )
431
+ b = tl ._experimental_descriptor_load (b_desc_ptr , [offs_bn , offs_k ], [BLOCK_SIZE_N , BLOCK_SIZE_K ], dtype )
432
+ accumulator = tl .dot (a , b .T , accumulator )
433
+
434
+ if ki == k_tiles - 1 :
435
+ c = accumulator .to (dtype )
436
+
437
+ tl ._experimental_descriptor_store (c_desc_ptr , c , [offs_am , offs_bn ])
438
+ accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
439
+
440
+
441
+ def matmul_device_tma_persistent (a , b ):
442
+ # Autotuner does not work with TMA. Use manual config.
443
+ configs = {
444
+ torch .float8_e4m3fn : {
445
+ "BLOCK_SIZE_M" : 128 , "BLOCK_SIZE_N" : 256 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 8 , "num_stages" : 4 ,
446
+ "num_warps" : 8
447
+ }, torch .float16 : {
448
+ "BLOCK_SIZE_M" : 128 , "BLOCK_SIZE_N" : 256 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 8 , "num_stages" : 3 ,
449
+ "num_warps" : 8
450
+ }
451
+ }
452
+
453
+ # Check constraints.
454
+ assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
455
+ assert a .dtype == b .dtype , "Incompatible dtypes"
456
+
457
+ M , K = a .shape
458
+ N , K = b .shape
459
+ dtype = a .dtype
460
+
461
+ c = torch .zeros ((M , N ), device = a .device , dtype = dtype )
462
+ a_desc , b_desc , c_desc = [torch .empty (128 , dtype = torch .uint8 , device = "cuda" ) for _ in range (3 )]
463
+ ready_flag = torch .zeros ((), dtype = torch .int32 , device = "cuda" )
464
+ NUM_SMS = torch .cuda .get_device_properties ("cuda" ).multi_processor_count
465
+
466
+ grid = lambda META : (min (NUM_SMS , triton .cdiv (M , META ["BLOCK_SIZE_M" ]) * triton .cdiv (N , META ["BLOCK_SIZE_N" ])), )
467
+ matmul_kernel_device_tma_persistent [grid ](
468
+ a_desc , b_desc , c_desc , #
469
+ a , b , c , #
470
+ ready_flag , #
471
+ M , N , K , #
472
+ BLOCK_SIZE_M = configs [dtype ]["BLOCK_SIZE_M" ], #
473
+ BLOCK_SIZE_N = configs [dtype ]["BLOCK_SIZE_N" ], #
474
+ BLOCK_SIZE_K = configs [dtype ]["BLOCK_SIZE_K" ], #
475
+ GROUP_SIZE_M = configs [dtype ]["GROUP_SIZE_M" ], #
476
+ NUM_SMS = NUM_SMS , #
477
+ num_stages = configs [dtype ]["num_stages" ], #
478
+ num_warps = configs [dtype ]["num_warps" ], #
479
+ )
480
+ return c
481
+
482
+
362
483
def cublas_matmul (a , b ):
363
484
# Check constraints.
364
485
assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
@@ -414,6 +535,9 @@ def bench(K, dtype, reps=10):
414
535
for _ in range (reps ):
415
536
matmul_tma_persistent (a , b )
416
537
time .sleep (0.01 )
538
+ for _ in range (reps ):
539
+ matmul_device_tma_persistent (a , b )
540
+ time .sleep (0.01 )
417
541
418
542
proton .deactivate (0 )
419
543
@@ -428,6 +552,7 @@ def validate(M, N, K, dtype):
428
552
naive_result = matmul (a , b .T )
429
553
persistent_result = matmul_persistent (a , b .T )
430
554
tma_persistent_result = matmul_tma_persistent (a , b ) if supports_tma () else None
555
+ device_tma_persistent_result = matmul_device_tma_persistent (a , b ) if supports_tma () else None
431
556
432
557
if torch_result is not None :
433
558
naive_vs_torch = "✅" if torch .allclose (naive_result .to (torch .float16 ), torch_result .to (torch .float16 ),
@@ -440,14 +565,20 @@ def validate(M, N, K, dtype):
440
565
if tma_persistent_result is not None :
441
566
naive_vs_tma_persistent = "✅" if torch .allclose (cublas_result .to (torch .float16 ),
442
567
tma_persistent_result .to (torch .float16 ), atol = 1.0 ) else "❌"
568
+ if device_tma_persistent_result is not None :
569
+ naive_vs_device_tma_persistent = "✅" if torch .allclose (cublas_result .to (
570
+ torch .float16 ), device_tma_persistent_result .to (torch .float16 ), atol = 1.0 ) else "❌"
443
571
print (f"M={ M } , N={ N } , K={ K } verification naive vs: " , end = "" )
444
572
if torch_result is not None :
445
573
print (f"torch: { naive_vs_torch } " , end = "" )
446
574
if cublas_result is not None :
447
575
print (f"cublas: { naive_vs_cublas } " , end = "" )
448
576
print (f"persistent: { naive_vs_persistent } " , end = "" )
449
577
if tma_persistent_result is not None :
450
- print (f"TMA persistent: { naive_vs_tma_persistent } " )
578
+ print (f"TMA persistent: { naive_vs_tma_persistent } " , end = "" )
579
+ if device_tma_persistent_result is not None :
580
+ print (f"Device TMA persistent: { naive_vs_device_tma_persistent } " , end = "" )
581
+ print ()
451
582
452
583
453
584
if __name__ == "__main__" :
0 commit comments