@@ -24,7 +24,15 @@ def to_fp8_no_autograd(
24
24
"""Convert a tensor to float8 without autograd
25
25
This is used in multiple places in the codebase to convert a tensor to float8
26
26
27
- This function will calculate the scale, do the scaling, and then convert to a Float8Tensor
27
+ This function will apply the scaling, and then convert to a Float8Tensor
28
+
29
+ Note:
30
+ We will call this function with a DTensor subclass. Ideally this would be an aten OP
31
+ that DTensor could overload to ensure proper semantics. There are some techincal issues
32
+ with that composing with FakeTensor, so we special case here.
33
+
34
+ DTensor Invariant: DTensor must always be the outer most tensor subclass
35
+
28
36
Args:
29
37
x: the tensor to convert
30
38
scale: the scale to use to convert the tensor
@@ -50,6 +58,32 @@ def to_fp8_no_autograd(
50
58
return Float8Tensor (bits_fp8 , x_scale , x .dtype , emulate = emulate )
51
59
52
60
61
+ def from_fp8_no_autograd (x : torch .Tensor ) -> torch .Tensor :
62
+ """Convert a tensor from float8 without autograd
63
+
64
+ This function will handle 3 cases:
65
+ 1. If the tensor is a DTensor, it will convert the inner tensor to the original precision
66
+ 2. If the tensor is a Float8Tensor, it will convert the tensor to the original precision
67
+ 3. If the tensor is a regular tensor, it will pass through this tensor
68
+
69
+ Args:
70
+ x: the tensor to convert
71
+ """
72
+
73
+ def to_original_precision (grad ):
74
+ if isinstance (grad , Float8Tensor ):
75
+ return grad .to_original_precision ()
76
+ else :
77
+ return grad
78
+
79
+ if isinstance (x , DTensor ):
80
+ local_grad = x .to_local ()
81
+ original_precision_grad = to_original_precision (local_grad )
82
+ return DTensor .from_local (original_precision_grad , x .device_mesh , x .placements )
83
+ else :
84
+ return to_original_precision (x )
85
+
86
+
53
87
@torch ._dynamo .allow_in_graph
54
88
class ToFloat8ConstrFunc (torch .autograd .Function ):
55
89
"""
@@ -62,17 +96,16 @@ def forward(
62
96
tensor : torch .Tensor ,
63
97
scale : torch .Tensor ,
64
98
float8_dtype = torch .float8_e4m3fn ,
65
- amax_buffer = None ,
99
+ amax_buffer : Optional [ torch . Tensor ] = None ,
66
100
emulate : bool = False ,
67
101
):
68
- """Converts a higher precision tensor to float8 in a differentiable way.
69
-
70
- Note:
71
- We will call this function with a DTensor subclass. Ideally this would be an aten OP
72
- that DTensor could overload to ensure proper semantics. There are some techincal issues
73
- with that composing with FakeTensor, so we special case here.
74
-
75
- DTensor Invariant: DTensor must always be the outer most tensor subclass
102
+ """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
103
+ Args
104
+ tensor: the tensor to convert
105
+ scale: the scale to use to convert the tensor
106
+ float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
107
+ amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion
108
+ emulate: whether to emulate the matmuls in fp32
76
109
"""
77
110
if amax_buffer is not None :
78
111
amax_buffer .fill_ (tensor_to_amax (tensor ))
@@ -81,26 +114,8 @@ def forward(
81
114
82
115
@staticmethod
83
116
def backward (ctx , g ):
84
- def to_original_precision (grad ):
85
- if isinstance (grad , Float8Tensor ):
86
- return grad .to_original_precision ()
87
- else :
88
- return grad
89
-
90
- if isinstance (g , DTensor ):
91
- local_grad = g .to_local ()
92
- original_precision_grad = to_original_precision (local_grad )
93
- return (
94
- DTensor .from_local (
95
- original_precision_grad , g .device_mesh , g .placements
96
- ),
97
- None ,
98
- None ,
99
- None ,
100
- None ,
101
- )
102
- else :
103
- return to_original_precision (g ), None , None , None , None
117
+ grad = from_fp8_no_autograd (g )
118
+ return grad , None , None , None , None
104
119
105
120
106
121
@torch ._dynamo .allow_in_graph
0 commit comments