13
13
"""
14
14
15
15
import dataclasses
16
+ import functools
16
17
17
- from typing import Any , cast , Optional , Tuple , Union
18
+ from typing import Any , Callable , cast , Optional , Tuple , Union
18
19
19
20
import float8_experimental .config as config
20
21
29
30
tensor_to_amax ,
30
31
to_fp8_saturated ,
31
32
)
33
+ from torch .utils ._pytree import tree_map
32
34
33
35
34
36
def _maybe_initialize_amaxes_scales_for_float8_cast (
@@ -221,9 +223,7 @@ def cast_x_to_float8(
221
223
)
222
224
return x_fp8
223
225
224
- def cast_w_to_float8 (
225
- self , w : torch .Tensor , is_amax_initialized : bool
226
- ) -> torch .Tensor :
226
+ def cast_w_to_float8 (self , w : torch .Tensor ) -> torch .Tensor :
227
227
scale_fn_name = self .recipe .scale_fn_name
228
228
_maybe_initialize_amaxes_scales_for_float8_cast (
229
229
w ,
@@ -232,7 +232,7 @@ def cast_w_to_float8(
232
232
self .fp8_scale_w ,
233
233
scale_fn_name ,
234
234
torch .float8_e4m3fn ,
235
- is_amax_initialized ,
235
+ self . is_amax_initialized ,
236
236
)
237
237
238
238
w_fp8 = Float8Tensor .to_float8 (
@@ -296,7 +296,7 @@ def forward(self, x):
296
296
w_fp8 = (
297
297
self .weight
298
298
if isinstance (self .weight , Float8Tensor )
299
- else self .cast_w_to_float8 (self .weight , self . is_amax_initialized )
299
+ else self .cast_w_to_float8 (self .weight )
300
300
)
301
301
302
302
y = torch .matmul (x_fp8 , w_fp8 .t ())
@@ -332,7 +332,9 @@ def from_float(
332
332
# with torch.device("meta"):
333
333
new_mod = cls (mod .in_features , mod .out_features , bias = False )
334
334
new_mod .weight = (
335
- nn .Parameter (Float8LinearWeightTensor (mod .weight ))
335
+ nn .Parameter (
336
+ Float8LinearWeightTensor (mod .weight , new_mod .cast_w_to_float8 , emulate )
337
+ )
336
338
if use_fp8_all_gather
337
339
else mod .weight
338
340
)
@@ -344,12 +346,32 @@ def from_float(
344
346
345
347
346
348
class Float8LinearWeightTensor (torch .Tensor ):
347
- # TODO: Remove `module` arg, save state on subclass, and propagate it.
348
- def fsdp_pre_all_gather (
349
- self , module : nn .Module
350
- ) -> Tuple [Tuple [torch .Tensor , ...], Any ]:
351
- float8_tensor = module .cast_w_to_float8 (self , module .is_amax_initialized )
352
- return (float8_tensor ._data ,), (float8_tensor ._scale , module .emulate )
349
+ def __new__ (cls , tensor : torch .Tensor , cast_fn : Callable , emulate : bool ):
350
+ return cls ._make_subclass (cls , tensor , tensor .requires_grad )
351
+
352
+ def __init__ (self , tensor : torch .Tensor , cast_fn : Callable , emulate : bool ):
353
+ super ().__init__ ()
354
+ self .cast_fn = cast_fn
355
+ self .emulate = emulate
356
+
357
+ @classmethod
358
+ def __torch_function__ (cls , func , types , args = (), kwargs = None ):
359
+ kwargs = kwargs or {}
360
+
361
+ def wrap (cast_fn : Callable , emulate : bool , o : Any ):
362
+ if isinstance (o , torch .Tensor ) and not isinstance (o , cls ):
363
+ return cls (o , cast_fn , emulate )
364
+ return o
365
+
366
+ with torch ._C .DisableTorchFunctionSubclass ():
367
+ if isinstance (args [0 ], cls ):
368
+ out = func (* args , ** kwargs )
369
+ return tree_map (functools .partial (wrap , args [0 ].cast_fn , args [0 ].emulate ), out )
370
+ return func (* args , ** kwargs )
371
+
372
+ def fsdp_pre_all_gather (self ) -> Tuple [Tuple [torch .Tensor , ...], Any ]:
373
+ float8_tensor = self .cast_fn (self )
374
+ return (float8_tensor ._data ,), (float8_tensor ._scale ,)
353
375
354
376
def fsdp_post_all_gather (
355
377
self ,
@@ -359,8 +381,8 @@ def fsdp_post_all_gather(
359
381
* ,
360
382
out : Optional [torch .Tensor ] = None ,
361
383
) -> Union [Tuple [Float8Tensor , Tuple [torch .Tensor , ...]], None ]:
362
- ( data ,) = all_gather_outputs
363
- scale , emulate = metadata
384
+ data , = all_gather_outputs
385
+ scale , = metadata
364
386
if out is not None :
365
387
out = cast (Float8Tensor , out )
366
388
assert (
@@ -369,4 +391,4 @@ def fsdp_post_all_gather(
369
391
)
370
392
out ._scale = scale
371
393
return
372
- return Float8Tensor (data , scale , param_dtype , emulate ), (data ,)
394
+ return Float8Tensor (data , scale , param_dtype , self . emulate ), (data ,)
0 commit comments