Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 48bd393

Browse files
committed
[wip] static scaling support for training
Summary: In certain cases, activations and gradients can have a bounded range. For example, consider sigmoid -> fc -> ln -> sigmoid: 1. range of sigmoid in the forward is bounded, so we can scale statically if we are ok with a slight accuracy drop in the case that the observed values do not reach the theoretical bound 2. range of derivative of sigmoid is bounded (https://math.stackexchange.com/questions/78575/derivative-of-sigmoid-function-sigma-x-frac11e-x) 3. derivative of LN (https://liorsinai.github.io/mathematics/2022/05/18/layernorm.html) depends on the incoming gradient and the trainable LN parameters, so we can derive a bound based on the incoming bound and calculating max of LN parameters This PR adds static scaling as an option for x, w, dL_dY, and a quick benchmark to verify performance is as we expect. TODO add numerics testing. Test Plan: ``` // baseline python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear_ln_sigmoid ... experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 0.160 0.098 0.613 1.632 1_f8_overhead 0.000 0.100 inf 0.000 2_other 0.147 0.121 0.823 1.215 All 0.307 0.319 1.040 0.962 // static scaling for x (easier to justify numerics given a bounded activation such as sigmoid) python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear_ln_sigmoid --scaling_type_x static experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 0.665 0.362 0.545 1.834 1_f8_overhead 0.000 0.269 inf 0.000 2_other 0.396 0.273 0.689 1.452 All 1.061 0.904 0.853 1.173 // static scaling for x and dL_dY (handwaving for now, the actual code would // need to read the LN params to get the max) python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear_ln_sigmoid --scaling_type_x static --scaling_type_dL_dY static ... experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 0.665 0.365 0.549 1.822 1_f8_overhead 0.000 0.242 inf 0.000 2_other 0.395 0.273 0.690 1.448 All 1.060 0.879 0.830 1.205 ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 538c24e Pull Request resolved: #306
1 parent 8e9623a commit 48bd393

File tree

3 files changed

+230
-87
lines changed

3 files changed

+230
-87
lines changed

benchmarks/profile_linear_float8.py

Lines changed: 118 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,22 @@ def forward(self, h):
149149
return x
150150

151151

152+
class SigmoidLinearLNSigmoid(nn.Module):
153+
def __init__(self, d1, d2):
154+
super().__init__()
155+
self.sigmoid1 = nn.Sigmoid()
156+
self.fc = nn.Linear(d1, d2)
157+
self.ln = nn.LayerNorm(d2)
158+
self.sigmoid2 = nn.Sigmoid()
159+
160+
def forward(self, x):
161+
x = self.sigmoid1(x)
162+
x = self.fc(x)
163+
x = self.ln(x)
164+
x = self.sigmoid2(x)
165+
return x
166+
167+
152168
@dataclass
153169
class ProfileConfig:
154170
file_path: Optional[str] = None
@@ -210,7 +226,12 @@ def main(
210226
model_type: str = "linear",
211227
dtype_filter: str = "both",
212228
):
213-
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
229+
assert model_type in (
230+
"linear",
231+
"ln_linear",
232+
"norm_ffn_norm",
233+
"sigmoid_linear_ln_sigmoid",
234+
), "unsupported"
214235
assert dtype_filter in ("both", "float8", "bfloat16")
215236

216237
scaling_type_x = TensorScalingType(scaling_type_x)
@@ -242,6 +263,12 @@ def main(
242263
input_tensor = torch.randn(
243264
1, 8192, 4096, device=device, dtype=ref_dtype
244265
).requires_grad_()
266+
elif model_type == "sigmoid_linear_ln_sigmoid":
267+
bsz, d1, d2 = 4096, 4096, 4096
268+
m_ref = SigmoidLinearLNSigmoid(d1, d2)
269+
input_tensor = torch.randn(
270+
bsz, d1, device=device, dtype=ref_dtype, requires_grad=True
271+
)
245272
else:
246273
M, K, N = 4 * 4096, 8192, 7168
247274
m_ref = torch.nn.Sequential(
@@ -258,6 +285,15 @@ def main(
258285
"scaling_type_w": scaling_type_w,
259286
"scaling_type_dL_dY": scaling_type_dL_dY,
260287
}
288+
if scaling_type_x is TensorScalingType.STATIC:
289+
# for now, dummy scale
290+
extra_kwargs["static_scale_x"] = 1.0
291+
if scaling_type_w is TensorScalingType.STATIC:
292+
# for now, dummy scale
293+
extra_kwargs["static_scale_w"] = 1.0
294+
if scaling_type_dL_dY is TensorScalingType.STATIC:
295+
# for now, dummy scale
296+
extra_kwargs["static_scale_dL_dY"] = 1.0
261297

262298
m_float8 = copy.deepcopy(m_ref)
263299
swap_linear_with_float8_linear(m_float8, **extra_kwargs)
@@ -300,85 +336,91 @@ def float8_forw_backward_wrapper(x):
300336
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
301337
# to populate triton kernel bandwidth further down in the script
302338
f = io.StringIO()
303-
with redirect_stdout(f):
304-
# warm up
305-
for _ in range(1):
339+
try:
340+
with redirect_stdout(f):
341+
# warm up
342+
for _ in range(1):
343+
if dtype_filter != "float8":
344+
ref_forw_backward(input_tensor)
345+
if dtype_filter != "bfloat16":
346+
float8_forw_backward_wrapper(input_tensor)
347+
348+
profile_iters = 5
349+
ref_times, float8_times = None, None
350+
data = []
351+
306352
if dtype_filter != "float8":
307-
ref_forw_backward(input_tensor)
308-
if dtype_filter != "bfloat16":
309-
float8_forw_backward_wrapper(input_tensor)
310-
311-
profile_iters = 5
312-
ref_times, float8_times = None, None
313-
data = []
314-
315-
if dtype_filter != "float8":
316-
# Profile Reference Model
317-
print("profiling ref")
318-
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
319-
ref_path = profile_path_prefix + ref_suffix
320-
profile_config = ProfileConfig(
321-
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
322-
)
323-
p = profile_function(profile_config, ref_forw_backward, input_tensor)
324-
print(f"saved {ref_path}")
325-
ref_times = profiler_output_to_time_by_kernel_name(p)
326-
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
327-
for k, v in ref_times.items():
328-
v_ms = v / 1e3 / profile_iters
329-
data.append(
330-
[
331-
"0_ref",
332-
k,
333-
kernel_name_to_category(k),
334-
v_ms,
335-
v_ms / total_time_ms,
336-
None,
337-
]
353+
# Profile Reference Model
354+
print("profiling ref")
355+
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
356+
ref_path = profile_path_prefix + ref_suffix
357+
profile_config = ProfileConfig(
358+
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
338359
)
360+
p = profile_function(profile_config, ref_forw_backward, input_tensor)
361+
print(f"saved {ref_path}")
362+
ref_times = profiler_output_to_time_by_kernel_name(p)
363+
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
364+
for k, v in ref_times.items():
365+
v_ms = v / 1e3 / profile_iters
366+
data.append(
367+
[
368+
"0_ref",
369+
k,
370+
kernel_name_to_category(k),
371+
v_ms,
372+
v_ms / total_time_ms,
373+
None,
374+
]
375+
)
339376

340-
if dtype_filter != "bfloat16":
341-
# Profile Float8 Model
342-
print("profiling float8")
343-
float8_suffix = (
344-
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
345-
)
346-
float8_path = profile_path_prefix + float8_suffix
347-
profile_config = ProfileConfig(
348-
float8_path,
349-
float8_suffix,
350-
iters=profile_iters,
351-
warmup_iters=2,
352-
sync=True,
353-
)
354-
p = profile_function(
355-
profile_config, float8_forw_backward_wrapper, input_tensor
356-
)
357-
print(f"saved {float8_path}")
358-
float8_times = profiler_output_to_time_by_kernel_name(p)
359-
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
360-
for k, v in float8_times.items():
361-
v_ms = v / 1e3 / profile_iters
362-
data.append(
363-
[
364-
"1_float8",
365-
k,
366-
kernel_name_to_category(k),
367-
v / 1e3 / profile_iters,
368-
v_ms / total_time_ms,
369-
None,
370-
]
377+
if dtype_filter != "bfloat16":
378+
# Profile Float8 Model
379+
print("profiling float8")
380+
float8_suffix = (
381+
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
371382
)
372-
373-
# get the time spent per user annotation
374-
sync_time_us = profiler_output_to_gpu_time_for_key(
375-
p, "scale_amax_and_scales"
376-
)
377-
sync_time_ms = sync_time_us / profile_iters / 1e3
378-
print(f"Sync time ms: {sync_time_ms}")
379-
380-
# print the redirected stdout back to regular stdout
381-
print(f.getvalue())
383+
float8_path = profile_path_prefix + float8_suffix
384+
profile_config = ProfileConfig(
385+
float8_path,
386+
float8_suffix,
387+
iters=profile_iters,
388+
warmup_iters=2,
389+
sync=True,
390+
)
391+
p = profile_function(
392+
profile_config, float8_forw_backward_wrapper, input_tensor
393+
)
394+
print(f"saved {float8_path}")
395+
float8_times = profiler_output_to_time_by_kernel_name(p)
396+
total_time_ms = (
397+
sum(v for v in float8_times.values()) / 1e3 / profile_iters
398+
)
399+
for k, v in float8_times.items():
400+
v_ms = v / 1e3 / profile_iters
401+
data.append(
402+
[
403+
"1_float8",
404+
k,
405+
kernel_name_to_category(k),
406+
v / 1e3 / profile_iters,
407+
v_ms / total_time_ms,
408+
None,
409+
]
410+
)
411+
412+
# get the time spent per user annotation
413+
sync_time_us = profiler_output_to_gpu_time_for_key(
414+
p, "scale_amax_and_scales"
415+
)
416+
sync_time_ms = sync_time_us / profile_iters / 1e3
417+
print(f"Sync time ms: {sync_time_ms}")
418+
419+
finally:
420+
# print the redirected stdout back to regular stdout
421+
# the finally clause is to help print output in the presence of exceptions,
422+
# to aid local debugging
423+
print(f.getvalue())
382424

383425
# populate the triton kernel bandwidth
384426
for line in f.getvalue().split("\n"):

0 commit comments

Comments
 (0)