Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b7584ba

Browse files
author
Andrew Gu
committed
Removed module arg from fsdp_pre_all_gather
ghstack-source-id: aa0fe4a Pull Request resolved: #217
1 parent 61d6611 commit b7584ba

File tree

2 files changed

+80
-29
lines changed

2 files changed

+80
-29
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
88
"""
99

10-
from typing import Any, cast, Optional, Tuple, Union
10+
import functools
11+
from typing import Any, Callable, cast, Optional, Tuple, Union
1112

1213
import torch
1314
import torch.nn as nn
1415

1516
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
1617
from float8_experimental.float8_utils import tensor_to_scale
1718

19+
from torch.utils._pytree import tree_map
20+
1821

1922
@torch._dynamo.allow_in_graph
2023
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
@@ -125,11 +128,13 @@ def from_float(
125128
"bias": False,
126129
}
127130
new_mod = cls(use_activation_hooks, **super_kwargs)
128-
new_mod.weight = (
129-
nn.Parameter(Float8DynamicLinearWeightTensor(mod.weight))
130-
if use_fp8_all_gather
131-
else mod.weight
132-
)
131+
if use_fp8_all_gather:
132+
cast_fn = new_mod.cast_to_float8_e4m3fn
133+
new_mod.weight = nn.Parameter(
134+
Float8DynamicLinearWeightTensor(mod.weight, cast_fn, emulate)
135+
)
136+
else:
137+
new_mod.weight = mod.weight
133138
new_mod.bias = mod.bias
134139
new_mod.emulate = emulate
135140
if new_mod.use_activation_hooks:
@@ -142,12 +147,34 @@ def from_float(
142147

143148

144149
class Float8DynamicLinearWeightTensor(torch.Tensor):
145-
# TODO: Remove `module` arg, save state on subclass, and propagate it.
146-
def fsdp_pre_all_gather(
147-
self, module: nn.Module
148-
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
149-
float8_tensor = module.cast_to_float8_e4m3fn(self, reduce_amax=True)
150-
return (float8_tensor._data,), (float8_tensor._scale, module.emulate)
150+
def __new__(cls, tensor: torch.Tensor, cast_fn: Callable, emulate: bool):
151+
return cls._make_subclass(cls, tensor, tensor.requires_grad)
152+
153+
def __init__(self, tensor: torch.Tensor, cast_fn: Callable, emulate: bool):
154+
super().__init__()
155+
self.cast_fn = cast_fn
156+
self.emulate = emulate
157+
158+
@classmethod
159+
def __torch_function__(cls, func, types, args=(), kwargs=None):
160+
kwargs = kwargs or {}
161+
162+
def wrap(cast_fn: Callable, emulate: bool, o: Any):
163+
if isinstance(o, torch.Tensor) and not isinstance(o, cls):
164+
return cls(o, cast_fn, emulate)
165+
return o
166+
167+
with torch._C.DisableTorchFunctionSubclass():
168+
if isinstance(args[0], cls):
169+
out = func(*args, **kwargs)
170+
return tree_map(
171+
functools.partial(wrap, args[0].cast_fn, args[0].emulate), out
172+
)
173+
return func(*args, **kwargs)
174+
175+
def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]:
176+
float8_tensor = self.cast_fn(self, reduce_amax=True)
177+
return (float8_tensor._data,), (float8_tensor._scale,)
151178

152179
def fsdp_post_all_gather(
153180
self,
@@ -158,7 +185,7 @@ def fsdp_post_all_gather(
158185
out: Optional[torch.Tensor] = None,
159186
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]:
160187
(data,) = all_gather_outputs
161-
scale, emulate = metadata
188+
(scale,) = metadata
162189
if out is not None:
163190
out = cast(Float8Tensor, out)
164191
assert (
@@ -167,4 +194,4 @@ def fsdp_post_all_gather(
167194
)
168195
out._scale = scale
169196
return
170-
return Float8Tensor(data, scale, param_dtype, emulate), (data,)
197+
return Float8Tensor(data, scale, param_dtype, self.emulate), (data,)

float8_experimental/float8_linear.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
"""
1414

1515
import dataclasses
16+
import functools
1617

17-
from typing import Any, cast, Optional, Tuple, Union
18+
from typing import Any, Callable, cast, Optional, Tuple, Union
1819

1920
import float8_experimental.config as config
2021

@@ -29,6 +30,7 @@
2930
tensor_to_amax,
3031
to_fp8_saturated,
3132
)
33+
from torch.utils._pytree import tree_map
3234

3335

3436
def _maybe_initialize_amaxes_scales_for_float8_cast(
@@ -221,9 +223,7 @@ def cast_x_to_float8(
221223
)
222224
return x_fp8
223225

224-
def cast_w_to_float8(
225-
self, w: torch.Tensor, is_amax_initialized: bool
226-
) -> torch.Tensor:
226+
def cast_w_to_float8(self, w: torch.Tensor) -> torch.Tensor:
227227
scale_fn_name = self.recipe.scale_fn_name
228228
_maybe_initialize_amaxes_scales_for_float8_cast(
229229
w,
@@ -232,7 +232,7 @@ def cast_w_to_float8(
232232
self.fp8_scale_w,
233233
scale_fn_name,
234234
torch.float8_e4m3fn,
235-
is_amax_initialized,
235+
self.is_amax_initialized,
236236
)
237237

238238
w_fp8 = Float8Tensor.to_float8(
@@ -296,7 +296,7 @@ def forward(self, x):
296296
w_fp8 = (
297297
self.weight
298298
if isinstance(self.weight, Float8Tensor)
299-
else self.cast_w_to_float8(self.weight, self.is_amax_initialized)
299+
else self.cast_w_to_float8(self.weight)
300300
)
301301

302302
y = torch.matmul(x_fp8, w_fp8.t())
@@ -332,7 +332,9 @@ def from_float(
332332
# with torch.device("meta"):
333333
new_mod = cls(mod.in_features, mod.out_features, bias=False)
334334
new_mod.weight = (
335-
nn.Parameter(Float8LinearWeightTensor(mod.weight))
335+
nn.Parameter(
336+
Float8LinearWeightTensor(mod.weight, new_mod.cast_w_to_float8, emulate)
337+
)
336338
if use_fp8_all_gather
337339
else mod.weight
338340
)
@@ -344,12 +346,34 @@ def from_float(
344346

345347

346348
class Float8LinearWeightTensor(torch.Tensor):
347-
# TODO: Remove `module` arg, save state on subclass, and propagate it.
348-
def fsdp_pre_all_gather(
349-
self, module: nn.Module
350-
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
351-
float8_tensor = module.cast_w_to_float8(self, module.is_amax_initialized)
352-
return (float8_tensor._data,), (float8_tensor._scale, module.emulate)
349+
def __new__(cls, tensor: torch.Tensor, cast_fn: Callable, emulate: bool):
350+
return cls._make_subclass(cls, tensor, tensor.requires_grad)
351+
352+
def __init__(self, tensor: torch.Tensor, cast_fn: Callable, emulate: bool):
353+
super().__init__()
354+
self.cast_fn = cast_fn
355+
self.emulate = emulate
356+
357+
@classmethod
358+
def __torch_function__(cls, func, types, args=(), kwargs=None):
359+
kwargs = kwargs or {}
360+
361+
def wrap(cast_fn: Callable, emulate: bool, o: Any):
362+
if isinstance(o, torch.Tensor) and not isinstance(o, cls):
363+
return cls(o, cast_fn, emulate)
364+
return o
365+
366+
with torch._C.DisableTorchFunctionSubclass():
367+
if isinstance(args[0], cls):
368+
out = func(*args, **kwargs)
369+
return tree_map(
370+
functools.partial(wrap, args[0].cast_fn, args[0].emulate), out
371+
)
372+
return func(*args, **kwargs)
373+
374+
def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]:
375+
float8_tensor = self.cast_fn(self)
376+
return (float8_tensor._data,), (float8_tensor._scale,)
353377

354378
def fsdp_post_all_gather(
355379
self,
@@ -360,7 +384,7 @@ def fsdp_post_all_gather(
360384
out: Optional[torch.Tensor] = None,
361385
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]:
362386
(data,) = all_gather_outputs
363-
scale, emulate = metadata
387+
(scale,) = metadata
364388
if out is not None:
365389
out = cast(Float8Tensor, out)
366390
assert (
@@ -369,4 +393,4 @@ def fsdp_post_all_gather(
369393
)
370394
out._scale = scale
371395
return
372-
return Float8Tensor(data, scale, param_dtype, emulate), (data,)
396+
return Float8Tensor(data, scale, param_dtype, self.emulate), (data,)

0 commit comments

Comments
 (0)