Skip to content

Commit b80a7fe

Browse files
peterbell10vlad-penkin
authored andcommitted
[Tutorial] Add device side tensormap update to persistent matmul tutorial (#4648)
This adds a 4th variant to the persistent matmul tutorial that uses the device-side tensormap creation API. When running the tutorial I do see a small reduction in utilization, but I suppose this is to expected. The result is still superior to not using tma though: ``` ├─ 0.244 matmul_kernel [M=8192, N=8192, K=512] ├─ 0.285 matmul_kernel_device_tma_persistent [M=8192, N=8192, K=512] ├─ 0.259 matmul_kernel_persistent [M=8192, N=8192, K=512] ├─ 0.288 matmul_kernel_tma_persistent [M=8192, N=8192, K=512] ```
1 parent 609a906 commit b80a7fe

File tree

1 file changed

+132
-1
lines changed

1 file changed

+132
-1
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,127 @@ def matmul_tma_persistent(a, b):
359359
return c
360360

361361

362+
@triton.jit(launch_metadata=_matmul_launch_metadata)
363+
def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
364+
a_ptr, b_ptr, c_ptr, #
365+
ready_flag, #
366+
M, N, K, #
367+
BLOCK_SIZE_M: tl.constexpr, #
368+
BLOCK_SIZE_N: tl.constexpr, #
369+
BLOCK_SIZE_K: tl.constexpr, #
370+
GROUP_SIZE_M: tl.constexpr, #
371+
NUM_SMS: tl.constexpr): #
372+
# Matmul using TMA and device-side descriptor creation
373+
dtype = c_ptr.dtype.element_ty
374+
start_pid = tl.program_id(axis=0)
375+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
376+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
377+
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
378+
num_tiles = num_pid_m * num_pid_n
379+
380+
if start_pid == 0:
381+
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
382+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K],
383+
element_ty=a_ptr.dtype.element_ty)
384+
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
385+
load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K],
386+
element_ty=b_ptr.dtype.element_ty)
387+
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
388+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N],
389+
element_ty=c_ptr.dtype.element_ty)
390+
tl.atomic_xchg(ready_flag, 1, sem="release")
391+
else:
392+
flag = tl.full([], 0, tl.int32)
393+
while flag != 1:
394+
flag = tl.atomic_add(ready_flag, 0, sem="acquire")
395+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
396+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
397+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
398+
399+
tiles_per_SM = num_tiles // NUM_SMS
400+
if start_pid < num_tiles % NUM_SMS:
401+
tiles_per_SM += 1
402+
403+
tile_id = start_pid - NUM_SMS
404+
ki = -1
405+
406+
pid_m = 0
407+
pid_n = 0
408+
offs_am = 0
409+
offs_bn = 0
410+
411+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
412+
413+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
414+
415+
for _ in range(0, k_tiles * tiles_per_SM):
416+
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
417+
if ki == 0:
418+
tile_id += NUM_SMS
419+
group_id = tile_id // num_pid_in_group
420+
first_pid_m = group_id * GROUP_SIZE_M
421+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
422+
pid_m = first_pid_m + (tile_id % group_size_m)
423+
pid_n = (tile_id % num_pid_in_group) // group_size_m
424+
425+
offs_am = pid_m * BLOCK_SIZE_M
426+
offs_bn = pid_n * BLOCK_SIZE_N
427+
428+
offs_k = ki * BLOCK_SIZE_K
429+
430+
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype)
431+
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype)
432+
accumulator = tl.dot(a, b.T, accumulator)
433+
434+
if ki == k_tiles - 1:
435+
c = accumulator.to(dtype)
436+
437+
tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn])
438+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
439+
440+
441+
def matmul_device_tma_persistent(a, b):
442+
# Autotuner does not work with TMA. Use manual config.
443+
configs = {
444+
torch.float8_e4m3fn: {
445+
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4,
446+
"num_warps": 8
447+
}, torch.float16: {
448+
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3,
449+
"num_warps": 8
450+
}
451+
}
452+
453+
# Check constraints.
454+
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
455+
assert a.dtype == b.dtype, "Incompatible dtypes"
456+
457+
M, K = a.shape
458+
N, K = b.shape
459+
dtype = a.dtype
460+
461+
c = torch.zeros((M, N), device=a.device, dtype=dtype)
462+
a_desc, b_desc, c_desc = [torch.empty(128, dtype=torch.uint8, device="cuda") for _ in range(3)]
463+
ready_flag = torch.zeros((), dtype=torch.int32, device="cuda")
464+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
465+
466+
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
467+
matmul_kernel_device_tma_persistent[grid](
468+
a_desc, b_desc, c_desc, #
469+
a, b, c, #
470+
ready_flag, #
471+
M, N, K, #
472+
BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], #
473+
BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], #
474+
BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], #
475+
GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], #
476+
NUM_SMS=NUM_SMS, #
477+
num_stages=configs[dtype]["num_stages"], #
478+
num_warps=configs[dtype]["num_warps"], #
479+
)
480+
return c
481+
482+
362483
def cublas_matmul(a, b):
363484
# Check constraints.
364485
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
@@ -414,6 +535,9 @@ def bench(K, dtype, reps=10):
414535
for _ in range(reps):
415536
matmul_tma_persistent(a, b)
416537
time.sleep(0.01)
538+
for _ in range(reps):
539+
matmul_device_tma_persistent(a, b)
540+
time.sleep(0.01)
417541

418542
proton.deactivate(0)
419543

@@ -428,6 +552,7 @@ def validate(M, N, K, dtype):
428552
naive_result = matmul(a, b.T)
429553
persistent_result = matmul_persistent(a, b.T)
430554
tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None
555+
device_tma_persistent_result = matmul_device_tma_persistent(a, b) if supports_tma() else None
431556

432557
if torch_result is not None:
433558
naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16),
@@ -440,14 +565,20 @@ def validate(M, N, K, dtype):
440565
if tma_persistent_result is not None:
441566
naive_vs_tma_persistent = "✅" if torch.allclose(cublas_result.to(torch.float16),
442567
tma_persistent_result.to(torch.float16), atol=1.0) else "❌"
568+
if device_tma_persistent_result is not None:
569+
naive_vs_device_tma_persistent = "✅" if torch.allclose(cublas_result.to(
570+
torch.float16), device_tma_persistent_result.to(torch.float16), atol=1.0) else "❌"
443571
print(f"M={M}, N={N}, K={K} verification naive vs: ", end="")
444572
if torch_result is not None:
445573
print(f"torch: {naive_vs_torch} ", end="")
446574
if cublas_result is not None:
447575
print(f"cublas: {naive_vs_cublas} ", end="")
448576
print(f"persistent: {naive_vs_persistent} ", end="")
449577
if tma_persistent_result is not None:
450-
print(f"TMA persistent: {naive_vs_tma_persistent}")
578+
print(f"TMA persistent: {naive_vs_tma_persistent} ", end="")
579+
if device_tma_persistent_result is not None:
580+
print(f"Device TMA persistent: {naive_vs_device_tma_persistent} ", end="")
581+
print()
451582

452583

453584
if __name__ == "__main__":

0 commit comments

Comments
 (0)