Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 1194661

Browse files
committed
[1/x]: Make Float8Linear support dynamic scaling
Summary: At a high level, we need to make dynamic vs delayed scaling configurable separately for activations, weights and gradients. The way I am approaching this is as follows: * PR 1 (this PR): add basic support for dynamic scaling, configurable by tensor, to `Float8Linear` * PRs 2..n: one by one, add features implemented in `Float8DynamicLinear` to `Float8Linear`, as necessary * last PR: delete `Float8DynamicLinear` Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: bb1bfe5 Pull Request resolved: #290
1 parent 0b60496 commit 1194661

File tree

6 files changed

+438
-132
lines changed

6 files changed

+438
-132
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def __init__(self, **super_kwargs):
6363
super().__init__(**super_kwargs)
6464

6565
def forward(self, x):
66-
x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
66+
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
6767
if isinstance(self.weight, Float8Tensor): # cast by FSDP
6868
w_fp8 = self.weight
6969
else:
70-
w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config)
70+
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
7171
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
72-
y = cast_to_float8_e5m2_bw(y, self.backward_config)
72+
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
7373
return y
7474

7575
@classmethod
@@ -111,7 +111,7 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
111111
return new_mod
112112

113113

114-
def cast_to_float8_e4m3fn(
114+
def cast_to_float8_e4m3_dynamic(
115115
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
116116
) -> Float8Tensor:
117117
if tensor_already_casted_to_fp8(inpt_tensor):
@@ -120,7 +120,7 @@ def cast_to_float8_e4m3fn(
120120
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
121121

122122

123-
def cast_to_float8_e5m2_bw(
123+
def cast_to_float8_e5m2_dynamic_bw(
124124
gradY: torch.Tensor, mm_config: ScaledMMConfig
125125
) -> torch.Tensor:
126126
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
@@ -199,7 +199,7 @@ def __repr__(self):
199199
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
200200

201201
def fsdp_pre_all_gather(self, mesh):
202-
float8_tensor = cast_to_float8_e4m3fn(
202+
float8_tensor = cast_to_float8_e4m3_dynamic(
203203
self._tensor, self._mm_config, reduce_amax=True
204204
)
205205
return (float8_tensor._data,), (float8_tensor._scale,)

float8_experimental/float8_linear.py

Lines changed: 167 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@
88
"""
99

1010
import dataclasses
11+
import enum
1112

1213
from typing import Optional
1314

1415
import float8_experimental.config as config
1516

1617
import torch
1718

19+
from float8_experimental.float8_dynamic_linear import (
20+
cast_to_float8_e4m3_dynamic,
21+
cast_to_float8_e5m2_dynamic_bw,
22+
)
23+
1824
from float8_experimental.float8_tensor import (
1925
Float8Tensor,
2026
ScaledMMConfig,
@@ -125,20 +131,54 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
125131
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
126132

127133

134+
class TensorScalingType(enum.Enum):
135+
DELAYED = "delayed"
136+
DYNAMIC = "dynamic"
137+
138+
def short_str(self):
139+
if self is TensorScalingType.DELAYED:
140+
return "del"
141+
else:
142+
assert self is TensorScalingType.DYNAMIC
143+
return "dyn"
144+
145+
128146
class Float8Linear(torch.nn.Linear):
129147
"""
130148
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
131149
scales in way friendly to delayed scaling.
132150
"""
133151

134152
def __init__(self, *args, **kwargs):
153+
"""
154+
Additional arguments on top of `torch.nn.Linear`'s arguments:
155+
* `delayed_scaling_recipe`: configuration for delayed scaling
156+
* `scaling_type_x`: delayed vs dynamic scaling for `x`
157+
* `scaling_type_w`: delayed vs dynamic scaling for `w`
158+
* `scaling_type_dL_dY`: delayed vs dynamic scaling for `dL_dY`
159+
"""
160+
135161
delayed_scaling_recipe = kwargs.pop(
136162
"delayed_scaling_recipe", DelayedScalingRecipe()
137163
)
138164
# Amax scales should always be kept as float32.
139165
self.always_float32_buffers = set()
166+
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED)
167+
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED)
168+
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED)
140169
super().__init__(*args, **kwargs)
141170

171+
# Defines the scaling behavior of x, w, dL_dY
172+
self.scaling_type_x = scaling_type_x
173+
self.scaling_type_w = scaling_type_w
174+
self.scaling_type_dL_dY = scaling_type_dL_dY
175+
# Convenience flag to skip code related to delayed scaling
176+
self.has_any_delayed_scaling = (
177+
self.scaling_type_x is TensorScalingType.DELAYED
178+
or self.scaling_type_w is TensorScalingType.DELAYED
179+
or self.scaling_type_dL_dY is TensorScalingType.DELAYED
180+
)
181+
142182
# TODO(future): have a unique recipe per buffer instead of one per
143183
# module, saving implementing that until we need it.
144184
# TODO(future): serialization for recipes
@@ -175,37 +215,44 @@ def create_buffers(self):
175215
# Default values for history buffers, see above TODO
176216
history_len = self.recipe.history_len
177217
device = self.weight.device
218+
# TODO(future PR): dtype values below don't have the other float8
219+
# flavors, fix it
178220
default_x = torch.finfo(torch.float8_e4m3fn).max
179221
default_w = torch.finfo(torch.float8_e4m3fn).max
180222
default_dl_dy = torch.finfo(torch.float8_e5m2).max
181223

182-
self.register_always_float32_buffer(
183-
"fp8_amax_x", torch.tensor([default_x], device=device)
184-
)
185-
self.register_always_float32_buffer(
186-
"fp8_amax_history_x", torch.zeros(history_len, device=device)
187-
)
188-
self.register_always_float32_buffer(
189-
"fp8_scale_x", torch.tensor([1.0], device=device)
190-
)
191-
self.register_always_float32_buffer(
192-
"fp8_amax_w", torch.tensor([default_w], device=device)
193-
)
194-
self.register_always_float32_buffer(
195-
"fp8_amax_history_w", torch.zeros(history_len, device=device)
196-
)
197-
self.register_always_float32_buffer(
198-
"fp8_scale_w", torch.tensor([1.0], device=device)
199-
)
200-
self.register_always_float32_buffer(
201-
"fp8_amax_dL_dY", torch.tensor([default_dl_dy], device=device)
202-
)
203-
self.register_always_float32_buffer(
204-
"fp8_amax_history_dL_dY", torch.zeros(history_len, device=device)
205-
)
206-
self.register_always_float32_buffer(
207-
"fp8_scale_dL_dY", torch.tensor([1.0], device=device)
208-
)
224+
# Note: for now, create all the buffers if any are needed, to postpone
225+
# the work to make the scale and amax syncing and history calculation
226+
# handle a heterogeneous setup. We can do that work later if benchmarks
227+
# show it is worth doing.
228+
if self.has_any_delayed_scaling:
229+
self.register_always_float32_buffer(
230+
"fp8_amax_x", torch.tensor([default_x], device=device)
231+
)
232+
self.register_always_float32_buffer(
233+
"fp8_amax_history_x", torch.zeros(history_len, device=device)
234+
)
235+
self.register_always_float32_buffer(
236+
"fp8_scale_x", torch.tensor([1.0], device=device)
237+
)
238+
self.register_always_float32_buffer(
239+
"fp8_amax_w", torch.tensor([default_w], device=device)
240+
)
241+
self.register_always_float32_buffer(
242+
"fp8_amax_history_w", torch.zeros(history_len, device=device)
243+
)
244+
self.register_always_float32_buffer(
245+
"fp8_scale_w", torch.tensor([1.0], device=device)
246+
)
247+
self.register_always_float32_buffer(
248+
"fp8_amax_dL_dY", torch.tensor([default_dl_dy], device=device)
249+
)
250+
self.register_always_float32_buffer(
251+
"fp8_amax_history_dL_dY", torch.zeros(history_len, device=device)
252+
)
253+
self.register_always_float32_buffer(
254+
"fp8_scale_dL_dY", torch.tensor([1.0], device=device)
255+
)
209256

210257
def register_always_float32_buffer(
211258
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
@@ -234,61 +281,77 @@ def cast_x_to_float8(
234281
autocast_dtype = torch.get_autocast_gpu_dtype()
235282
x = x.to(autocast_dtype)
236283

237-
scale_fn_name = self.recipe.scale_fn_name
238-
_maybe_initialize_amaxes_scales_for_float8_cast(
239-
x,
240-
self.fp8_amax_x,
241-
self.fp8_amax_history_x,
242-
self.fp8_scale_x,
243-
scale_fn_name,
244-
e4m3_dtype,
245-
is_amax_initialized,
246-
reduce_amax=True,
247-
)
248-
x_fp8 = Float8Tensor.to_float8(
249-
x,
250-
self.fp8_scale_x,
251-
e4m3_dtype,
252-
self.fp8_amax_x,
253-
self.forward_config,
254-
)
284+
if self.scaling_type_x is TensorScalingType.DELAYED:
285+
scale_fn_name = self.recipe.scale_fn_name
286+
_maybe_initialize_amaxes_scales_for_float8_cast(
287+
x,
288+
self.fp8_amax_x,
289+
self.fp8_amax_history_x,
290+
self.fp8_scale_x,
291+
scale_fn_name,
292+
e4m3_dtype,
293+
is_amax_initialized,
294+
reduce_amax=True,
295+
)
296+
x_fp8 = Float8Tensor.to_float8(
297+
x,
298+
self.fp8_scale_x,
299+
e4m3_dtype,
300+
self.fp8_amax_x,
301+
self.forward_config,
302+
)
303+
else:
304+
assert self.scaling_type_x is TensorScalingType.DYNAMIC
305+
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
255306
return x_fp8
256307

257308
def cast_w_to_float8(
258309
self, w: torch.Tensor, is_amax_initialized: bool
259310
) -> torch.Tensor:
260-
scale_fn_name = self.recipe.scale_fn_name
261-
_maybe_initialize_amaxes_scales_for_float8_cast(
262-
w,
263-
self.fp8_amax_w,
264-
self.fp8_amax_history_w,
265-
self.fp8_scale_w,
266-
scale_fn_name,
267-
e4m3_dtype,
268-
is_amax_initialized,
269-
reduce_amax=False,
270-
)
311+
if self.scaling_type_w is TensorScalingType.DELAYED:
312+
scale_fn_name = self.recipe.scale_fn_name
313+
_maybe_initialize_amaxes_scales_for_float8_cast(
314+
w,
315+
self.fp8_amax_w,
316+
self.fp8_amax_history_w,
317+
self.fp8_scale_w,
318+
scale_fn_name,
319+
e4m3_dtype,
320+
is_amax_initialized,
321+
reduce_amax=False,
322+
)
271323

272-
w_fp8 = Float8Tensor.to_float8(
273-
w,
274-
self.fp8_scale_w,
275-
e4m3_dtype,
276-
self.fp8_amax_w,
277-
self.forward_config,
278-
)
324+
w_fp8 = Float8Tensor.to_float8(
325+
w,
326+
self.fp8_scale_w,
327+
e4m3_dtype,
328+
self.fp8_amax_w,
329+
self.forward_config,
330+
)
331+
else:
332+
assert self.scaling_type_w is TensorScalingType.DYNAMIC
333+
# TODO(future): also support FSDP integration in delayed scaling path
334+
if isinstance(self.weight, Float8Tensor): # cast by FSDP
335+
w_fp8 = self.weight
336+
else:
337+
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
279338
return w_fp8
280339

281340
def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
282-
scale_fn_name = self.recipe.scale_fn_name
283-
y = NoopFwToFloat8E5M2Bw.apply(
284-
y,
285-
self.fp8_amax_dL_dY,
286-
self.fp8_amax_history_dL_dY,
287-
self.fp8_scale_dL_dY,
288-
scale_fn_name,
289-
self.is_amax_initialized,
290-
self.backward_config,
291-
)
341+
if self.scaling_type_dL_dY is TensorScalingType.DELAYED:
342+
scale_fn_name = self.recipe.scale_fn_name
343+
y = NoopFwToFloat8E5M2Bw.apply(
344+
y,
345+
self.fp8_amax_dL_dY,
346+
self.fp8_amax_history_dL_dY,
347+
self.fp8_scale_dL_dY,
348+
scale_fn_name,
349+
self.is_amax_initialized,
350+
self.backward_config,
351+
)
352+
else:
353+
assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC
354+
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
292355
return y
293356

294357
def float8_pre_forward(self, x):
@@ -313,7 +376,8 @@ def float8_post_forward(self):
313376
self.amax_and_scale_synced = False
314377

315378
def forward(self, x):
316-
self.float8_pre_forward(x)
379+
if self.has_any_delayed_scaling:
380+
self.float8_pre_forward(x)
317381

318382
x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
319383
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
@@ -326,11 +390,29 @@ def forward(self, x):
326390
if self.bias is not None:
327391
y = y + self.bias.to(y.dtype)
328392

329-
self.float8_post_forward()
393+
if self.has_any_delayed_scaling:
394+
self.float8_post_forward()
330395
return y
331396

397+
def extra_repr(self):
398+
# example: in_features=32, out_features=16, bias=True
399+
s = super().extra_repr()
400+
# add scaling settings without using too many characters
401+
scaling = f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"
402+
403+
s = f'{s}, scaling="{scaling}"'
404+
# example: in_features=32, out_features=16, bias=True, scaling="x:del,w:del,dldy:dyn"
405+
return s
406+
332407
@classmethod
333-
def from_float(cls, mod, emulate: bool = False):
408+
def from_float(
409+
cls,
410+
mod,
411+
emulate: bool = False,
412+
scaling_type_x=TensorScalingType.DELAYED,
413+
scaling_type_w=TensorScalingType.DELAYED,
414+
scaling_type_dL_dY=TensorScalingType.DELAYED,
415+
):
334416
"""
335417
Create an nn.Linear with fp8 compute from a regular nn.Linear
336418
@@ -339,14 +421,22 @@ def from_float(cls, mod, emulate: bool = False):
339421
emulate (bool): whether to emulate fp8 matmul logic in float32
340422
"""
341423
with torch.device("meta"):
342-
new_mod = cls(mod.in_features, mod.out_features, bias=False)
424+
new_mod = cls(
425+
mod.in_features,
426+
mod.out_features,
427+
bias=False,
428+
scaling_type_x=scaling_type_x,
429+
scaling_type_w=scaling_type_w,
430+
scaling_type_dL_dY=scaling_type_dL_dY,
431+
)
343432
new_mod.weight = mod.weight
344433
new_mod.bias = mod.bias
345434
# need to create buffers again when moving from meta device to
346435
# real device
347436
new_mod.create_buffers()
348437
# Defines the behavior of the matmul in the forward and backward
349438
# Forward we use fast_accum, backwards we do not
439+
# TODO(future PR): move below to the constructor
350440
new_mod.forward_config = ScaledMMConfig(
351441
emulate, True if not emulate else False, False, config.pad_inner_dim
352442
)

0 commit comments

Comments
 (0)