Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ffd4bcf

Browse files
committed
[wip] static scaling support for training
Summary: In certain cases, activations and gradients can have a bounded range. For example, consider sigmoid -> fc -> ln -> sigmoid: 1. range of sigmoid in the forward is bounded, so we can scale statically if we are ok with a slight accuracy drop in the case that the observed values do not reach the theoretical bound 2. range of derivative of sigmoid is bounded (https://math.stackexchange.com/questions/78575/derivative-of-sigmoid-function-sigma-x-frac11e-x) 3. derivative of LN (https://liorsinai.github.io/mathematics/2022/05/18/layernorm.html) depends on the incoming gradient and the trainable LN parameters, so we can derive a bound based on the incoming bound and calculating max of LN parameters This PR adds static scaling as an option for x, w, dL_dY, and a quick benchmark to verify performance is as we expect. TODO add numerics testing. Test Plan: ``` // baseline python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear_ln_sigmoid ... experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 0.160 0.098 0.613 1.632 1_f8_overhead 0.000 0.100 inf 0.000 2_other 0.147 0.121 0.823 1.215 All 0.307 0.319 1.040 0.962 // static scaling for x (easier to justify numerics given a bounded activation such as sigmoid) python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear_ln_sigmoid --scaling_type_x static experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 0.665 0.362 0.545 1.834 1_f8_overhead 0.000 0.269 inf 0.000 2_other 0.396 0.273 0.689 1.452 All 1.061 0.904 0.853 1.173 // static scaling for x and dL_dY (handwaving for now, the actual code would // need to read the LN params to get the max) python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear_ln_sigmoid --scaling_type_x static --scaling_type_dL_dY static ... experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 0.665 0.365 0.549 1.822 1_f8_overhead 0.000 0.242 inf 0.000 2_other 0.395 0.273 0.690 1.448 All 1.060 0.879 0.830 1.205 ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b8fd4e9 Pull Request resolved: #306
1 parent 3b69b15 commit ffd4bcf

File tree

5 files changed

+173
-88
lines changed

5 files changed

+173
-88
lines changed

benchmarks/profile_linear_float8.py

Lines changed: 108 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,18 @@ def forward(self, h):
149149
return x
150150

151151

152+
class SigmoidLinear(nn.Module):
153+
def __init__(self, d1, d2):
154+
super().__init__()
155+
self.sigmoid1 = nn.Sigmoid()
156+
self.fc = nn.Linear(d1, d2)
157+
158+
def forward(self, x):
159+
x = self.sigmoid1(x)
160+
x = self.fc(x)
161+
return x
162+
163+
152164
@dataclass
153165
class ProfileConfig:
154166
file_path: Optional[str] = None
@@ -210,7 +222,12 @@ def main(
210222
model_type: str = "linear",
211223
dtype_filter: str = "both",
212224
):
213-
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
225+
assert model_type in (
226+
"linear",
227+
"ln_linear",
228+
"norm_ffn_norm",
229+
"sigmoid_linear",
230+
), "unsupported"
214231
assert dtype_filter in ("both", "float8", "bfloat16")
215232

216233
scaling_type_x = TensorScalingType(scaling_type_x)
@@ -242,6 +259,12 @@ def main(
242259
input_tensor = torch.randn(
243260
1, 8192, 4096, device=device, dtype=ref_dtype
244261
).requires_grad_()
262+
elif model_type == "sigmoid_linear":
263+
bsz, d1, d2 = 4096, 4096, 4096
264+
m_ref = SigmoidLinear(d1, d2)
265+
input_tensor = torch.randn(
266+
bsz, d1, device=device, dtype=ref_dtype, requires_grad=True
267+
)
245268
else:
246269
M, K, N = 4 * 4096, 8192, 7168
247270
m_ref = torch.nn.Sequential(
@@ -258,6 +281,9 @@ def main(
258281
"scaling_type_w": scaling_type_w,
259282
"scaling_type_dL_dY": scaling_type_dL_dY,
260283
}
284+
if scaling_type_x is TensorScalingType.STATIC:
285+
# for now, dummy scale
286+
extra_kwargs["static_scale_x"] = torch.tensor(1.0, device="cuda")
261287

262288
m_float8 = copy.deepcopy(m_ref)
263289
swap_linear_with_float8_linear(m_float8, **extra_kwargs)
@@ -300,85 +326,91 @@ def float8_forw_backward_wrapper(x):
300326
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
301327
# to populate triton kernel bandwidth further down in the script
302328
f = io.StringIO()
303-
with redirect_stdout(f):
304-
# warm up
305-
for _ in range(1):
329+
try:
330+
with redirect_stdout(f):
331+
# warm up
332+
for _ in range(1):
333+
if dtype_filter != "float8":
334+
ref_forw_backward(input_tensor)
335+
if dtype_filter != "bfloat16":
336+
float8_forw_backward_wrapper(input_tensor)
337+
338+
profile_iters = 5
339+
ref_times, float8_times = None, None
340+
data = []
341+
306342
if dtype_filter != "float8":
307-
ref_forw_backward(input_tensor)
308-
if dtype_filter != "bfloat16":
309-
float8_forw_backward_wrapper(input_tensor)
310-
311-
profile_iters = 5
312-
ref_times, float8_times = None, None
313-
data = []
314-
315-
if dtype_filter != "float8":
316-
# Profile Reference Model
317-
print("profiling ref")
318-
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
319-
ref_path = profile_path_prefix + ref_suffix
320-
profile_config = ProfileConfig(
321-
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
322-
)
323-
p = profile_function(profile_config, ref_forw_backward, input_tensor)
324-
print(f"saved {ref_path}")
325-
ref_times = profiler_output_to_time_by_kernel_name(p)
326-
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
327-
for k, v in ref_times.items():
328-
v_ms = v / 1e3 / profile_iters
329-
data.append(
330-
[
331-
"0_ref",
332-
k,
333-
kernel_name_to_category(k),
334-
v_ms,
335-
v_ms / total_time_ms,
336-
None,
337-
]
343+
# Profile Reference Model
344+
print("profiling ref")
345+
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
346+
ref_path = profile_path_prefix + ref_suffix
347+
profile_config = ProfileConfig(
348+
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
338349
)
350+
p = profile_function(profile_config, ref_forw_backward, input_tensor)
351+
print(f"saved {ref_path}")
352+
ref_times = profiler_output_to_time_by_kernel_name(p)
353+
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
354+
for k, v in ref_times.items():
355+
v_ms = v / 1e3 / profile_iters
356+
data.append(
357+
[
358+
"0_ref",
359+
k,
360+
kernel_name_to_category(k),
361+
v_ms,
362+
v_ms / total_time_ms,
363+
None,
364+
]
365+
)
339366

340-
if dtype_filter != "bfloat16":
341-
# Profile Float8 Model
342-
print("profiling float8")
343-
float8_suffix = (
344-
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
345-
)
346-
float8_path = profile_path_prefix + float8_suffix
347-
profile_config = ProfileConfig(
348-
float8_path,
349-
float8_suffix,
350-
iters=profile_iters,
351-
warmup_iters=2,
352-
sync=True,
353-
)
354-
p = profile_function(
355-
profile_config, float8_forw_backward_wrapper, input_tensor
356-
)
357-
print(f"saved {float8_path}")
358-
float8_times = profiler_output_to_time_by_kernel_name(p)
359-
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
360-
for k, v in float8_times.items():
361-
v_ms = v / 1e3 / profile_iters
362-
data.append(
363-
[
364-
"1_float8",
365-
k,
366-
kernel_name_to_category(k),
367-
v / 1e3 / profile_iters,
368-
v_ms / total_time_ms,
369-
None,
370-
]
367+
if dtype_filter != "bfloat16":
368+
# Profile Float8 Model
369+
print("profiling float8")
370+
float8_suffix = (
371+
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
371372
)
372-
373-
# get the time spent per user annotation
374-
sync_time_us = profiler_output_to_gpu_time_for_key(
375-
p, "scale_amax_and_scales"
376-
)
377-
sync_time_ms = sync_time_us / profile_iters / 1e3
378-
print(f"Sync time ms: {sync_time_ms}")
379-
380-
# print the redirected stdout back to regular stdout
381-
print(f.getvalue())
373+
float8_path = profile_path_prefix + float8_suffix
374+
profile_config = ProfileConfig(
375+
float8_path,
376+
float8_suffix,
377+
iters=profile_iters,
378+
warmup_iters=2,
379+
sync=True,
380+
)
381+
p = profile_function(
382+
profile_config, float8_forw_backward_wrapper, input_tensor
383+
)
384+
print(f"saved {float8_path}")
385+
float8_times = profiler_output_to_time_by_kernel_name(p)
386+
total_time_ms = (
387+
sum(v for v in float8_times.values()) / 1e3 / profile_iters
388+
)
389+
for k, v in float8_times.items():
390+
v_ms = v / 1e3 / profile_iters
391+
data.append(
392+
[
393+
"1_float8",
394+
k,
395+
kernel_name_to_category(k),
396+
v / 1e3 / profile_iters,
397+
v_ms / total_time_ms,
398+
None,
399+
]
400+
)
401+
402+
# get the time spent per user annotation
403+
sync_time_us = profiler_output_to_gpu_time_for_key(
404+
p, "scale_amax_and_scales"
405+
)
406+
sync_time_ms = sync_time_us / profile_iters / 1e3
407+
print(f"Sync time ms: {sync_time_ms}")
408+
409+
finally:
410+
# print the redirected stdout back to regular stdout
411+
# the finally clause is to help print output in the presence of exceptions,
412+
# to aid local debugging
413+
print(f.getvalue())
382414

383415
# populate the triton kernel bandwidth
384416
for line in f.getvalue().split("\n"):

float8_experimental/float8_linear.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,16 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
135135
class TensorScalingType(enum.Enum):
136136
DELAYED = "delayed"
137137
DYNAMIC = "dynamic"
138+
STATIC = "static" # note: only supported for `x`
138139

139140
def short_str(self):
140141
if self is TensorScalingType.DELAYED:
141142
return "del"
142-
else:
143-
assert self is TensorScalingType.DYNAMIC
143+
elif self is TensorScalingType.DYNAMIC:
144144
return "dyn"
145+
else:
146+
assert self is TensorScalingType.STATIC
147+
return "sta"
145148

146149

147150
class Float8Linear(torch.nn.Linear):
@@ -154,9 +157,10 @@ def __init__(self, *args, **kwargs):
154157
"""
155158
Additional arguments on top of `torch.nn.Linear`'s arguments:
156159
* `delayed_scaling_recipe`: configuration for delayed scaling
157-
* `scaling_type_x`: delayed vs dynamic scaling for `x`
158-
* `scaling_type_w`: delayed vs dynamic scaling for `w`
159-
* `scaling_type_dL_dY`: delayed vs dynamic scaling for `dL_dY`
160+
* `scaling_type_x`: dynamic/delayed/static scaling for `x`
161+
* `scaling_type_w`: dynamic/delayed scaling for `w`
162+
* `scaling_type_dL_dY`: dynamic/delayed scaling for `dL_dY`
163+
* `static_scale_x`: static scale for `x`, requires static scaling
160164
"""
161165

162166
delayed_scaling_recipe = kwargs.pop(
@@ -168,19 +172,32 @@ def __init__(self, *args, **kwargs):
168172
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED)
169173
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED)
170174
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED)
175+
static_scale_x = kwargs.pop("static_scale_x", None)
171176
super().__init__(*args, **kwargs)
172177

173178
# Defines the scaling behavior of x, w, dL_dY
174179
self.scaling_type_x = scaling_type_x
175180
self.scaling_type_w = scaling_type_w
181+
assert self.scaling_type_w in (
182+
TensorScalingType.DELAYED,
183+
TensorScalingType.DYNAMIC,
184+
), "unsupported"
176185
self.scaling_type_dL_dY = scaling_type_dL_dY
186+
assert self.scaling_type_dL_dY in (
187+
TensorScalingType.DELAYED,
188+
TensorScalingType.DYNAMIC,
189+
), "unsupported"
177190
# Convenience flag to skip code related to delayed scaling
178191
self.has_any_delayed_scaling = (
179192
self.scaling_type_x is TensorScalingType.DELAYED
180193
or self.scaling_type_w is TensorScalingType.DELAYED
181194
or self.scaling_type_dL_dY is TensorScalingType.DELAYED
182195
)
183196

197+
if static_scale_x is not None:
198+
assert self.scaling_type_x is TensorScalingType.STATIC, "unsupported"
199+
self.register_always_float32_buffer("static_scale_x", static_scale_x)
200+
184201
# TODO(future): have a unique recipe per buffer instead of one per
185202
# module, saving implementing that until we need it.
186203
# TODO(future): serialization for recipes
@@ -306,9 +323,16 @@ def cast_x_to_float8(
306323
self.fp8_amax_x,
307324
self.forward_config,
308325
)
309-
else:
310-
assert self.scaling_type_x is TensorScalingType.DYNAMIC
326+
elif self.scaling_type_x is TensorScalingType.DYNAMIC:
311327
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
328+
else:
329+
assert self.scaling_type_x is TensorScalingType.STATIC
330+
x_fp8 = Float8Tensor.to_float8(
331+
x,
332+
self.static_scale_x,
333+
e4m3_dtype,
334+
mm_config=self.forward_config,
335+
)
312336
return x_fp8
313337

314338
def cast_w_to_float8(
@@ -417,6 +441,7 @@ def from_float(
417441
scaling_type_x=TensorScalingType.DELAYED,
418442
scaling_type_w=TensorScalingType.DELAYED,
419443
scaling_type_dL_dY=TensorScalingType.DELAYED,
444+
static_scale_x=None,
420445
):
421446
"""
422447
Create an nn.Linear with fp8 compute from a regular nn.Linear
@@ -434,6 +459,7 @@ def from_float(
434459
scaling_type_w=scaling_type_w,
435460
scaling_type_dL_dY=scaling_type_dL_dY,
436461
emulate=emulate,
462+
static_scale_x=static_scale_x,
437463
)
438464
if (
439465
scaling_type_w == TensorScalingType.DYNAMIC

float8_experimental/float8_linear_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def swap_linear_with_float8_linear(
154154
scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC,
155155
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
156156
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC,
157+
static_scale_x: Optional[float] = None,
157158
) -> Optional[nn.Module]:
158159
"""
159160
Swaps `torch.nn.Linear` in `module` with `Float8Linear`.
@@ -167,6 +168,7 @@ def swap_linear_with_float8_linear(
167168
scaling_type_x (TensorScalingType): scaling type for `x`
168169
scaling_type_w (TensorScalingType): scaling type for `w`
169170
scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
171+
static_scale_x: static scale for `x`
170172
171173
Returns:
172174
nn.Module: The modified module with swapped linear layers.
@@ -177,6 +179,7 @@ def swap_linear_with_float8_linear(
177179
scaling_type_x=scaling_type_x,
178180
scaling_type_w=scaling_type_w,
179181
scaling_type_dL_dY=scaling_type_dL_dY,
182+
static_scale_x=static_scale_x,
180183
)
181184
return swap_linear_layers(
182185
module,

0 commit comments

Comments
 (0)