Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 55af19f

Browse files
committed
bigger sweep
1 parent 017e858 commit 55af19f

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

benchmarks/bench_padding.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import fire
55

66
import torch
7-
import torch.utils.benchmark as benchmark
87
from float8_experimental.float8_utils import pad_tensor_for_matmul
98
from tabulate import tabulate
9+
from torch._inductor.utils import do_bench_using_profiling
10+
from tqdm import tqdm
1011

1112
# estimating TOPs for matmuls in fp32, fp16, fp8
1213
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
@@ -26,14 +27,9 @@
2627

2728

2829
def benchmark_fn_in_usec(f, *args, **kwargs):
29-
# Manual warmup
30-
for _ in range(4):
31-
f(*args, **kwargs)
32-
t0 = benchmark.Timer(
33-
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
34-
)
35-
measurement = t0.blocked_autorange()
36-
return measurement.mean * 1e6
30+
no_args = lambda: f(*args, **kwargs)
31+
time = do_bench_using_profiling(no_args)
32+
return time * 1e3
3733

3834

3935
def get_tops_info(tops, time, peak_tops):
@@ -51,16 +47,17 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
5147
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
5248

5349
A_pad = pad_tensor_for_matmul(A_fp8, dims=1) # mem copy
54-
B_pad = pad_tensor_for_matmul(B_fp8, dims=[0, 1]).contiguous().t() # mem copy
50+
B_pad = pad_tensor_for_matmul(B_fp8, dims=0).contiguous().t() # mem copy
5551

56-
return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[
57-
: A.shape[0], : B.shape[1]
58-
]
52+
return torch._scaled_mm(
53+
A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True
54+
)
5955

6056

6157
def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
58+
# We are only going to test the shape preserving
6259
A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
63-
B_pad = pad_tensor_for_matmul(B, dims=[0, 1]) # mem copy
60+
B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy
6461

6562
scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
6663
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
@@ -70,9 +67,9 @@ def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
7067

7168
B_pad = B_pad.t().contiguous().t() # mem copy
7269

73-
return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[
74-
: A.shape[0], : B.shape[1]
75-
]
70+
return torch._scaled_mm(
71+
A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True
72+
)
7673

7774

7875
def do_hp_matmul(A, B):
@@ -92,7 +89,18 @@ def __iter__(self):
9289

9390

9491
def gen_configs():
95-
shapes = [(8192, 2500, 5000), (64, 255, 4096)]
92+
shapes = shapes = [
93+
(8193, 2501, 5008),
94+
(65, 253, 4096),
95+
(1023, 1029, 2512),
96+
(4095, 511, 10000),
97+
(2047, 3073, 8192),
98+
(511, 769, 7504),
99+
(127, 4097, 12288),
100+
(32769, 15, 15024),
101+
(9217, 8191, 20480),
102+
(16385, 1025, 25008),
103+
]
96104
output_dtype = torch.bfloat16
97105
fp8_dtype = torch.float8_e4m3fn
98106
return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]
@@ -112,7 +120,7 @@ def run(compile: bool = False, n_limit: Optional[int] = None):
112120
"Ref % Peak",
113121
"FP8 % Peak",
114122
]
115-
for experiment in experiments:
123+
for experiment in tqdm(experiments):
116124
M, K, N, output_dtype, fp8_dtype = experiment
117125
tops = 2 * M * N * K
118126

0 commit comments

Comments
 (0)