Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 5fc07fc

Browse files
drisspgfacebook-github-bot
authored andcommitted
Adds utilities for AMD fp8 dtype support, follow up PR to add option to the configs (#235)
Summary: AMD GPUS support a different fp8 dtype compared to nvidia. These dtypes were added to PyTorch and we update Float8Tensor construction to use the format dependent on the arch. For a detailed summary see: https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md Pull Request resolved: #235 Reviewed By: malfet Differential Revision: D58044802 Pulled By: drisspg fbshipit-source-id: fed15edaceceaa79b3fbcc9644dd51aee3641dd6
1 parent cdb7867 commit 5fc07fc

File tree

4 files changed

+96
-53
lines changed

4 files changed

+96
-53
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
8888
"bias": False,
8989
}
9090
new_mod = cls(**super_kwargs)
91-
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
91+
new_mod.forward_config = ScaledMMConfig(emulate, not bool(emulate))
9292
new_mod.backward_config = ScaledMMConfig(emulate, False)
9393
if config.enable_fsdp_fp8_all_gather:
9494
new_mod.weight = nn.Parameter(

float8_experimental/float8_linear.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,7 @@
2626
to_fp8_no_autograd,
2727
)
2828

29-
from float8_experimental.float8_utils import (
30-
amax_history_to_scale,
31-
E4M3_MAX_POS,
32-
E5M2_MAX_POS,
33-
tensor_to_amax,
34-
)
29+
from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax
3530

3631

3732
def _maybe_initialize_amaxes_scales_for_float8_cast(
@@ -142,18 +137,23 @@ def __init__(self, *args, **kwargs):
142137
self.recipe = delayed_scaling_recipe
143138
history_len = self.recipe.history_len
144139

145-
self.register_always_float32_buffer("fp8_amax_x", torch.tensor([E4M3_MAX_POS]))
140+
# Default values for history buffers, see above TODO
141+
default_x = torch.finfo(torch.float8_e4m3fn).max
142+
default_w = torch.finfo(torch.float8_e4m3fn).max
143+
default_dl_dy = torch.finfo(torch.float8_e5m2).max
144+
145+
self.register_always_float32_buffer("fp8_amax_x", torch.tensor([default_x]))
146146
self.register_always_float32_buffer(
147147
"fp8_amax_history_x", torch.zeros(history_len)
148148
)
149149
self.register_always_float32_buffer("fp8_scale_x", torch.tensor([1.0]))
150-
self.register_always_float32_buffer("fp8_amax_w", torch.tensor([E4M3_MAX_POS]))
150+
self.register_always_float32_buffer("fp8_amax_w", torch.tensor([default_w]))
151151
self.register_always_float32_buffer(
152152
"fp8_amax_history_w", torch.zeros(history_len)
153153
)
154154
self.register_always_float32_buffer("fp8_scale_w", torch.tensor([1.0]))
155155
self.register_always_float32_buffer(
156-
"fp8_amax_dL_dY", torch.tensor([E5M2_MAX_POS])
156+
"fp8_amax_dL_dY", torch.tensor([default_dl_dy])
157157
)
158158
self.register_always_float32_buffer(
159159
"fp8_amax_history_dL_dY", torch.zeros(history_len)

float8_experimental/float8_utils.py

Lines changed: 81 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,66 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Tuple
7+
from typing import Literal, Tuple
88

99
import torch
1010
import torch.distributed as dist
1111

1212
# Helpful visualizer for debugging (only supports fp32):
1313
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
1414

15-
# define the e4m3/e5m2 constants
16-
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
17-
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
18-
19-
FP16_MAX_POS = torch.finfo(torch.float16).max
20-
2115
# avoid division by zero when calculating scale
2216
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
2317
EPS = 1e-12
2418

19+
IS_AMD = torch.cuda.is_available() and torch.version.hip is not None
20+
FP8_TYPES = {
21+
torch.float8_e4m3fn,
22+
torch.float8_e5m2,
23+
torch.float8_e4m3fnuz,
24+
torch.float8_e5m2fnuz,
25+
}
26+
2527

2628
@torch.no_grad()
27-
def amax_to_scale(amax, float8_dtype, orig_dtype):
29+
def amax_to_scale(
30+
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
31+
):
32+
"""Converts the amax value of a tensor to the fp8 scale.
33+
Args:
34+
amax: The amax value of the tensor.
35+
float8_dtype: The float8 dtype.
36+
orig_dtype: The original dtype of the tensor.
37+
"""
2838
scale = torch.empty_like(amax, dtype=torch.float32)
29-
if float8_dtype == torch.float8_e4m3fn:
30-
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
31-
else: # e5m2
32-
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
39+
if float8_dtype in FP8_TYPES:
40+
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
41+
else:
42+
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
3343

3444
# Ensure that the scale is representable in float16,
3545
# this helps when amax is small. We are assuming that we don't need
3646
# to care about this for float32/bfloat16.
3747
if orig_dtype is torch.float16:
38-
res = torch.clamp(res, max=FP16_MAX_POS)
48+
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
3949
scale.copy_(res)
4050
return scale
4151

4252

4353
@torch.no_grad()
4454
def amax_history_to_scale(
45-
amax_history,
46-
float8_dtype,
47-
orig_dtype,
48-
history_to_scale_fn_type,
55+
amax_history: torch.Tensor,
56+
float8_dtype: torch.Tensor,
57+
orig_dtype: torch.dtype,
58+
history_to_scale_fn_type: Literal["max"],
4959
):
60+
"""Takes in a history of amax values and returns a scale tensor.
61+
Args:
62+
amax_history: A tensor containing the history of amax values.
63+
float8_dtype: The float8 dtype.
64+
orig_dtype: The original dtype of the tensor.
65+
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
66+
"""
5067
if history_to_scale_fn_type == "max":
5168
amax = torch.max(amax_history)
5269
return amax_to_scale(amax, float8_dtype, orig_dtype)
@@ -58,9 +75,15 @@ def amax_history_to_scale_stack(
5875
amax_history: torch.Tensor,
5976
float8_dtype: torch.dtype,
6077
orig_dtype: torch.dtype,
61-
history_to_scale_fn_type: str,
78+
history_to_scale_fn_type: Literal["max"],
6279
) -> torch.Tensor:
63-
"""Takes in a stack of amax_history tensors and returns a scale tensor."""
80+
"""Takes in a stack of amax_history tensors and returns a scale tensor.
81+
Args:
82+
amax_history: A 2D tensor containing a stack of amax histories.
83+
float8_dtype: The float8 dtype.
84+
orig_dtype: The original dtype of the tensor.
85+
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
86+
"""
6487
if history_to_scale_fn_type == "max":
6588
amax_stack = torch.max(amax_history, dim=1).values
6689
return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
@@ -90,21 +113,35 @@ def tensor_to_scale(
90113
return amax_to_scale(amax, float8_dtype, x.dtype)
91114

92115

93-
def to_fp8_saturated(x, float8_dtype: torch.dtype):
94-
# The default behavior in PyTorch for casting to `float8_e4m3fn`
95-
# and `e5m2` is to not saturate. In this context, we should saturate.
96-
# A common case where we want to saturate is when the history of a
97-
# tensor has a maximum value of `amax1`, and the current amax value
98-
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
99-
# scaling.
100-
if float8_dtype == torch.float8_e4m3fn:
101-
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
116+
def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
117+
"""Converts a tensor to a saturated fp8 tensor.
118+
119+
Note:
120+
The default behavior in PyTorch for casting to `float8_e4m3fn`
121+
and `e5m2` is to not saturate. In this context, we should saturate.
122+
A common case where we want to saturate is when the history of a
123+
tensor has a maximum value of `amax1`, and the current amax value
124+
is `amax2`, where `amax1 < amax2`. This is common when using delayed
125+
scaling.
126+
"""
127+
if float8_dtype in FP8_TYPES:
128+
max_value = torch.finfo(float8_dtype).max
129+
x = x.clamp(min=-max_value, max=max_value)
130+
return x.to(float8_dtype)
102131
else:
103-
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
104-
return x.to(float8_dtype)
132+
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
133+
134+
135+
def compute_error(x: torch.Tensor, y: torch.Tensor):
136+
"""Computes the error between two tensors in dB.
105137
138+
For more details see:
139+
https://en.wikipedia.org/wiki/Signal-to-noise_ratio
106140
107-
def compute_error(x, y):
141+
Args:
142+
x: The original tensor.
143+
y: The tensor to compare to the original tensor.
144+
"""
108145
Ps = torch.norm(x)
109146
Pn = torch.norm(x - y)
110147
return 20 * torch.log10(Ps / Pn)
@@ -113,11 +150,19 @@ def compute_error(x, y):
113150
def fp8_tensor_statistics(
114151
tensor: torch.Tensor, float8_dtype=torch.float8_e4m3fn
115152
) -> Tuple[int, ...]:
116-
"""Calculate FP8 tensor stats"""
117-
if float8_dtype == torch.float8_e4m3fn:
118-
FP8_MAX = E4M3_MAX_POS
119-
else: # e5m2
120-
FP8_MAX = E5M2_MAX_POS
153+
"""Calculate FP8 tensor stats
154+
155+
Args:
156+
tensor: The tensor to calculate stats for.
157+
float8_dtype: The float8 dtype.
158+
159+
Returns:
160+
A tuple containing the number of zeros and the number of max values.
161+
"""
162+
if float8_dtype in FP8_TYPES:
163+
FP8_MAX = torch.finfo(float8_dtype).max
164+
else:
165+
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
121166
tensor_orig_type = tensor._data.to(dtype=tensor._orig_dtype)
122167
num_max = (torch.abs(tensor_orig_type) == FP8_MAX).sum().item()
123168
num_zero = (tensor_orig_type == 0).sum().item()

test/test_base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@
3131
from float8_experimental.float8_utils import (
3232
amax_to_scale,
3333
compute_error,
34-
E4M3_MAX_POS,
35-
E5M2_MAX_POS,
36-
FP16_MAX_POS,
3734
fp8_tensor_statistics,
35+
FP8_TYPES,
3836
tensor_to_scale,
3937
)
4038

@@ -118,9 +116,10 @@ def _test_linear_impl(
118116
"fp8_amax_w",
119117
"fp8_amax_dL_dY",
120118
]
119+
max_float8_pos = {torch.finfo(dtype).max for dtype in FP8_TYPES}
121120
for buffer_name in amax_buffer_names:
122121
buffer_value = getattr(m_fp8, buffer_name)
123-
for init_val in (E4M3_MAX_POS, E5M2_MAX_POS):
122+
for init_val in max_float8_pos:
124123
assert torch.ne(
125124
buffer_value, torch.tensor(init_val)
126125
), f"{buffer_name} not filled, current value {buffer_value}"
@@ -412,9 +411,8 @@ def test_small_amax_float16(self, float8_dtype):
412411
#
413412
# amax + eps >= fp8_max_pos / fp16_max_pos
414413

415-
float8_max_pos = (
416-
E4M3_MAX_POS if float8_dtype is torch.float8_e4m3fn else E5M2_MAX_POS
417-
)
414+
float8_max_pos = torch.finfo(float8_dtype).max
415+
FP16_MAX_POS = torch.finfo(torch.float16).max
418416

419417
target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12)
420418
x = torch.tensor([target_amax], dtype=torch.float16, device="cuda")

0 commit comments

Comments
 (0)