Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 61d6611

Browse files
author
Andrew Gu
committed
[POC] Added fp8 all-gather extensions
ghstack-source-id: ff0433d Pull Request resolved: #216
1 parent de10c25 commit 61d6611

9 files changed

+1117
-131
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
"""
77
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
88
"""
9+
10+
from typing import Any, cast, Optional, Tuple, Union
11+
912
import torch
13+
import torch.nn as nn
1014

1115
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
1216
from float8_experimental.float8_utils import tensor_to_scale
@@ -73,7 +77,11 @@ def forward(self, x):
7377
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x)
7478

7579
# cast w to float8_e4m3fn
76-
w_fp8 = self.cast_to_float8_e4m3fn(self.weight)
80+
w_fp8 = (
81+
self.weight
82+
if isinstance(self.weight, Float8Tensor)
83+
else self.cast_to_float8_e4m3fn(self.weight)
84+
)
7785

7886
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
7987

@@ -83,8 +91,10 @@ def forward(self, x):
8391

8492
return y
8593

86-
def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
87-
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
94+
def cast_to_float8_e4m3fn(
95+
self, inpt_tensor: torch.Tensor, reduce_amax: bool = False
96+
) -> Float8Tensor:
97+
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
8898
return Float8Tensor.to_float8(
8999
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
90100
)
@@ -94,7 +104,11 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
94104

95105
@classmethod
96106
def from_float(
97-
cls, mod, emulate: bool = False, use_activation_hooks: bool = False
107+
cls,
108+
mod: nn.Module,
109+
emulate: bool = False,
110+
use_activation_hooks: bool = False,
111+
use_fp8_all_gather: bool = False,
98112
) -> "Float8DynamicLinear":
99113
"""
100114
Create an nn.Linear with fp8 compute from a regular nn.Linear
@@ -111,7 +125,11 @@ def from_float(
111125
"bias": False,
112126
}
113127
new_mod = cls(use_activation_hooks, **super_kwargs)
114-
new_mod.weight = mod.weight
128+
new_mod.weight = (
129+
nn.Parameter(Float8DynamicLinearWeightTensor(mod.weight))
130+
if use_fp8_all_gather
131+
else mod.weight
132+
)
115133
new_mod.bias = mod.bias
116134
new_mod.emulate = emulate
117135
if new_mod.use_activation_hooks:
@@ -121,3 +139,32 @@ def from_float(
121139
cast_grad_to_float8_e5m2_backward_forward_hook
122140
)
123141
return new_mod
142+
143+
144+
class Float8DynamicLinearWeightTensor(torch.Tensor):
145+
# TODO: Remove `module` arg, save state on subclass, and propagate it.
146+
def fsdp_pre_all_gather(
147+
self, module: nn.Module
148+
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
149+
float8_tensor = module.cast_to_float8_e4m3fn(self, reduce_amax=True)
150+
return (float8_tensor._data,), (float8_tensor._scale, module.emulate)
151+
152+
def fsdp_post_all_gather(
153+
self,
154+
all_gather_outputs: Tuple[torch.Tensor, ...],
155+
metadata: Any,
156+
param_dtype: torch.dtype,
157+
*,
158+
out: Optional[torch.Tensor] = None,
159+
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]:
160+
(data,) = all_gather_outputs
161+
scale, emulate = metadata
162+
if out is not None:
163+
out = cast(Float8Tensor, out)
164+
assert (
165+
data.untyped_storage().data_ptr()
166+
== out._data.untyped_storage().data_ptr()
167+
)
168+
out._scale = scale
169+
return
170+
return Float8Tensor(data, scale, param_dtype, emulate), (data,)

float8_experimental/float8_linear.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515
import dataclasses
1616

17-
from typing import Optional
17+
from typing import Any, cast, Optional, Tuple, Union
1818

1919
import float8_experimental.config as config
2020

2121
import torch
22+
import torch.nn as nn
2223

2324
from float8_experimental.float8_tensor import Float8Tensor
24-
2525
from float8_experimental.float8_utils import (
2626
amax_history_to_scale,
2727
E4M3_MAX_POS,
@@ -293,7 +293,11 @@ def forward(self, x):
293293
self.float8_pre_forward(x)
294294

295295
x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
296-
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
296+
w_fp8 = (
297+
self.weight
298+
if isinstance(self.weight, Float8Tensor)
299+
else self.cast_w_to_float8(self.weight, self.is_amax_initialized)
300+
)
297301

298302
y = torch.matmul(x_fp8, w_fp8.t())
299303

@@ -307,7 +311,13 @@ def forward(self, x):
307311
return y
308312

309313
@classmethod
310-
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False):
314+
def from_float(
315+
cls,
316+
mod: nn.Module,
317+
emulate: bool = False,
318+
use_activation_hooks: bool = False,
319+
use_fp8_all_gather: bool = False,
320+
):
311321
"""
312322
Create an nn.Linear with fp8 compute from a regular nn.Linear
313323
@@ -321,9 +331,42 @@ def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = Fal
321331
# Tensors and the Linear base to create empty params
322332
# with torch.device("meta"):
323333
new_mod = cls(mod.in_features, mod.out_features, bias=False)
324-
new_mod.weight = mod.weight
334+
new_mod.weight = (
335+
nn.Parameter(Float8LinearWeightTensor(mod.weight))
336+
if use_fp8_all_gather
337+
else mod.weight
338+
)
325339
new_mod.bias = mod.bias
326340
new_mod.emulate = emulate
327341
# I think its okay to send all params and buffers to device
328342
new_mod.to(mod.weight.device)
329343
return new_mod
344+
345+
346+
class Float8LinearWeightTensor(torch.Tensor):
347+
# TODO: Remove `module` arg, save state on subclass, and propagate it.
348+
def fsdp_pre_all_gather(
349+
self, module: nn.Module
350+
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
351+
float8_tensor = module.cast_w_to_float8(self, module.is_amax_initialized)
352+
return (float8_tensor._data,), (float8_tensor._scale, module.emulate)
353+
354+
def fsdp_post_all_gather(
355+
self,
356+
all_gather_outputs: Tuple[torch.Tensor, ...],
357+
metadata: Any,
358+
param_dtype: torch.dtype,
359+
*,
360+
out: Optional[torch.Tensor] = None,
361+
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]:
362+
(data,) = all_gather_outputs
363+
scale, emulate = metadata
364+
if out is not None:
365+
out = cast(Float8Tensor, out)
366+
assert (
367+
data.untyped_storage().data_ptr()
368+
== out._data.untyped_storage().data_ptr()
369+
)
370+
out._scale = scale
371+
return
372+
return Float8Tensor(data, scale, param_dtype, emulate), (data,)

float8_experimental/float8_linear_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def swap_linear_with_float8_linear(
7474
skip_fqn_list: Optional[List[str]] = None,
7575
emulate: bool = False,
7676
use_activation_hooks: bool = False,
77+
use_fp8_all_gather: bool = False,
7778
) -> nn.Module:
7879
"""
7980
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
@@ -86,6 +87,7 @@ def swap_linear_with_float8_linear(
8687
Linear submodules of these skipped modules will also be skipped.
8788
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
8889
use_activation_hooks (bool): Whether to cast activations to fp8 using module hooks.
90+
use_fp8_all_gather (bool): Whether to cast to fp8 before all-gather when using FSDP.
8991
"""
9092
module_names_to_skip = set(skip_fqn_list or [])
9193
if isinstance(module, nn.Linear):
@@ -94,7 +96,10 @@ def swap_linear_with_float8_linear(
9496
f"Does not support a root nn.Linear with children: {module}"
9597
)
9698
return module_cls.from_float(
97-
module, emulate=emulate, use_activation_hooks=use_activation_hooks
99+
module,
100+
emulate=emulate,
101+
use_activation_hooks=use_activation_hooks,
102+
use_fp8_all_gather=use_fp8_all_gather,
98103
)
99104

100105
# Mark all modules to skip as visited
@@ -118,7 +123,10 @@ def post_order_traversal(
118123
parent_module is not None
119124
), f"Linear root module should return early: {module}"
120125
float8linear_module = module_cls.from_float(
121-
module, emulate=emulate, use_activation_hooks=use_activation_hooks
126+
module,
127+
emulate=emulate,
128+
use_activation_hooks=use_activation_hooks,
129+
use_fp8_all_gather=use_fp8_all_gather,
122130
)
123131
setattr(parent_module, module_name, float8linear_module)
124132

float8_experimental/float8_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def tensor_to_amax(x, distributed_reduction=False):
6565

6666

6767
@torch.no_grad()
68-
def tensor_to_scale(x, float8_dtype):
69-
amax = tensor_to_amax(x)
68+
def tensor_to_scale(x, float8_dtype: torch.dtype, distributed_reduction: bool = False):
69+
amax = tensor_to_amax(x, distributed_reduction=distributed_reduction)
7070
return amax_to_scale(amax, float8_dtype, x.dtype)
7171

7272

test/test_everything.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ pytest test/test_compile.py
99
./test/test_fsdp.sh
1010
./test/test_fsdp_compile.sh
1111
./test/test_tp.sh
12-
pytest test/test_fsdp/test_flat_param_fsdp_compile.py
12+
pytest test/test_fsdp/*
1313

1414
echo "all tests successful"

0 commit comments

Comments
 (0)