Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 1e9add3

Browse files
vkuzofacebook-github-bot
authored andcommitted
QOL improvements to linear benchmarking script (#278)
Summary: Pull Request resolved: #278 1. add more command line filters 2. add dynamic scaling 3. remove float16 since it's low-pri and this cuts down benchmark time by 50% Reviewed By: drisspg Differential Revision: D58396927 fbshipit-source-id: 298cb3c48418d4b9dd1529fe38cf2229ab5618b7
1 parent 5d293a7 commit 1e9add3

File tree

1 file changed

+70
-19
lines changed

1 file changed

+70
-19
lines changed

benchmarks/bench_linear_float8.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17+
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1718
from float8_experimental.float8_linear import Float8Linear
18-
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history
19+
from float8_experimental.float8_linear_utils import (
20+
get_float8_linear,
21+
LinearType,
22+
sync_float8_amax_and_scale_history,
23+
)
1924
from float8_experimental.float8_tensor import ScaledMMConfig
2025
from tqdm import tqdm
2126

@@ -35,6 +40,13 @@
3540
torch.float8_e5m2: h100_peak_tops_float8_tc,
3641
}
3742

43+
# prevent splitting columns when printing a data frame
44+
pd.set_option("display.expand_frame_repr", False)
45+
# print the entire data frame
46+
pd_print_full_ctx = pd.option_context(
47+
"display.max_rows", None, "display.max_columns", None
48+
)
49+
3850

3951
def benchmark_torch_function_in_microseconds(
4052
func: Callable,
@@ -57,6 +69,7 @@ class Experiment:
5769
dtype: torch.dtype
5870
compiled: bool
5971
use_fast_accum: bool
72+
linear_type: str
6073

6174
# 3 Times since we are calculating forward backward
6275
@property
@@ -79,9 +92,12 @@ def float8_pct_top_peak(self):
7992

8093

8194
def main(
82-
sweep_path: Path,
83-
compile: bool,
95+
sweep_path: Optional[Path] = None,
96+
compile: bool = False,
8497
n_limit: Optional[int] = None,
98+
fast_accum_filter: Optional[bool] = None,
99+
shape_name_filter: Optional[str] = None,
100+
linear_type_filter: Optional[str] = None,
85101
):
86102
device = "cuda"
87103
print(f"Compile is set to | {compile}")
@@ -95,20 +111,33 @@ def main(
95111
"ffn.w2": (3584, 8192),
96112
}
97113
input_bias = False
98-
ref_dtypes = [torch.bfloat16, torch.float16]
99-
use_fast_accum = [True, False]
114+
if fast_accum_filter is not None:
115+
use_fast_accum = [fast_accum_filter]
116+
else:
117+
use_fast_accum = [True, False]
118+
if linear_type_filter is not None:
119+
linear_types = [linear_type_filter]
120+
else:
121+
linear_types = ["delayed", "dynamic"]
122+
if shape_name_filter is not None:
123+
k = shape_name_filter
124+
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
100125
experiment_list: List[Experiment] = []
101-
for idx, (dtype, fast_accum, (name, (K, N))) in enumerate(
102-
tqdm(list(product(ref_dtypes, use_fast_accum, name_to_shapes_70b.items())))
126+
dtype = torch.bfloat16
127+
for idx, (fast_accum, (name, (K, N)), linear_type) in enumerate(
128+
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items(), linear_types)))
103129
):
104130
if n_limit is not None and idx >= n_limit:
105131
break
106132
linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
107133
device=device, dtype=dtype
108134
)
135+
linear_type_enum = (
136+
LinearType.DELAYED if linear_type == "delayed" else LinearType.DYNAMIC
137+
)
109138

110-
linear_float8 = Float8Linear.from_float(
111-
copy.deepcopy(linear_ref), emulate=False
139+
linear_float8 = get_float8_linear(
140+
linear_type_enum, copy.deepcopy(linear_ref), emulate=False
112141
)
113142
if fast_accum:
114143
linear_float8.forward_config = ScaledMMConfig(False, True, False)
@@ -120,9 +149,16 @@ def main(
120149
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
121150
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
122151

123-
def float8_forw_backward():
124-
sync_float8_amax_and_scale_history(linear_float8)
125-
linear_float8(input_tensor).sum().backward()
152+
if linear_type_enum == LinearType.DELAYED:
153+
154+
def float8_forw_backward():
155+
sync_float8_amax_and_scale_history(linear_float8)
156+
linear_float8(input_tensor).sum().backward()
157+
158+
else:
159+
160+
def float8_forw_backward():
161+
linear_float8(input_tensor).sum().backward()
126162

127163
def n_times(n, fn, *args, **kwargs):
128164
def wrapper(*args, **kwargs):
@@ -162,6 +198,7 @@ def wrapper(*args, **kwargs):
162198
dtype,
163199
compile,
164200
use_fast_accum=fast_accum,
201+
linear_type=linear_type,
165202
)
166203
print(experiment)
167204
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
@@ -173,6 +210,7 @@ def wrapper(*args, **kwargs):
173210
"M",
174211
"K",
175212
"N",
213+
"linear_type",
176214
"ref_dtype",
177215
"compiled",
178216
"use_fast_accum",
@@ -191,6 +229,7 @@ def wrapper(*args, **kwargs):
191229
experiment.shape[0],
192230
experiment.shape[1],
193231
experiment.shape[2],
232+
experiment.linear_type,
194233
experiment.dtype,
195234
experiment.compiled,
196235
experiment.use_fast_accum,
@@ -219,28 +258,40 @@ def wrapper(*args, **kwargs):
219258
[
220259
"name",
221260
"shape",
222-
"ref_dtype",
261+
"linear_type",
223262
"compiled",
224263
"use_fast_accum",
225264
"ref_time_sec",
226265
"pt_fp8_time_sec",
227266
"pt_fp8_speedup",
228267
]
229268
]
230-
print(data_pd_simple)
269+
with pd_print_full_ctx:
270+
print(data_pd_simple)
231271

232-
sweep_path = sweep_path.with_suffix(".csv")
233-
data_pd.to_csv(sweep_path)
272+
if sweep_path is not None:
273+
sweep_path = sweep_path.with_suffix(".csv")
274+
data_pd.to_csv(sweep_path)
234275

235276

236277
def invoke_main() -> None:
237278
parser = argparse.ArgumentParser()
238-
parser.add_argument("-o", "--output_path", type=str, required=True)
279+
parser.add_argument("-o", "--output_path", type=str, required=False)
239280
parser.add_argument("--compile", action="store_true")
240281
parser.add_argument("-n", "--n_limit", type=int, required=False)
282+
parser.add_argument("--fast_accum_filter", type=bool, required=False)
283+
parser.add_argument("--shape_name_filter", type=str, required=False)
284+
parser.add_argument("--linear_type_filter", type=str, required=False)
241285
args = parser.parse_args()
242-
output_path = Path(args.output_path)
243-
main(output_path, args.compile, args.n_limit)
286+
output_path = Path(args.output_path) if args.output_path is not None else None
287+
main(
288+
output_path,
289+
args.compile,
290+
args.n_limit,
291+
args.fast_accum_filter,
292+
args.shape_name_filter,
293+
args.linear_type_filter,
294+
)
244295

245296

246297
if __name__ == "__main__":

0 commit comments

Comments
 (0)