Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 29e48ac

Browse files
committed
Add utilities for padding and add to bench_padding.py
1 parent 1e9add3 commit 29e48ac

File tree

2 files changed

+197
-0
lines changed

2 files changed

+197
-0
lines changed

benchmarks/bench_padding.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
import fire
5+
6+
import torch
7+
import torch.utils.benchmark as benchmark
8+
from float8_experimental.float8_utils import pad_tensor_for_matmul
9+
from tabulate import tabulate
10+
11+
# estimating TOPs for matmuls in fp32, fp16, fp8
12+
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
13+
14+
# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
15+
h100_peak_flops_float32 = 67e12
16+
h100_peak_flops_fp16_tc = 1979e12
17+
h100_peak_tops_float8_tc = 3958e12
18+
19+
dtype_to_peak_tops = {
20+
torch.float32: h100_peak_flops_float32,
21+
torch.float16: h100_peak_flops_fp16_tc,
22+
torch.bfloat16: h100_peak_flops_fp16_tc,
23+
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
24+
torch.float8_e5m2: h100_peak_tops_float8_tc,
25+
}
26+
27+
28+
def benchmark_fn_in_usec(f, *args, **kwargs):
29+
# Manual warmup
30+
for _ in range(4):
31+
f(*args, **kwargs)
32+
t0 = benchmark.Timer(
33+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
34+
)
35+
measurement = t0.blocked_autorange()
36+
return measurement.mean * 1e6
37+
38+
39+
def get_tops_info(tops, time, peak_tops):
40+
time_sec = time / 1e6
41+
tops_sec = float(tops) / time_sec
42+
pct_top_peak = tops_sec / peak_tops
43+
return tops_sec, pct_top_peak
44+
45+
46+
def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
47+
A_fp8 = A.to(fp8_dtype)
48+
B_fp8 = B.to(fp8_dtype).t() # view
49+
50+
A_pad = pad_tensor_for_matmul(A_fp8) # mem copy
51+
B_pad = pad_tensor_for_matmul(B_fp8, both=True).contiguous().t() # mem copy
52+
53+
return torch._scaled_mm(A_pad, B_pad, out_dtype=out_dtype)[0][
54+
: A.shape[0], : B.shape[1]
55+
]
56+
57+
58+
def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
59+
A_pad = pad_tensor_for_matmul(A) # mem copy
60+
B_pad = pad_tensor_for_matmul(B, both=True) # mem copy
61+
62+
A_pad = A_pad.to(fp8_dtype) # mem copy
63+
B_pad = B_pad.to(fp8_dtype) # mem copy
64+
65+
B_pad = B_pad.t().contiguous().t() # mem copy
66+
67+
return torch._scaled_mm(A_pad, B_pad, out_dtype=out_dtype)[0][
68+
: A.shape[0], : B.shape[1]
69+
]
70+
71+
72+
def do_hp_matmul(A, B):
73+
return torch.matmul(A, B)
74+
75+
76+
@dataclass
77+
class Experiment_config:
78+
M: int
79+
K: int
80+
N: int
81+
output_dtype: torch.dtype
82+
fp8_dtype: torch.dtype
83+
84+
def __iter__(self):
85+
return iter((self.M, self.K, self.N, self.output_dtype, self.fp8_dtype))
86+
87+
88+
def gen_configs():
89+
shapes = [(8192, 2500, 5000), (4096, 10, 4096)]
90+
output_dtype = torch.float32
91+
fp8_dtype = torch.float8_e4m3fn
92+
return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]
93+
94+
95+
@torch.no_grad()
96+
def run(compile: bool = False, n_limit: Optional[int] = None):
97+
device = "cuda"
98+
experiments = gen_configs()
99+
results = []
100+
tops_table = []
101+
tops_headers = [
102+
"Shape",
103+
"Ref Dtype",
104+
"Ref Tops",
105+
"FP8 Tops",
106+
"Ref % Peak",
107+
"FP8 % Peak",
108+
]
109+
for experiment in experiments:
110+
M, K, N, output_dtype, fp8_dtype = experiment
111+
tops = 2 * M * N * K
112+
113+
A_base = torch.rand(M, K, device=device, dtype=output_dtype)
114+
B_base = torch.rand(K, N, device=device, dtype=output_dtype)
115+
116+
hp_func = torch.compile(do_hp_matmul) if compile else do_hp_matmul
117+
fp8_func = torch.compile(do_fp8_pad_first_matmul) if compile else do_fp8_matmul
118+
119+
ref_time = benchmark_fn_in_usec(hp_func, A_base, B_base)
120+
fp8_time = benchmark_fn_in_usec(
121+
fp8_func, A_base, B_base, fp8_dtype, output_dtype
122+
)
123+
124+
ref_tops_sec, ref_pct_top_peak = get_tops_info(
125+
tops, ref_time, dtype_to_peak_tops[output_dtype]
126+
)
127+
fp8_tops_sec, fp8_pct_top_peak = get_tops_info(
128+
tops, fp8_time, dtype_to_peak_tops[fp8_dtype]
129+
)
130+
tops_table.append(
131+
[
132+
f"({M}x{K}x{N})",
133+
f"{output_dtype}",
134+
f"{ref_tops_sec:.2E}",
135+
f"{fp8_tops_sec:.2E}",
136+
f"{ref_pct_top_peak:.3f}",
137+
f"{fp8_pct_top_peak:.3f}",
138+
]
139+
)
140+
results.append(
141+
[(M, K, N), output_dtype, ref_time, fp8_time, ref_time / fp8_time]
142+
)
143+
144+
print("TOPs".center(80, "*"))
145+
print(tabulate(tops_table, headers=tops_headers))
146+
print("Speed Results".center(80, "*"))
147+
headers = ["Shape", "Ref Dtype", "Ref Time", "FP8 Time", "Speedup"]
148+
print(tabulate(results, headers=headers, tablefmt="grid"))
149+
150+
151+
if __name__ == "__main__":
152+
fire.Fire(run)

float8_experimental/float8_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,48 @@ def fp8_tensor_statistics(
172172
def is_row_major(stride):
173173
assert len(stride) == 2, "is_row_major only supports 2D tensors"
174174
return stride[0] > stride[1] and stride[1] == 1
175+
176+
177+
def get_min_alignment(size: int, alignment_value: int):
178+
"""
179+
Returns the minimum alignment value that is greater than or equal to the given size.
180+
181+
Args:
182+
size: The size of the data to be aligned.
183+
alignment_value: The alignment value to be used.
184+
185+
Returns:
186+
int: The minimum alignment value that is greater than or equal to the given size.
187+
"""
188+
if size % alignment_value == 0:
189+
return size
190+
return (1 + (size // alignment_value)) * alignment_value
191+
192+
193+
def pad_tensor_for_matmul(tensor: torch.Tensor, both: bool = False) -> torch.Tensor:
194+
"""
195+
Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required for H100s.
196+
197+
Args:
198+
tensor: The tensor to pad.
199+
both: Whether to pad both dimensions or just the second dimension.
200+
201+
Returns:
202+
torch.Tensor: The padded tensor.
203+
"""
204+
assert tensor.dim() == 2
205+
dim1, dim2 = tensor.shape
206+
207+
# Calculate aligned dimensions
208+
dim2_aligned = get_min_alignment(dim2, 16)
209+
dim1_aligned = get_min_alignment(dim1, 16) if both else dim1
210+
211+
# Check if padding is needed for either dimension
212+
if dim1 == dim1_aligned and dim2 == dim2_aligned:
213+
return tensor
214+
215+
# Calculate padding values for both dimensions
216+
pad_dim1 = dim1_aligned - dim1
217+
pad_dim2 = dim2_aligned - dim2
218+
219+
return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))

0 commit comments

Comments
 (0)