Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 9facff8

Browse files
committed
move WeightWithDynamicFloat8CastTensor to fsdp_utils.py
Summary: Refactor in preparation of adding delayed scaling support to float8 all-gather, no logic change. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f4a7638 Pull Request resolved: #310
1 parent 5744c0c commit 9facff8

File tree

4 files changed

+115
-108
lines changed

4 files changed

+115
-108
lines changed

float8_experimental/float8_dynamic_utils.py

Lines changed: 0 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,16 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
"""
7-
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
8-
"""
9-
10-
from typing import Any, Optional, Tuple
11-
12-
import float8_experimental.config as config
136

147
import torch
15-
import torch.nn as nn
16-
import torch.utils._pytree as pytree
178

189
from float8_experimental.float8_tensor import (
1910
Float8Tensor,
20-
merge_mm_configs,
2111
ScaledMMConfig,
2212
tensor_already_casted_to_fp8,
2313
to_fp8_no_autograd,
2414
)
2515
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale
26-
from torch._prims_common import suggest_memory_format
2716

2817

2918
@torch._dynamo.allow_in_graph
@@ -66,98 +55,3 @@ def cast_to_float8_e5m2_dynamic_bw(
6655
gradY: torch.Tensor, mm_config: ScaledMMConfig
6756
) -> torch.Tensor:
6857
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
69-
70-
71-
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
72-
# that the padded local tensor (and any transformations like copying to GPU)
73-
# is of the subclass as well.
74-
_ops_to_preserve_subclass = {
75-
torch.ops.aten.empty_like.default,
76-
torch.ops.aten.new_zeros.default,
77-
torch.ops.aten.slice.Tensor,
78-
torch.ops.aten.copy_.default,
79-
torch.ops.aten.view.default,
80-
torch.ops.aten.as_strided.default,
81-
torch.ops.aten._to_copy.default,
82-
torch.ops.aten._pin_memory.default,
83-
}
84-
85-
86-
class WeightWithDynamicFloat8CastTensor(torch.Tensor):
87-
@staticmethod
88-
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
89-
return torch.Tensor._make_wrapper_subclass(
90-
cls,
91-
tensor.size(),
92-
strides=tensor.stride(),
93-
storage_offset=tensor.storage_offset(),
94-
memory_format=suggest_memory_format(tensor),
95-
dtype=tensor.dtype,
96-
layout=tensor.layout,
97-
device=tensor.device,
98-
pin_memory=tensor.is_pinned(),
99-
requires_grad=tensor.requires_grad,
100-
)
101-
102-
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
103-
self._tensor = tensor
104-
self._mm_config = mm_config
105-
106-
@classmethod
107-
def __torch_dispatch__(cls, func, types, args, kwargs=None):
108-
if func == torch.ops.aten.detach.default:
109-
return WeightWithDynamicFloat8CastTensor(
110-
args[0]._tensor, args[0]._mm_config
111-
)
112-
mm_config: Optional[ScaledMMConfig] = None
113-
114-
def unwrap(t):
115-
nonlocal mm_config
116-
if mm_config is None:
117-
mm_config = t._mm_config
118-
else:
119-
mm_config = merge_mm_configs(mm_config, t._mm_config)
120-
return t._tensor
121-
122-
args, kwargs = pytree.tree_map_only(
123-
WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {})
124-
)
125-
out = func(*args, **kwargs)
126-
if func not in _ops_to_preserve_subclass:
127-
return out
128-
return pytree.tree_map_only(
129-
torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out
130-
)
131-
132-
def __tensor_flatten__(self):
133-
return ["_tensor"], self._mm_config
134-
135-
@staticmethod
136-
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
137-
mm_config = flatten_spec
138-
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
139-
140-
def __repr__(self):
141-
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
142-
143-
def fsdp_pre_all_gather(self, mesh):
144-
float8_tensor = cast_to_float8_e4m3_dynamic(
145-
self._tensor, self._mm_config, reduce_amax=True
146-
)
147-
return (float8_tensor._data,), (float8_tensor._scale,)
148-
149-
def fsdp_post_all_gather(
150-
self,
151-
all_gather_outputs: Tuple[torch.Tensor, ...],
152-
metadata: Any,
153-
param_dtype: torch.dtype,
154-
*,
155-
out: Optional[torch.Tensor] = None,
156-
):
157-
(data,) = all_gather_outputs
158-
(scale,) = metadata
159-
if out is not None:
160-
assert isinstance(out, Float8Tensor), f"{type(out)}"
161-
out._scale = scale
162-
return
163-
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)

float8_experimental/float8_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from float8_experimental.float8_dynamic_utils import (
2020
cast_to_float8_e4m3_dynamic,
2121
cast_to_float8_e5m2_dynamic_bw,
22-
WeightWithDynamicFloat8CastTensor,
2322
)
2423

2524
from float8_experimental.float8_tensor import (
@@ -35,6 +34,8 @@
3534
tensor_to_amax,
3635
)
3736

37+
from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor
38+
3839

3940
def _maybe_initialize_amaxes_scales_for_float8_cast(
4041
x,

float8_experimental/fsdp_utils.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Optional, Tuple
8+
9+
import torch
10+
import torch.utils._pytree as pytree
11+
from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic
12+
13+
from float8_experimental.float8_tensor import (
14+
Float8Tensor,
15+
merge_mm_configs,
16+
ScaledMMConfig,
17+
)
18+
from torch._prims_common import suggest_memory_format
19+
20+
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
21+
# that the padded local tensor (and any transformations like copying to GPU)
22+
# is of the subclass as well.
23+
_ops_to_preserve_subclass = {
24+
torch.ops.aten.empty_like.default,
25+
torch.ops.aten.new_zeros.default,
26+
torch.ops.aten.slice.Tensor,
27+
torch.ops.aten.copy_.default,
28+
torch.ops.aten.view.default,
29+
torch.ops.aten.as_strided.default,
30+
torch.ops.aten._to_copy.default,
31+
torch.ops.aten._pin_memory.default,
32+
}
33+
34+
35+
class WeightWithDynamicFloat8CastTensor(torch.Tensor):
36+
@staticmethod
37+
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
38+
return torch.Tensor._make_wrapper_subclass(
39+
cls,
40+
tensor.size(),
41+
strides=tensor.stride(),
42+
storage_offset=tensor.storage_offset(),
43+
memory_format=suggest_memory_format(tensor),
44+
dtype=tensor.dtype,
45+
layout=tensor.layout,
46+
device=tensor.device,
47+
pin_memory=tensor.is_pinned(),
48+
requires_grad=tensor.requires_grad,
49+
)
50+
51+
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
52+
self._tensor = tensor
53+
self._mm_config = mm_config
54+
55+
@classmethod
56+
def __torch_dispatch__(cls, func, types, args, kwargs=None):
57+
if func == torch.ops.aten.detach.default:
58+
return WeightWithDynamicFloat8CastTensor(
59+
args[0]._tensor, args[0]._mm_config
60+
)
61+
mm_config: Optional[ScaledMMConfig] = None
62+
63+
def unwrap(t):
64+
nonlocal mm_config
65+
if mm_config is None:
66+
mm_config = t._mm_config
67+
else:
68+
mm_config = merge_mm_configs(mm_config, t._mm_config)
69+
return t._tensor
70+
71+
args, kwargs = pytree.tree_map_only(
72+
WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {})
73+
)
74+
out = func(*args, **kwargs)
75+
if func not in _ops_to_preserve_subclass:
76+
return out
77+
return pytree.tree_map_only(
78+
torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out
79+
)
80+
81+
def __tensor_flatten__(self):
82+
return ["_tensor"], self._mm_config
83+
84+
@staticmethod
85+
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
86+
mm_config = flatten_spec
87+
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
88+
89+
def __repr__(self):
90+
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
91+
92+
def fsdp_pre_all_gather(self, mesh):
93+
float8_tensor = cast_to_float8_e4m3_dynamic(
94+
self._tensor, self._mm_config, reduce_amax=True
95+
)
96+
return (float8_tensor._data,), (float8_tensor._scale,)
97+
98+
def fsdp_post_all_gather(
99+
self,
100+
all_gather_outputs: Tuple[torch.Tensor, ...],
101+
metadata: Any,
102+
param_dtype: torch.dtype,
103+
*,
104+
out: Optional[torch.Tensor] = None,
105+
):
106+
(data,) = all_gather_outputs
107+
(scale,) = metadata
108+
if out is not None:
109+
assert isinstance(out, Float8Tensor), f"{type(out)}"
110+
out._scale = scale
111+
return
112+
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)

test/test_fsdp2/test_fsdp2_eager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import torch._dynamo.testing
99
import torch.distributed as dist
1010
import torch.nn as nn
11-
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
1211
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
1312
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
13+
from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor
1414
from test_fsdp2_common import (
1515
check_parity_bf16_mp,
1616
check_parity_no_mp,

0 commit comments

Comments
 (0)