Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 017e858

Browse files
committed
update benchmark
1 parent ee5d4f9 commit 017e858

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

benchmarks/bench_padding.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,24 +47,30 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
4747
A_fp8 = A.to(fp8_dtype)
4848
B_fp8 = B.to(fp8_dtype).t() # view
4949

50-
A_pad = pad_tensor_for_matmul(A_fp8) # mem copy
51-
B_pad = pad_tensor_for_matmul(B_fp8, both=True).contiguous().t() # mem copy
50+
scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
51+
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
5252

53-
return torch._scaled_mm(A_pad, B_pad, out_dtype=out_dtype)[0][
53+
A_pad = pad_tensor_for_matmul(A_fp8, dims=1) # mem copy
54+
B_pad = pad_tensor_for_matmul(B_fp8, dims=[0, 1]).contiguous().t() # mem copy
55+
56+
return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[
5457
: A.shape[0], : B.shape[1]
5558
]
5659

5760

5861
def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
59-
A_pad = pad_tensor_for_matmul(A) # mem copy
60-
B_pad = pad_tensor_for_matmul(B, both=True) # mem copy
62+
A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
63+
B_pad = pad_tensor_for_matmul(B, dims=[0, 1]) # mem copy
64+
65+
scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
66+
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
6167

6268
A_pad = A_pad.to(fp8_dtype) # mem copy
6369
B_pad = B_pad.to(fp8_dtype) # mem copy
6470

6571
B_pad = B_pad.t().contiguous().t() # mem copy
6672

67-
return torch._scaled_mm(A_pad, B_pad, out_dtype=out_dtype)[0][
73+
return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[
6874
: A.shape[0], : B.shape[1]
6975
]
7076

@@ -86,8 +92,8 @@ def __iter__(self):
8692

8793

8894
def gen_configs():
89-
shapes = [(8192, 2500, 5000), (4096, 10, 4096)]
90-
output_dtype = torch.float32
95+
shapes = [(8192, 2500, 5000), (64, 255, 4096)]
96+
output_dtype = torch.bfloat16
9197
fp8_dtype = torch.float8_e4m3fn
9298
return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]
9399

0 commit comments

Comments
 (0)