4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Tuple
7
+ from typing import Literal , Tuple
8
8
9
9
import torch
10
10
import torch .distributed as dist
11
11
12
12
# Helpful visualizer for debugging (only supports fp32):
13
13
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
14
14
15
- # define the e4m3/e5m2 constants
16
- E4M3_MAX_POS = torch .finfo (torch .float8_e4m3fn ).max
17
- E5M2_MAX_POS = torch .finfo (torch .float8_e5m2 ).max
18
-
19
- FP16_MAX_POS = torch .finfo (torch .float16 ).max
20
-
21
15
# avoid division by zero when calculating scale
22
16
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
23
17
EPS = 1e-12
24
18
19
+ IS_AMD = torch .cuda .is_available () and torch .version .hip is not None
20
+ FP8_TYPES = {
21
+ torch .float8_e4m3fn ,
22
+ torch .float8_e5m2 ,
23
+ torch .float8_e4m3fnuz ,
24
+ torch .float8_e5m2fnuz ,
25
+ }
26
+
25
27
26
28
@torch .no_grad ()
27
- def amax_to_scale (amax , float8_dtype , orig_dtype ):
29
+ def amax_to_scale (
30
+ amax : torch .Tensor , float8_dtype : torch .dtype , orig_dtype : torch .dtype
31
+ ):
32
+ """Converts the amax value of a tensor to the fp8 scale.
33
+ Args:
34
+ amax: The amax value of the tensor.
35
+ float8_dtype: The float8 dtype.
36
+ orig_dtype: The original dtype of the tensor.
37
+ """
28
38
scale = torch .empty_like (amax , dtype = torch .float32 )
29
- if float8_dtype == torch . float8_e4m3fn :
30
- res = E4M3_MAX_POS / torch .clamp (amax , min = EPS )
31
- else : # e5m2
32
- res = E5M2_MAX_POS / torch . clamp ( amax , min = EPS )
39
+ if float8_dtype in FP8_TYPES :
40
+ res = torch . finfo ( float8_dtype ). max / torch .clamp (amax , min = EPS )
41
+ else :
42
+ raise ValueError ( f"Unsupported float8_dtype: { float8_dtype } " )
33
43
34
44
# Ensure that the scale is representable in float16,
35
45
# this helps when amax is small. We are assuming that we don't need
36
46
# to care about this for float32/bfloat16.
37
47
if orig_dtype is torch .float16 :
38
- res = torch .clamp (res , max = FP16_MAX_POS )
48
+ res = torch .clamp (res , max = torch . finfo ( torch . float16 ). max )
39
49
scale .copy_ (res )
40
50
return scale
41
51
42
52
43
53
@torch .no_grad ()
44
54
def amax_history_to_scale (
45
- amax_history ,
46
- float8_dtype ,
47
- orig_dtype ,
48
- history_to_scale_fn_type ,
55
+ amax_history : torch . Tensor ,
56
+ float8_dtype : torch . Tensor ,
57
+ orig_dtype : torch . dtype ,
58
+ history_to_scale_fn_type : Literal [ "max" ] ,
49
59
):
60
+ """Takes in a history of amax values and returns a scale tensor.
61
+ Args:
62
+ amax_history: A tensor containing the history of amax values.
63
+ float8_dtype: The float8 dtype.
64
+ orig_dtype: The original dtype of the tensor.
65
+ history_to_scale_fn_type: The type of function to use to convert the history to a scale.
66
+ """
50
67
if history_to_scale_fn_type == "max" :
51
68
amax = torch .max (amax_history )
52
69
return amax_to_scale (amax , float8_dtype , orig_dtype )
@@ -58,9 +75,15 @@ def amax_history_to_scale_stack(
58
75
amax_history : torch .Tensor ,
59
76
float8_dtype : torch .dtype ,
60
77
orig_dtype : torch .dtype ,
61
- history_to_scale_fn_type : str ,
78
+ history_to_scale_fn_type : Literal [ "max" ] ,
62
79
) -> torch .Tensor :
63
- """Takes in a stack of amax_history tensors and returns a scale tensor."""
80
+ """Takes in a stack of amax_history tensors and returns a scale tensor.
81
+ Args:
82
+ amax_history: A 2D tensor containing a stack of amax histories.
83
+ float8_dtype: The float8 dtype.
84
+ orig_dtype: The original dtype of the tensor.
85
+ history_to_scale_fn_type: The type of function to use to convert the history to a scale.
86
+ """
64
87
if history_to_scale_fn_type == "max" :
65
88
amax_stack = torch .max (amax_history , dim = 1 ).values
66
89
return amax_to_scale (amax_stack , float8_dtype , orig_dtype )
@@ -90,21 +113,35 @@ def tensor_to_scale(
90
113
return amax_to_scale (amax , float8_dtype , x .dtype )
91
114
92
115
93
- def to_fp8_saturated (x , float8_dtype : torch .dtype ):
94
- # The default behavior in PyTorch for casting to `float8_e4m3fn`
95
- # and `e5m2` is to not saturate. In this context, we should saturate.
96
- # A common case where we want to saturate is when the history of a
97
- # tensor has a maximum value of `amax1`, and the current amax value
98
- # is `amax2`, where `amax1 < amax2`. This is common when using delayed
99
- # scaling.
100
- if float8_dtype == torch .float8_e4m3fn :
101
- x = x .clamp (min = - 1 * E4M3_MAX_POS , max = E4M3_MAX_POS )
116
+ def to_fp8_saturated (x : torch .Tensor , float8_dtype : torch .dtype ):
117
+ """Converts a tensor to a saturated fp8 tensor.
118
+
119
+ Note:
120
+ The default behavior in PyTorch for casting to `float8_e4m3fn`
121
+ and `e5m2` is to not saturate. In this context, we should saturate.
122
+ A common case where we want to saturate is when the history of a
123
+ tensor has a maximum value of `amax1`, and the current amax value
124
+ is `amax2`, where `amax1 < amax2`. This is common when using delayed
125
+ scaling.
126
+ """
127
+ if float8_dtype in FP8_TYPES :
128
+ max_value = torch .finfo (float8_dtype ).max
129
+ x = x .clamp (min = - max_value , max = max_value )
130
+ return x .to (float8_dtype )
102
131
else :
103
- x = x .clamp (min = - 1 * E5M2_MAX_POS , max = E5M2_MAX_POS )
104
- return x .to (float8_dtype )
132
+ raise ValueError (f"Unsupported float8_dtype: { float8_dtype } " )
133
+
134
+
135
+ def compute_error (x : torch .Tensor , y : torch .Tensor ):
136
+ """Computes the error between two tensors in dB.
105
137
138
+ For more details see:
139
+ https://en.wikipedia.org/wiki/Signal-to-noise_ratio
106
140
107
- def compute_error (x , y ):
141
+ Args:
142
+ x: The original tensor.
143
+ y: The tensor to compare to the original tensor.
144
+ """
108
145
Ps = torch .norm (x )
109
146
Pn = torch .norm (x - y )
110
147
return 20 * torch .log10 (Ps / Pn )
@@ -113,11 +150,19 @@ def compute_error(x, y):
113
150
def fp8_tensor_statistics (
114
151
tensor : torch .Tensor , float8_dtype = torch .float8_e4m3fn
115
152
) -> Tuple [int , ...]:
116
- """Calculate FP8 tensor stats"""
117
- if float8_dtype == torch .float8_e4m3fn :
118
- FP8_MAX = E4M3_MAX_POS
119
- else : # e5m2
120
- FP8_MAX = E5M2_MAX_POS
153
+ """Calculate FP8 tensor stats
154
+
155
+ Args:
156
+ tensor: The tensor to calculate stats for.
157
+ float8_dtype: The float8 dtype.
158
+
159
+ Returns:
160
+ A tuple containing the number of zeros and the number of max values.
161
+ """
162
+ if float8_dtype in FP8_TYPES :
163
+ FP8_MAX = torch .finfo (float8_dtype ).max
164
+ else :
165
+ raise ValueError (f"Unsupported float8_dtype: { float8_dtype } " )
121
166
tensor_orig_type = tensor ._data .to (dtype = tensor ._orig_dtype )
122
167
num_max = (torch .abs (tensor_orig_type ) == FP8_MAX ).sum ().item ()
123
168
num_zero = (tensor_orig_type == 0 ).sum ().item ()
0 commit comments