13
13
"""
14
14
15
15
import dataclasses
16
+ import functools
16
17
17
- from typing import Any , cast , Optional , Tuple , Union
18
+ from typing import Any , Callable , cast , Optional , Tuple , Union
18
19
19
20
import float8_experimental .config as config
20
21
29
30
tensor_to_amax ,
30
31
to_fp8_saturated ,
31
32
)
33
+ from torch .utils ._pytree import tree_map
32
34
33
35
34
36
def _maybe_initialize_amaxes_scales_for_float8_cast (
@@ -222,9 +224,7 @@ def cast_x_to_float8(
222
224
)
223
225
return x_fp8
224
226
225
- def cast_w_to_float8 (
226
- self , w : torch .Tensor , is_amax_initialized : bool
227
- ) -> torch .Tensor :
227
+ def cast_w_to_float8 (self , w : torch .Tensor ) -> torch .Tensor :
228
228
scale_fn_name = self .recipe .scale_fn_name
229
229
_maybe_initialize_amaxes_scales_for_float8_cast (
230
230
w ,
@@ -233,7 +233,7 @@ def cast_w_to_float8(
233
233
self .fp8_scale_w ,
234
234
scale_fn_name ,
235
235
torch .float8_e4m3fn ,
236
- is_amax_initialized ,
236
+ self . is_amax_initialized ,
237
237
)
238
238
239
239
w_fp8 = Float8Tensor .to_float8 (
@@ -297,7 +297,7 @@ def forward(self, x):
297
297
w_fp8 = (
298
298
self .weight
299
299
if isinstance (self .weight , Float8Tensor )
300
- else self .cast_w_to_float8 (self .weight , self . is_amax_initialized )
300
+ else self .cast_w_to_float8 (self .weight )
301
301
)
302
302
303
303
y = torch .matmul (x_fp8 , w_fp8 .t ())
@@ -333,7 +333,9 @@ def from_float(
333
333
# with torch.device("meta"):
334
334
new_mod = cls (mod .in_features , mod .out_features , bias = False )
335
335
new_mod .weight = (
336
- nn .Parameter (Float8LinearWeightTensor (mod .weight ))
336
+ nn .Parameter (
337
+ Float8LinearWeightTensor (mod .weight , new_mod .cast_w_to_float8 , emulate )
338
+ )
337
339
if use_fp8_all_gather
338
340
else mod .weight
339
341
)
@@ -345,12 +347,34 @@ def from_float(
345
347
346
348
347
349
class Float8LinearWeightTensor (torch .Tensor ):
348
- # TODO: Remove `module` arg, save state on subclass, and propagate it.
349
- def fsdp_pre_all_gather (
350
- self , module : nn .Module
351
- ) -> Tuple [Tuple [torch .Tensor , ...], Any ]:
352
- float8_tensor = module .cast_w_to_float8 (self , module .is_amax_initialized )
353
- return (float8_tensor ._data ,), (float8_tensor ._scale , module .emulate )
350
+ def __new__ (cls , tensor : torch .Tensor , cast_fn : Callable , emulate : bool ):
351
+ return cls ._make_subclass (cls , tensor , tensor .requires_grad )
352
+
353
+ def __init__ (self , tensor : torch .Tensor , cast_fn : Callable , emulate : bool ):
354
+ super ().__init__ ()
355
+ self .cast_fn = cast_fn
356
+ self .emulate = emulate
357
+
358
+ @classmethod
359
+ def __torch_function__ (cls , func , types , args = (), kwargs = None ):
360
+ kwargs = kwargs or {}
361
+
362
+ def wrap (cast_fn : Callable , emulate : bool , o : Any ):
363
+ if isinstance (o , torch .Tensor ) and not isinstance (o , cls ):
364
+ return cls (o , cast_fn , emulate )
365
+ return o
366
+
367
+ with torch ._C .DisableTorchFunctionSubclass ():
368
+ if isinstance (args [0 ], cls ):
369
+ out = func (* args , ** kwargs )
370
+ return tree_map (
371
+ functools .partial (wrap , args [0 ].cast_fn , args [0 ].emulate ), out
372
+ )
373
+ return func (* args , ** kwargs )
374
+
375
+ def fsdp_pre_all_gather (self ) -> Tuple [Tuple [torch .Tensor , ...], Any ]:
376
+ float8_tensor = self .cast_fn (self )
377
+ return (float8_tensor ._data ,), (float8_tensor ._scale ,)
354
378
355
379
def fsdp_post_all_gather (
356
380
self ,
@@ -361,7 +385,7 @@ def fsdp_post_all_gather(
361
385
out : Optional [torch .Tensor ] = None ,
362
386
) -> Union [Tuple [Float8Tensor , Tuple [torch .Tensor , ...]], None ]:
363
387
(data ,) = all_gather_outputs
364
- scale , emulate = metadata
388
+ ( scale ,) = metadata
365
389
if out is not None :
366
390
out = cast (Float8Tensor , out )
367
391
assert (
@@ -370,4 +394,4 @@ def fsdp_post_all_gather(
370
394
)
371
395
out ._scale = scale
372
396
return
373
- return Float8Tensor (data , scale , param_dtype , emulate ), (data ,)
397
+ return Float8Tensor (data , scale , param_dtype , self . emulate ), (data ,)
0 commit comments