Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 6e9a2ae

Browse files
author
Andrew Gu
committed
[POC] Added fp8 all-gather extensions
ghstack-source-id: d1ab08f Pull Request resolved: #216
1 parent 1d45441 commit 6e9a2ae

9 files changed

+1120
-145
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
"""
77
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
88
"""
9+
10+
from typing import Any, cast, Optional, Tuple, Union
11+
912
import torch
13+
import torch.nn as nn
1014

1115
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
1216
from float8_experimental.float8_utils import tensor_to_scale
@@ -73,7 +77,11 @@ def forward(self, x):
7377
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x)
7478

7579
# cast w to float8_e4m3fn
76-
w_fp8 = self.cast_to_float8_e4m3fn(self.weight)
80+
w_fp8 = (
81+
self.weight
82+
if isinstance(self.weight, Float8Tensor)
83+
else self.cast_to_float8_e4m3fn(self.weight)
84+
)
7785

7886
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
7987

@@ -83,8 +91,10 @@ def forward(self, x):
8391

8492
return y
8593

86-
def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
87-
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
94+
def cast_to_float8_e4m3fn(
95+
self, inpt_tensor: torch.Tensor, reduce_amax: bool = False
96+
) -> Float8Tensor:
97+
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
8898
return Float8Tensor.to_float8(
8999
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
90100
)
@@ -94,7 +104,11 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
94104

95105
@classmethod
96106
def from_float(
97-
cls, mod, emulate: bool = False, use_activation_hooks: bool = False
107+
cls,
108+
mod: nn.Module,
109+
emulate: bool = False,
110+
use_activation_hooks: bool = False,
111+
use_fp8_all_gather: bool = False,
98112
) -> "Float8DynamicLinear":
99113
"""
100114
Create an nn.Linear with fp8 compute from a regular nn.Linear
@@ -111,7 +125,11 @@ def from_float(
111125
"bias": False,
112126
}
113127
new_mod = cls(use_activation_hooks, **super_kwargs)
114-
new_mod.weight = mod.weight
128+
new_mod.weight = (
129+
nn.Parameter(Float8DynamicLinearWeightTensor(mod.weight))
130+
if use_fp8_all_gather
131+
else mod.weight
132+
)
115133
new_mod.bias = mod.bias
116134
new_mod.emulate = emulate
117135
if new_mod.use_activation_hooks:
@@ -121,3 +139,32 @@ def from_float(
121139
cast_grad_to_float8_e5m2_backward_forward_hook
122140
)
123141
return new_mod
142+
143+
144+
class Float8DynamicLinearWeightTensor(torch.Tensor):
145+
# TODO: Remove `module` arg, save state on subclass, and propagate it.
146+
def fsdp_pre_all_gather(
147+
self, module: nn.Module
148+
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
149+
float8_tensor = module.cast_to_float8_e4m3fn(self, reduce_amax=True)
150+
return (float8_tensor._data,), (float8_tensor._scale, module.emulate)
151+
152+
def fsdp_post_all_gather(
153+
self,
154+
all_gather_outputs: Tuple[torch.Tensor, ...],
155+
metadata: Any,
156+
param_dtype: torch.dtype,
157+
*,
158+
out: Optional[torch.Tensor] = None,
159+
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]:
160+
(data,) = all_gather_outputs
161+
scale, emulate = metadata
162+
if out is not None:
163+
out = cast(Float8Tensor, out)
164+
assert (
165+
data.untyped_storage().data_ptr()
166+
== out._data.untyped_storage().data_ptr()
167+
)
168+
out._scale = scale
169+
return
170+
return Float8Tensor(data, scale, param_dtype, emulate), (data,)

float8_experimental/float8_linear.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515
import dataclasses
1616

17-
from typing import Optional
17+
from typing import Any, cast, Optional, Tuple, Union
1818

1919
import float8_experimental.config as config
2020

2121
import torch
22+
import torch.nn as nn
2223

2324
from float8_experimental.float8_tensor import Float8Tensor
24-
2525
from float8_experimental.float8_utils import (
2626
amax_history_to_scale,
2727
E4M3_MAX_POS,
@@ -294,7 +294,11 @@ def forward(self, x):
294294
self.float8_pre_forward(x)
295295

296296
x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
297-
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
297+
w_fp8 = (
298+
self.weight
299+
if isinstance(self.weight, Float8Tensor)
300+
else self.cast_w_to_float8(self.weight, self.is_amax_initialized)
301+
)
298302

299303
y = torch.matmul(x_fp8, w_fp8.t())
300304

@@ -308,7 +312,13 @@ def forward(self, x):
308312
return y
309313

310314
@classmethod
311-
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False):
315+
def from_float(
316+
cls,
317+
mod: nn.Module,
318+
emulate: bool = False,
319+
use_activation_hooks: bool = False,
320+
use_fp8_all_gather: bool = False,
321+
):
312322
"""
313323
Create an nn.Linear with fp8 compute from a regular nn.Linear
314324
@@ -322,9 +332,42 @@ def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = Fal
322332
# Tensors and the Linear base to create empty params
323333
# with torch.device("meta"):
324334
new_mod = cls(mod.in_features, mod.out_features, bias=False)
325-
new_mod.weight = mod.weight
335+
new_mod.weight = (
336+
nn.Parameter(Float8LinearWeightTensor(mod.weight))
337+
if use_fp8_all_gather
338+
else mod.weight
339+
)
326340
new_mod.bias = mod.bias
327341
new_mod.emulate = emulate
328342
# I think its okay to send all params and buffers to device
329343
new_mod.to(mod.weight.device)
330344
return new_mod
345+
346+
347+
class Float8LinearWeightTensor(torch.Tensor):
348+
# TODO: Remove `module` arg, save state on subclass, and propagate it.
349+
def fsdp_pre_all_gather(
350+
self, module: nn.Module
351+
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
352+
float8_tensor = module.cast_w_to_float8(self, module.is_amax_initialized)
353+
return (float8_tensor._data,), (float8_tensor._scale, module.emulate)
354+
355+
def fsdp_post_all_gather(
356+
self,
357+
all_gather_outputs: Tuple[torch.Tensor, ...],
358+
metadata: Any,
359+
param_dtype: torch.dtype,
360+
*,
361+
out: Optional[torch.Tensor] = None,
362+
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]:
363+
(data,) = all_gather_outputs
364+
scale, emulate = metadata
365+
if out is not None:
366+
out = cast(Float8Tensor, out)
367+
assert (
368+
data.untyped_storage().data_ptr()
369+
== out._data.untyped_storage().data_ptr()
370+
)
371+
out._scale = scale
372+
return
373+
return Float8Tensor(data, scale, param_dtype, emulate), (data,)

float8_experimental/float8_linear_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def swap_linear_with_float8_linear(
9393
skip_fqn_list: Optional[List[str]] = None,
9494
emulate: bool = False,
9595
use_activation_hooks: bool = False,
96+
use_fp8_all_gather: bool = False,
9697
) -> nn.Module:
9798
"""
9899
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
@@ -105,6 +106,7 @@ def swap_linear_with_float8_linear(
105106
Linear submodules of these skipped modules will also be skipped.
106107
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
107108
use_activation_hooks (bool): Whether to cast activations to fp8 using module hooks.
109+
use_fp8_all_gather (bool): Whether to cast to fp8 before all-gather when using FSDP.
108110
"""
109111
module_names_to_skip = set(skip_fqn_list or [])
110112
if isinstance(module, nn.Linear):
@@ -113,7 +115,10 @@ def swap_linear_with_float8_linear(
113115
f"Does not support a root nn.Linear with children: {module}"
114116
)
115117
return module_cls.from_float(
116-
module, emulate=emulate, use_activation_hooks=use_activation_hooks
118+
module,
119+
emulate=emulate,
120+
use_activation_hooks=use_activation_hooks,
121+
use_fp8_all_gather=use_fp8_all_gather,
117122
)
118123

119124
# Mark all modules to skip as visited
@@ -137,7 +142,10 @@ def post_order_traversal(
137142
parent_module is not None
138143
), f"Linear root module should return early: {module}"
139144
float8linear_module = module_cls.from_float(
140-
module, emulate=emulate, use_activation_hooks=use_activation_hooks
145+
module,
146+
emulate=emulate,
147+
use_activation_hooks=use_activation_hooks,
148+
use_fp8_all_gather=use_fp8_all_gather,
141149
)
142150
setattr(parent_module, module_name, float8linear_module)
143151

float8_experimental/float8_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def tensor_to_amax(x, distributed_reduction=False):
8686

8787

8888
@torch.no_grad()
89-
def tensor_to_scale(x, float8_dtype):
90-
amax = tensor_to_amax(x)
89+
def tensor_to_scale(x, float8_dtype: torch.dtype, distributed_reduction: bool = False):
90+
amax = tensor_to_amax(x, distributed_reduction=distributed_reduction)
9191
return amax_to_scale(amax, float8_dtype, x.dtype)
9292

9393

test/test_everything.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ pytest test/test_compile.py
99
./test/test_fsdp.sh
1010
./test/test_fsdp_compile.sh
1111
./test/test_tp.sh
12-
pytest test/test_fsdp/test_flat_param_fsdp_compile.py
12+
pytest test/test_fsdp/*
1313

1414
echo "all tests successful"

0 commit comments

Comments
 (0)