4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Tuple
7
+ from typing import Literal , Tuple
8
8
9
9
import torch
10
10
import torch .distributed as dist
14
14
15
15
# define the e4m3/e5m2 constants
16
16
E4M3_MAX_POS = torch .finfo (torch .float8_e4m3fn ).max
17
+ E4M3_FNUZ_MAX_POS = torch .finfo (torch .float8_e4m3fnuz ).max
17
18
E5M2_MAX_POS = torch .finfo (torch .float8_e5m2 ).max
19
+ E5M2_FNUZ_MAX_POS = torch .finfo (torch .float8_e5m2fnuz ).max
18
20
19
21
FP16_MAX_POS = torch .finfo (torch .float16 ).max
20
22
21
23
# avoid division by zero when calculating scale
22
24
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
23
25
EPS = 1e-12
24
26
27
+ IS_AMD = torch .cuda .is_available () and torch .version .hip is not None
28
+
25
29
26
30
@torch .no_grad ()
27
- def amax_to_scale (amax , float8_dtype , orig_dtype ):
31
+ def amax_to_scale (
32
+ amax : torch .Tensor , float8_dtype : torch .dtype , orig_dtype : torch .dtype
33
+ ):
34
+ """Converts the amax value of a tensor to the fp8 scale.
35
+ Args:
36
+ amax: The amax value of the tensor.
37
+ float8_dtype: The float8 dtype.
38
+ orig_dtype: The original dtype of the tensor.
39
+ """
28
40
scale = torch .empty_like (amax , dtype = torch .float32 )
29
41
if float8_dtype == torch .float8_e4m3fn :
30
42
res = E4M3_MAX_POS / torch .clamp (amax , min = EPS )
31
- else : # e5m2
43
+ elif float8_dtype == torch .float8_e4m3fnuz :
44
+ res = E4M3_FNUZ_MAX_POS / torch .clamp (amax , min = EPS )
45
+ elif float8_dtype == torch .float8_e5m2 :
32
46
res = E5M2_MAX_POS / torch .clamp (amax , min = EPS )
47
+ elif float8_dtype == torch .float8_e5m2fnuz :
48
+ res = E5M2_FNUZ_MAX_POS / torch .clamp (amax , min = EPS )
49
+ else :
50
+ raise ValueError (f"Unsupported float8_dtype: { float8_dtype } " )
33
51
34
52
# Ensure that the scale is representable in float16,
35
53
# this helps when amax is small. We are assuming that we don't need
@@ -42,11 +60,18 @@ def amax_to_scale(amax, float8_dtype, orig_dtype):
42
60
43
61
@torch .no_grad ()
44
62
def amax_history_to_scale (
45
- amax_history ,
46
- float8_dtype ,
47
- orig_dtype ,
48
- history_to_scale_fn_type ,
63
+ amax_history : torch . Tensor ,
64
+ float8_dtype : torch . Tensor ,
65
+ orig_dtype : torch . dtype ,
66
+ history_to_scale_fn_type : Literal [ "max" ] ,
49
67
):
68
+ """Takes in a history of amax values and returns a scale tensor.
69
+ Args:
70
+ amax_history: A tensor containing the history of amax values.
71
+ float8_dtype: The float8 dtype.
72
+ orig_dtype: The original dtype of the tensor.
73
+ history_to_scale_fn_type: The type of function to use to convert the history to a scale.
74
+ """
50
75
if history_to_scale_fn_type == "max" :
51
76
amax = torch .max (amax_history )
52
77
return amax_to_scale (amax , float8_dtype , orig_dtype )
@@ -58,9 +83,15 @@ def amax_history_to_scale_stack(
58
83
amax_history : torch .Tensor ,
59
84
float8_dtype : torch .dtype ,
60
85
orig_dtype : torch .dtype ,
61
- history_to_scale_fn_type : str ,
86
+ history_to_scale_fn_type : Literal [ "max" ] ,
62
87
) -> torch .Tensor :
63
- """Takes in a stack of amax_history tensors and returns a scale tensor."""
88
+ """Takes in a stack of amax_history tensors and returns a scale tensor.
89
+ Args:
90
+ amax_history: A 2D tensor containing a stack of amax histories.
91
+ float8_dtype: The float8 dtype.
92
+ orig_dtype: The original dtype of the tensor.
93
+ history_to_scale_fn_type: The type of function to use to convert the history to a scale.
94
+ """
64
95
if history_to_scale_fn_type == "max" :
65
96
amax_stack = torch .max (amax_history , dim = 1 ).values
66
97
return amax_to_scale (amax_stack , float8_dtype , orig_dtype )
@@ -90,21 +121,41 @@ def tensor_to_scale(
90
121
return amax_to_scale (amax , float8_dtype , x .dtype )
91
122
92
123
93
- def to_fp8_saturated (x , float8_dtype : torch .dtype ):
94
- # The default behavior in PyTorch for casting to `float8_e4m3fn`
95
- # and `e5m2` is to not saturate. In this context, we should saturate.
96
- # A common case where we want to saturate is when the history of a
97
- # tensor has a maximum value of `amax1`, and the current amax value
98
- # is `amax2`, where `amax1 < amax2`. This is common when using delayed
99
- # scaling.
124
+ def to_fp8_saturated (x : torch .Tensor , float8_dtype : torch .dtype ):
125
+ """Converts a tensor to a saturated fp8 tensor.
126
+
127
+ Note:
128
+ The default behavior in PyTorch for casting to `float8_e4m3fn`
129
+ and `e5m2` is to not saturate. In this context, we should saturate.
130
+ A common case where we want to saturate is when the history of a
131
+ tensor has a maximum value of `amax1`, and the current amax value
132
+ is `amax2`, where `amax1 < amax2`. This is common when using delayed
133
+ scaling.
134
+ """
135
+
100
136
if float8_dtype == torch .float8_e4m3fn :
101
137
x = x .clamp (min = - 1 * E4M3_MAX_POS , max = E4M3_MAX_POS )
102
- else :
138
+ elif float8_dtype == torch .float8_e4m3fnuz :
139
+ x = x .clamp (min = - 1 * E4M3_FNUZ_MAX_POS , max = E4M3_FNUZ_MAX_POS )
140
+ elif float8_dtype == torch .float8_e5m2 :
103
141
x = x .clamp (min = - 1 * E5M2_MAX_POS , max = E5M2_MAX_POS )
142
+ elif float8_dtype == torch .float8_e5m2fnuz :
143
+ x = x .clamp (min = - 1 * E5M2_FNUZ_MAX_POS , max = E5M2_FNUZ_MAX_POS )
144
+ else :
145
+ raise ValueError (f"Unsupported float8_dtype: { float8_dtype } " )
104
146
return x .to (float8_dtype )
105
147
106
148
107
- def compute_error (x , y ):
149
+ def compute_error (x : torch .Tensor , y : torch .Tensor ):
150
+ """Computes the error between two tensors in dB.
151
+
152
+ For more details see:
153
+ https://en.wikipedia.org/wiki/Signal-to-noise_ratio
154
+
155
+ Args:
156
+ x: The original tensor.
157
+ y: The tensor to compare to the original tensor.
158
+ """
108
159
Ps = torch .norm (x )
109
160
Pn = torch .norm (x - y )
110
161
return 20 * torch .log10 (Ps / Pn )
0 commit comments