1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD 3-Clause license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
1
7
import math
2
- from typing import List
8
+ from typing import Any , List , Optional , Tuple
3
9
4
10
import torch
5
11
import torch .nn as nn
6
- from float8_experimental .float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
7
- from float8_experimental .float8_linear import Float8Linear , TensorScalingType
12
+ import torch .utils ._pytree as pytree
13
+ from float8_experimental .float8_dynamic_utils import cast_to_float8_e4m3_dynamic
14
+
15
+ from float8_experimental .float8_tensor import (
16
+ Float8Tensor ,
17
+ merge_mm_configs ,
18
+ ScaledMMConfig ,
19
+ )
20
+
8
21
from float8_experimental .float8_utils import EPS
22
+ from torch ._prims_common import suggest_memory_format
9
23
10
24
11
25
@torch .no_grad ()
@@ -19,6 +33,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
19
33
optim.step()
20
34
precompute_float8_dynamic_scale_for_fsdp(model)
21
35
"""
36
+ from float8_experimental .float8_linear import Float8Linear , TensorScalingType
22
37
from torch .distributed ._tensor import DTensor
23
38
24
39
if any (
@@ -50,3 +65,127 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
50
65
scales = torch .split (scale_tensor , 1 ) # Replicate
51
66
for scale , float8_linear in zip (scales , float8_linears ):
52
67
float8_linear .weight ._local_tensor ._precomputed_scale = scale ._local_tensor
68
+
69
+
70
+ # FSDP pads its local tensor on dim-0. The subclass should be preserved such
71
+ # that the padded local tensor (and any transformations like copying to GPU)
72
+ # is of the subclass as well.
73
+ _ops_to_preserve_subclass = {
74
+ torch .ops .aten .empty_like .default ,
75
+ torch .ops .aten .new_zeros .default ,
76
+ torch .ops .aten .slice .Tensor ,
77
+ torch .ops .aten .copy_ .default ,
78
+ torch .ops .aten .view .default ,
79
+ torch .ops .aten .as_strided .default ,
80
+ torch .ops .aten ._to_copy .default ,
81
+ torch .ops .aten ._pin_memory .default ,
82
+ }
83
+
84
+
85
+ class WeightWithDynamicFloat8CastTensor (torch .Tensor ):
86
+ @staticmethod
87
+ def __new__ (
88
+ cls ,
89
+ tensor : torch .Tensor ,
90
+ mm_config : ScaledMMConfig ,
91
+ precomputed_scale : Optional [torch .Tensor ] = None ,
92
+ ):
93
+ return torch .Tensor ._make_wrapper_subclass (
94
+ cls ,
95
+ tensor .size (),
96
+ strides = tensor .stride (),
97
+ storage_offset = tensor .storage_offset (),
98
+ memory_format = suggest_memory_format (tensor ),
99
+ dtype = tensor .dtype ,
100
+ layout = tensor .layout ,
101
+ device = tensor .device ,
102
+ pin_memory = tensor .is_pinned (),
103
+ requires_grad = tensor .requires_grad ,
104
+ )
105
+
106
+ def __init__ (
107
+ self ,
108
+ tensor : torch .Tensor ,
109
+ mm_config : ScaledMMConfig ,
110
+ precomputed_scale : Optional [torch .Tensor ] = None ,
111
+ ):
112
+ self ._tensor = tensor
113
+ self ._mm_config = mm_config
114
+ # for dynamic scaling
115
+ # `precompute_float8_dynamic_scale_for_fsdp` calculates scales
116
+ # for all float8 parameters after optimizer step
117
+ self ._precomputed_scale = precomputed_scale
118
+
119
+ @classmethod
120
+ def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
121
+ if func == torch .ops .aten .detach .default :
122
+ return WeightWithDynamicFloat8CastTensor (
123
+ args [0 ]._tensor , args [0 ]._mm_config
124
+ )
125
+ mm_config : Optional [ScaledMMConfig ] = None
126
+
127
+ def unwrap (t ):
128
+ nonlocal mm_config
129
+ if mm_config is None :
130
+ mm_config = t ._mm_config
131
+ else :
132
+ mm_config = merge_mm_configs (mm_config , t ._mm_config )
133
+ return t ._tensor
134
+
135
+ args , kwargs = pytree .tree_map_only (
136
+ WeightWithDynamicFloat8CastTensor , unwrap , (args , kwargs or {})
137
+ )
138
+ out = func (* args , ** kwargs )
139
+ if func not in _ops_to_preserve_subclass :
140
+ return out
141
+ return pytree .tree_map_only (
142
+ torch .Tensor , lambda x : WeightWithDynamicFloat8CastTensor (x , mm_config ), out
143
+ )
144
+
145
+ def __tensor_flatten__ (self ):
146
+ if self ._precomputed_scale :
147
+ return ["_tensor" , "_precomputed_scale" ], self ._mm_config
148
+ else :
149
+ return ["_tensor" ], self ._mm_config
150
+
151
+ @staticmethod
152
+ def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
153
+ mm_config = flatten_spec
154
+ return WeightWithDynamicFloat8CastTensor (
155
+ inner_tensors ["_tensor" ],
156
+ mm_config ,
157
+ getattr (inner_tensors , "_precomputed_scale" , None ),
158
+ )
159
+
160
+ def __repr__ (self ):
161
+ return f"WeightWithDynamicFloat8CastTensor(tensor={ self ._tensor } , mm_config={ self ._mm_config } )"
162
+
163
+ def fsdp_pre_all_gather (self , mesh ):
164
+ if self ._precomputed_scale is not None :
165
+ float8_tensor = Float8Tensor .to_float8 (
166
+ self ._tensor ,
167
+ self ._precomputed_scale ,
168
+ torch .float8_e4m3fn ,
169
+ mm_config = self ._mm_config ,
170
+ )
171
+ else :
172
+ float8_tensor = cast_to_float8_e4m3_dynamic (
173
+ self ._tensor , self ._mm_config , reduce_amax = True
174
+ )
175
+ return (float8_tensor ._data ,), (float8_tensor ._scale ,)
176
+
177
+ def fsdp_post_all_gather (
178
+ self ,
179
+ all_gather_outputs : Tuple [torch .Tensor , ...],
180
+ metadata : Any ,
181
+ param_dtype : torch .dtype ,
182
+ * ,
183
+ out : Optional [torch .Tensor ] = None ,
184
+ ):
185
+ (data ,) = all_gather_outputs
186
+ (scale ,) = metadata
187
+ if out is not None :
188
+ assert isinstance (out , Float8Tensor ), f"{ type (out )} "
189
+ out ._scale = scale
190
+ return
191
+ return Float8Tensor (data , scale , param_dtype , self ._mm_config ), (data ,)
0 commit comments