Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ba9b5dd

Browse files
committed
Skipping SAM for now since it hangs
1 parent cdb7867 commit ba9b5dd

File tree

1 file changed

+69
-18
lines changed

1 file changed

+69
-18
lines changed

float8_experimental/float8_utils.py

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Tuple
7+
from typing import Literal, Tuple
88

99
import torch
1010
import torch.distributed as dist
@@ -14,22 +14,40 @@
1414

1515
# define the e4m3/e5m2 constants
1616
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
17+
E4M3_FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
1718
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
19+
E5M2_FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
1820

1921
FP16_MAX_POS = torch.finfo(torch.float16).max
2022

2123
# avoid division by zero when calculating scale
2224
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
2325
EPS = 1e-12
2426

27+
IS_AMD = torch.cuda.is_available() and torch.version.hip is not None
28+
2529

2630
@torch.no_grad()
27-
def amax_to_scale(amax, float8_dtype, orig_dtype):
31+
def amax_to_scale(
32+
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
33+
):
34+
"""Converts the amax value of a tensor to the fp8 scale.
35+
Args:
36+
amax: The amax value of the tensor.
37+
float8_dtype: The float8 dtype.
38+
orig_dtype: The original dtype of the tensor.
39+
"""
2840
scale = torch.empty_like(amax, dtype=torch.float32)
2941
if float8_dtype == torch.float8_e4m3fn:
3042
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
31-
else: # e5m2
43+
elif float8_dtype == torch.float8_e4m3fnuz:
44+
res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
45+
elif float8_dtype == torch.float8_e5m2:
3246
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
47+
elif float8_dtype == torch.float8_e5m2fnuz:
48+
res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
49+
else:
50+
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
3351

3452
# Ensure that the scale is representable in float16,
3553
# this helps when amax is small. We are assuming that we don't need
@@ -42,11 +60,18 @@ def amax_to_scale(amax, float8_dtype, orig_dtype):
4260

4361
@torch.no_grad()
4462
def amax_history_to_scale(
45-
amax_history,
46-
float8_dtype,
47-
orig_dtype,
48-
history_to_scale_fn_type,
63+
amax_history: torch.Tensor,
64+
float8_dtype: torch.Tensor,
65+
orig_dtype: torch.dtype,
66+
history_to_scale_fn_type: Literal["max"],
4967
):
68+
"""Takes in a history of amax values and returns a scale tensor.
69+
Args:
70+
amax_history: A tensor containing the history of amax values.
71+
float8_dtype: The float8 dtype.
72+
orig_dtype: The original dtype of the tensor.
73+
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
74+
"""
5075
if history_to_scale_fn_type == "max":
5176
amax = torch.max(amax_history)
5277
return amax_to_scale(amax, float8_dtype, orig_dtype)
@@ -58,9 +83,15 @@ def amax_history_to_scale_stack(
5883
amax_history: torch.Tensor,
5984
float8_dtype: torch.dtype,
6085
orig_dtype: torch.dtype,
61-
history_to_scale_fn_type: str,
86+
history_to_scale_fn_type: Literal["max"],
6287
) -> torch.Tensor:
63-
"""Takes in a stack of amax_history tensors and returns a scale tensor."""
88+
"""Takes in a stack of amax_history tensors and returns a scale tensor.
89+
Args:
90+
amax_history: A 2D tensor containing a stack of amax histories.
91+
float8_dtype: The float8 dtype.
92+
orig_dtype: The original dtype of the tensor.
93+
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
94+
"""
6495
if history_to_scale_fn_type == "max":
6596
amax_stack = torch.max(amax_history, dim=1).values
6697
return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
@@ -90,21 +121,41 @@ def tensor_to_scale(
90121
return amax_to_scale(amax, float8_dtype, x.dtype)
91122

92123

93-
def to_fp8_saturated(x, float8_dtype: torch.dtype):
94-
# The default behavior in PyTorch for casting to `float8_e4m3fn`
95-
# and `e5m2` is to not saturate. In this context, we should saturate.
96-
# A common case where we want to saturate is when the history of a
97-
# tensor has a maximum value of `amax1`, and the current amax value
98-
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
99-
# scaling.
124+
def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
125+
"""Converts a tensor to a saturated fp8 tensor.
126+
127+
Note:
128+
The default behavior in PyTorch for casting to `float8_e4m3fn`
129+
and `e5m2` is to not saturate. In this context, we should saturate.
130+
A common case where we want to saturate is when the history of a
131+
tensor has a maximum value of `amax1`, and the current amax value
132+
is `amax2`, where `amax1 < amax2`. This is common when using delayed
133+
scaling.
134+
"""
135+
100136
if float8_dtype == torch.float8_e4m3fn:
101137
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
102-
else:
138+
elif float8_dtype == torch.float8_e4m3fnuz:
139+
x = x.clamp(min=-1 * E4M3_FNUZ_MAX_POS, max=E4M3_FNUZ_MAX_POS)
140+
elif float8_dtype == torch.float8_e5m2:
103141
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
142+
elif float8_dtype == torch.float8_e5m2fnuz:
143+
x = x.clamp(min=-1 * E5M2_FNUZ_MAX_POS, max=E5M2_FNUZ_MAX_POS)
144+
else:
145+
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
104146
return x.to(float8_dtype)
105147

106148

107-
def compute_error(x, y):
149+
def compute_error(x: torch.Tensor, y: torch.Tensor):
150+
"""Computes the error between two tensors in dB.
151+
152+
For more details see:
153+
https://en.wikipedia.org/wiki/Signal-to-noise_ratio
154+
155+
Args:
156+
x: The original tensor.
157+
y: The tensor to compare to the original tensor.
158+
"""
108159
Ps = torch.norm(x)
109160
Pn = torch.norm(x - y)
110161
return 20 * torch.log10(Ps / Pn)

0 commit comments

Comments
 (0)