4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from typing import Literal
8
+
7
9
import torch
8
10
import torch .distributed as dist
9
11
12
14
13
15
# define the e4m3/e5m2 constants
14
16
E4M3_MAX_POS = torch .finfo (torch .float8_e4m3fn ).max
17
+ E4M3_FNUZ_MAX_POS = torch .finfo (torch .float8_e4m3fnuz ).max
15
18
E5M2_MAX_POS = torch .finfo (torch .float8_e5m2 ).max
19
+ E5M2_FNUZ_MAX_POS = torch .finfo (torch .float8_e5m2fnuz ).max
16
20
17
21
FP16_MAX_POS = torch .finfo (torch .float16 ).max
18
22
19
23
# avoid division by zero when calculating scale
20
24
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
21
25
EPS = 1e-12
22
26
27
+ IS_AMD = torch .cuda .is_available () and torch .version .hip is not None
28
+
23
29
24
30
@torch .no_grad ()
25
- def amax_to_scale (amax , float8_dtype , orig_dtype ):
31
+ def amax_to_scale (
32
+ amax : torch .Tensor , float8_dtype : torch .dtype , orig_dtype : torch .dtype
33
+ ):
34
+ """Converts the amax value of a tensor to the fp8 scale.
35
+ Args:
36
+ amax: The amax value of the tensor.
37
+ float8_dtype: The float8 dtype.
38
+ orig_dtype: The original dtype of the tensor.
39
+ """
26
40
scale = torch .empty_like (amax , dtype = torch .float32 )
27
41
if float8_dtype == torch .float8_e4m3fn :
28
42
res = E4M3_MAX_POS / torch .clamp (amax , min = EPS )
29
- else : # e5m2
43
+ elif float8_dtype == torch .float8_e4m3fnuz :
44
+ res = E4M3_FNUZ_MAX_POS / torch .clamp (amax , min = EPS )
45
+ elif float8_dtype == torch .float8_e5m2 :
30
46
res = E5M2_MAX_POS / torch .clamp (amax , min = EPS )
47
+ elif float8_dtype == torch .float8_e5m2fnuz :
48
+ res = E5M2_FNUZ_MAX_POS / torch .clamp (amax , min = EPS )
49
+ else :
50
+ raise ValueError (f"Unsupported float8_dtype: { float8_dtype } " )
31
51
32
52
# Ensure that the scale is representable in float16,
33
53
# this helps when amax is small. We are assuming that we don't need
@@ -40,11 +60,18 @@ def amax_to_scale(amax, float8_dtype, orig_dtype):
40
60
41
61
@torch .no_grad ()
42
62
def amax_history_to_scale (
43
- amax_history ,
44
- float8_dtype ,
45
- orig_dtype ,
46
- history_to_scale_fn_type ,
63
+ amax_history : torch . Tensor ,
64
+ float8_dtype : torch . Tensor ,
65
+ orig_dtype : torch . dtype ,
66
+ history_to_scale_fn_type : Literal [ "max" ] ,
47
67
):
68
+ """Takes in a history of amax values and returns a scale tensor.
69
+ Args:
70
+ amax_history: A tensor containing the history of amax values.
71
+ float8_dtype: The float8 dtype.
72
+ orig_dtype: The original dtype of the tensor.
73
+ history_to_scale_fn_type: The type of function to use to convert the history to a scale.
74
+ """
48
75
if history_to_scale_fn_type == "max" :
49
76
amax = torch .max (amax_history )
50
77
return amax_to_scale (amax , float8_dtype , orig_dtype )
@@ -56,9 +83,15 @@ def amax_history_to_scale_stack(
56
83
amax_history : torch .Tensor ,
57
84
float8_dtype : torch .dtype ,
58
85
orig_dtype : torch .dtype ,
59
- history_to_scale_fn_type : str ,
86
+ history_to_scale_fn_type : Literal [ "max" ] ,
60
87
) -> torch .Tensor :
61
- """Takes in a stack of amax_history tensors and returns a scale tensor."""
88
+ """Takes in a stack of amax_history tensors and returns a scale tensor.
89
+ Args:
90
+ amax_history: A 2D tensor containing a stack of amax histories.
91
+ float8_dtype: The float8 dtype.
92
+ orig_dtype: The original dtype of the tensor.
93
+ history_to_scale_fn_type: The type of function to use to convert the history to a scale.
94
+ """
62
95
if history_to_scale_fn_type == "max" :
63
96
amax_stack = torch .max (amax_history , dim = 1 ).values
64
97
return amax_to_scale (amax_stack , float8_dtype , orig_dtype )
@@ -81,26 +114,51 @@ def tensor_to_amax(x, distributed_reduction=False):
81
114
82
115
83
116
@torch .no_grad ()
84
- def tensor_to_scale (x , float8_dtype ):
117
+ def tensor_to_scale (x : torch .Tensor , float8_dtype : torch .dtype ):
118
+ """Converts a tensor to a scale tensor.
119
+ Args:
120
+ x: The tensor to calculate the scale for.
121
+ float8_dtype: The float8 dtype.
122
+ """
85
123
amax = tensor_to_amax (x )
86
124
return amax_to_scale (amax , float8_dtype , x .dtype )
87
125
88
126
89
- def to_fp8_saturated (x , float8_dtype : torch .dtype ):
90
- # The default behavior in PyTorch for casting to `float8_e4m3fn`
91
- # and `e5m2` is to not saturate. In this context, we should saturate.
92
- # A common case where we want to saturate is when the history of a
93
- # tensor has a maximum value of `amax1`, and the current amax value
94
- # is `amax2`, where `amax1 < amax2`. This is common when using delayed
95
- # scaling.
127
+ def to_fp8_saturated (x : torch .Tensor , float8_dtype : torch .dtype ):
128
+ """Converts a tensor to a saturated fp8 tensor.
129
+
130
+ Note:
131
+ The default behavior in PyTorch for casting to `float8_e4m3fn`
132
+ and `e5m2` is to not saturate. In this context, we should saturate.
133
+ A common case where we want to saturate is when the history of a
134
+ tensor has a maximum value of `amax1`, and the current amax value
135
+ is `amax2`, where `amax1 < amax2`. This is common when using delayed
136
+ scaling.
137
+ """
138
+
96
139
if float8_dtype == torch .float8_e4m3fn :
97
140
x = x .clamp (min = - 1 * E4M3_MAX_POS , max = E4M3_MAX_POS )
98
- else :
141
+ elif float8_dtype == torch .float8_e4m3fnuz :
142
+ x = x .clamp (min = - 1 * E4M3_FNUZ_MAX_POS , max = E4M3_FNUZ_MAX_POS )
143
+ elif float8_dtype == torch .float8_e5m2 :
99
144
x = x .clamp (min = - 1 * E5M2_MAX_POS , max = E5M2_MAX_POS )
145
+ elif float8_dtype == torch .float8_e5m2fnuz :
146
+ x = x .clamp (min = - 1 * E5M2_FNUZ_MAX_POS , max = E5M2_FNUZ_MAX_POS )
147
+ else :
148
+ raise ValueError (f"Unsupported float8_dtype: { float8_dtype } " )
100
149
return x .to (float8_dtype )
101
150
102
151
103
- def compute_error (x , y ):
152
+ def compute_error (x : torch .Tensor , y : torch .Tensor ):
153
+ """Computes the error between two tensors in dB.
154
+
155
+ For more details see:
156
+ https://en.wikipedia.org/wiki/Signal-to-noise_ratio
157
+
158
+ Args:
159
+ x: The original tensor.
160
+ y: The tensor to compare to the original tensor.
161
+ """
104
162
Ps = torch .norm (x )
105
163
Pn = torch .norm (x - y )
106
164
return 20 * torch .log10 (Ps / Pn )
0 commit comments