Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 4d6ca7d

Browse files
committed
move WeightWithDynamicFloat8CastTensor to fsdp_utils.py
Summary: Refactor in preparation of adding delayed scaling support to float8 all-gather, no logic change. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ab3f17e Pull Request resolved: #310
1 parent 5cd60c8 commit 4d6ca7d

File tree

4 files changed

+145
-135
lines changed

4 files changed

+145
-135
lines changed

float8_experimental/float8_dynamic_utils.py

Lines changed: 0 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,18 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
"""
7-
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
8-
"""
96

107
from typing import Any, Optional, Tuple
118

129
import torch
13-
import torch.utils._pytree as pytree
1410

1511
from float8_experimental.float8_tensor import (
1612
Float8Tensor,
17-
merge_mm_configs,
1813
ScaledMMConfig,
1914
tensor_already_casted_to_fp8,
2015
to_fp8_no_autograd,
2116
)
2217
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale
23-
from torch._prims_common import suggest_memory_format
2418

2519

2620
@torch._dynamo.allow_in_graph
@@ -63,127 +57,3 @@ def cast_to_float8_e5m2_dynamic_bw(
6357
gradY: torch.Tensor, mm_config: ScaledMMConfig
6458
) -> torch.Tensor:
6559
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
66-
67-
68-
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
69-
# that the padded local tensor (and any transformations like copying to GPU)
70-
# is of the subclass as well.
71-
_ops_to_preserve_subclass = {
72-
torch.ops.aten.empty_like.default,
73-
torch.ops.aten.new_zeros.default,
74-
torch.ops.aten.slice.Tensor,
75-
torch.ops.aten.copy_.default,
76-
torch.ops.aten.view.default,
77-
torch.ops.aten.as_strided.default,
78-
torch.ops.aten._to_copy.default,
79-
torch.ops.aten._pin_memory.default,
80-
}
81-
82-
83-
class WeightWithDynamicFloat8CastTensor(torch.Tensor):
84-
@staticmethod
85-
def __new__(
86-
cls,
87-
tensor: torch.Tensor,
88-
mm_config: ScaledMMConfig,
89-
precomputed_scale: Optional[torch.Tensor] = None,
90-
):
91-
return torch.Tensor._make_wrapper_subclass(
92-
cls,
93-
tensor.size(),
94-
strides=tensor.stride(),
95-
storage_offset=tensor.storage_offset(),
96-
memory_format=suggest_memory_format(tensor),
97-
dtype=tensor.dtype,
98-
layout=tensor.layout,
99-
device=tensor.device,
100-
pin_memory=tensor.is_pinned(),
101-
requires_grad=tensor.requires_grad,
102-
)
103-
104-
def __init__(
105-
self,
106-
tensor: torch.Tensor,
107-
mm_config: ScaledMMConfig,
108-
precomputed_scale: Optional[torch.Tensor] = None,
109-
):
110-
self._tensor = tensor
111-
self._mm_config = mm_config
112-
# for dynamic scaling
113-
# `precompute_float8_dynamic_scale_for_fsdp` calculates scales
114-
# for all float8 parameters after optimizer step
115-
self._precomputed_scale = precomputed_scale
116-
117-
@classmethod
118-
def __torch_dispatch__(cls, func, types, args, kwargs=None):
119-
if func == torch.ops.aten.detach.default:
120-
return WeightWithDynamicFloat8CastTensor(
121-
args[0]._tensor, args[0]._mm_config
122-
)
123-
mm_config: Optional[ScaledMMConfig] = None
124-
125-
def unwrap(t):
126-
nonlocal mm_config
127-
if mm_config is None:
128-
mm_config = t._mm_config
129-
else:
130-
mm_config = merge_mm_configs(mm_config, t._mm_config)
131-
return t._tensor
132-
133-
args, kwargs = pytree.tree_map_only(
134-
WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {})
135-
)
136-
out = func(*args, **kwargs)
137-
if func not in _ops_to_preserve_subclass:
138-
return out
139-
return pytree.tree_map_only(
140-
torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out
141-
)
142-
143-
def __tensor_flatten__(self):
144-
if self._precomputed_scale:
145-
return ["_tensor", "_precomputed_scale"], self._mm_config
146-
else:
147-
return ["_tensor"], self._mm_config
148-
149-
@staticmethod
150-
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
151-
mm_config = flatten_spec
152-
return WeightWithDynamicFloat8CastTensor(
153-
inner_tensors["_tensor"],
154-
mm_config,
155-
getattr(inner_tensors, "_precomputed_scale", None),
156-
)
157-
158-
def __repr__(self):
159-
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
160-
161-
def fsdp_pre_all_gather(self, mesh):
162-
if self._precomputed_scale is not None:
163-
float8_tensor = Float8Tensor.to_float8(
164-
self._tensor,
165-
self._precomputed_scale,
166-
torch.float8_e4m3fn,
167-
mm_config=self._mm_config,
168-
)
169-
else:
170-
float8_tensor = cast_to_float8_e4m3_dynamic(
171-
self._tensor, self._mm_config, reduce_amax=True
172-
)
173-
return (float8_tensor._data,), (float8_tensor._scale,)
174-
175-
def fsdp_post_all_gather(
176-
self,
177-
all_gather_outputs: Tuple[torch.Tensor, ...],
178-
metadata: Any,
179-
param_dtype: torch.dtype,
180-
*,
181-
out: Optional[torch.Tensor] = None,
182-
):
183-
(data,) = all_gather_outputs
184-
(scale,) = metadata
185-
if out is not None:
186-
assert isinstance(out, Float8Tensor), f"{type(out)}"
187-
out._scale = scale
188-
return
189-
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)

float8_experimental/float8_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from float8_experimental.float8_dynamic_utils import (
2020
cast_to_float8_e4m3_dynamic,
2121
cast_to_float8_e5m2_dynamic_bw,
22-
WeightWithDynamicFloat8CastTensor,
2322
)
2423

2524
from float8_experimental.float8_tensor import (
@@ -35,6 +34,8 @@
3534
tensor_to_amax,
3635
)
3736

37+
from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor
38+
3839

3940
def _maybe_initialize_amaxes_scales_for_float8_cast(
4041
x,

float8_experimental/fsdp_utils.py

Lines changed: 142 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
import math
2-
from typing import List
8+
from typing import Any, List, Optional, Tuple
39

410
import torch
511
import torch.nn as nn
6-
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
7-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
12+
import torch.utils._pytree as pytree
13+
from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic
14+
15+
from float8_experimental.float8_tensor import (
16+
Float8Tensor,
17+
merge_mm_configs,
18+
ScaledMMConfig,
19+
)
20+
821
from float8_experimental.float8_utils import EPS
22+
from torch._prims_common import suggest_memory_format
923

1024

1125
@torch.no_grad()
@@ -19,6 +33,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
1933
optim.step()
2034
precompute_float8_dynamic_scale_for_fsdp(model)
2135
"""
36+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
2237
from torch.distributed._tensor import DTensor
2338

2439
if any(
@@ -50,3 +65,127 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
5065
scales = torch.split(scale_tensor, 1) # Replicate
5166
for scale, float8_linear in zip(scales, float8_linears):
5267
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
68+
69+
70+
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
71+
# that the padded local tensor (and any transformations like copying to GPU)
72+
# is of the subclass as well.
73+
_ops_to_preserve_subclass = {
74+
torch.ops.aten.empty_like.default,
75+
torch.ops.aten.new_zeros.default,
76+
torch.ops.aten.slice.Tensor,
77+
torch.ops.aten.copy_.default,
78+
torch.ops.aten.view.default,
79+
torch.ops.aten.as_strided.default,
80+
torch.ops.aten._to_copy.default,
81+
torch.ops.aten._pin_memory.default,
82+
}
83+
84+
85+
class WeightWithDynamicFloat8CastTensor(torch.Tensor):
86+
@staticmethod
87+
def __new__(
88+
cls,
89+
tensor: torch.Tensor,
90+
mm_config: ScaledMMConfig,
91+
precomputed_scale: Optional[torch.Tensor] = None,
92+
):
93+
return torch.Tensor._make_wrapper_subclass(
94+
cls,
95+
tensor.size(),
96+
strides=tensor.stride(),
97+
storage_offset=tensor.storage_offset(),
98+
memory_format=suggest_memory_format(tensor),
99+
dtype=tensor.dtype,
100+
layout=tensor.layout,
101+
device=tensor.device,
102+
pin_memory=tensor.is_pinned(),
103+
requires_grad=tensor.requires_grad,
104+
)
105+
106+
def __init__(
107+
self,
108+
tensor: torch.Tensor,
109+
mm_config: ScaledMMConfig,
110+
precomputed_scale: Optional[torch.Tensor] = None,
111+
):
112+
self._tensor = tensor
113+
self._mm_config = mm_config
114+
# for dynamic scaling
115+
# `precompute_float8_dynamic_scale_for_fsdp` calculates scales
116+
# for all float8 parameters after optimizer step
117+
self._precomputed_scale = precomputed_scale
118+
119+
@classmethod
120+
def __torch_dispatch__(cls, func, types, args, kwargs=None):
121+
if func == torch.ops.aten.detach.default:
122+
return WeightWithDynamicFloat8CastTensor(
123+
args[0]._tensor, args[0]._mm_config
124+
)
125+
mm_config: Optional[ScaledMMConfig] = None
126+
127+
def unwrap(t):
128+
nonlocal mm_config
129+
if mm_config is None:
130+
mm_config = t._mm_config
131+
else:
132+
mm_config = merge_mm_configs(mm_config, t._mm_config)
133+
return t._tensor
134+
135+
args, kwargs = pytree.tree_map_only(
136+
WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {})
137+
)
138+
out = func(*args, **kwargs)
139+
if func not in _ops_to_preserve_subclass:
140+
return out
141+
return pytree.tree_map_only(
142+
torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out
143+
)
144+
145+
def __tensor_flatten__(self):
146+
if self._precomputed_scale:
147+
return ["_tensor", "_precomputed_scale"], self._mm_config
148+
else:
149+
return ["_tensor"], self._mm_config
150+
151+
@staticmethod
152+
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
153+
mm_config = flatten_spec
154+
return WeightWithDynamicFloat8CastTensor(
155+
inner_tensors["_tensor"],
156+
mm_config,
157+
getattr(inner_tensors, "_precomputed_scale", None),
158+
)
159+
160+
def __repr__(self):
161+
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
162+
163+
def fsdp_pre_all_gather(self, mesh):
164+
if self._precomputed_scale is not None:
165+
float8_tensor = Float8Tensor.to_float8(
166+
self._tensor,
167+
self._precomputed_scale,
168+
torch.float8_e4m3fn,
169+
mm_config=self._mm_config,
170+
)
171+
else:
172+
float8_tensor = cast_to_float8_e4m3_dynamic(
173+
self._tensor, self._mm_config, reduce_amax=True
174+
)
175+
return (float8_tensor._data,), (float8_tensor._scale,)
176+
177+
def fsdp_post_all_gather(
178+
self,
179+
all_gather_outputs: Tuple[torch.Tensor, ...],
180+
metadata: Any,
181+
param_dtype: torch.dtype,
182+
*,
183+
out: Optional[torch.Tensor] = None,
184+
):
185+
(data,) = all_gather_outputs
186+
(scale,) = metadata
187+
if out is not None:
188+
assert isinstance(out, Float8Tensor), f"{type(out)}"
189+
out._scale = scale
190+
return
191+
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)

test/test_fsdp2/test_fsdp2_eager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import torch._dynamo.testing
88
import torch.distributed as dist
99
import torch.nn as nn
10-
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
1110
from float8_experimental.float8_linear import TensorScalingType
1211
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
12+
from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor
1313
from test_fsdp2_common import (
1414
check_parity_bf16_mp,
1515
check_parity_no_mp,

0 commit comments

Comments
 (0)