Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 4a27a27

Browse files
committed
still some numeric issues on amd:
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[True-LinearType.DYNAMIC-x_shape0-False] - AssertionError: -3.183703660964966 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[True-LinearType.DYNAMIC-x_shape1-False] - AssertionError: -3.2964067459106445 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[True-LinearType.DYNAMIC-x_shape2-False] - AssertionError: -3.091813564300537 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape0-True] - AssertionError: 7.574269771575928 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape0-False] - AssertionError: -2.132262706756592 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape1-True] - AssertionError: 8.139453887939453 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape1-False] - AssertionError: -1.483538269996643 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape2-True] - AssertionError: 8.950117111206055 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape2-False] - AssertionError: -1.840381145477295 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DYNAMIC-x_shape0-False] - AssertionError: -3.1304943561553955 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DYNAMIC-x_shape1-False] - AssertionError: -3.246392250061035 is too low FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DYNAMIC-x_shape2-False] - AssertionError: -3.015180826187134 is too low
1 parent 829e9c4 commit 4a27a27

File tree

6 files changed

+83
-42
lines changed

6 files changed

+83
-42
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
88
"""
99
import torch
10+
from typing import Optional
1011

1112
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
12-
from float8_experimental.float8_utils import IS_AMD, tensor_to_scale
13+
from float8_experimental.float8_utils import IS_AMD, tensor_to_scale, FP8Dtypes
1314

1415

1516
@torch._dynamo.allow_in_graph
@@ -24,16 +25,17 @@ def forward(
2425
ctx,
2526
tensor,
2627
emulate: bool,
28+
fp8_dtype_bw: torch.dtype
2729
):
2830
ctx.emulate = emulate
31+
ctx.fp8_dtype_bw = fp8_dtype_bw
2932
return tensor
3033

3134
@staticmethod
3235
def backward(ctx, gradY):
33-
fp8_dtype = torch.float8_e5m2fnuz if IS_AMD else torch.float8_e5m2
34-
gradY_scale = tensor_to_scale(gradY, fp8_dtype)
35-
fp8_tensor = to_fp8_no_autograd(gradY, gradY_scale, fp8_dtype, ctx.emulate)
36-
return fp8_tensor, None
36+
gradY_scale = tensor_to_scale(gradY, ctx.fp8_dtype_bw)
37+
fp8_tensor = to_fp8_no_autograd(gradY, gradY_scale, ctx.fp8_dtype_bw, ctx.emulate)
38+
return fp8_tensor, None, None
3739

3840

3941
def cast_x_to_float8_e4m3fn_pre_hook(module, args):
@@ -61,18 +63,23 @@ class Float8DynamicLinear(torch.nn.Linear):
6163
conversion to fp8 of the input and weight tensors.
6264
"""
6365

64-
def __init__(self, use_activation_hooks: bool, **super_kwargs):
66+
def __init__(self, use_activation_hooks: bool, fp8_dtype: FP8Dtypes, **super_kwargs):
6567
"""
6668
Args:
6769
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
70+
fp8_dtype (torch.dtype): the dtype to use for fp8
6871
"""
6972
super().__init__(**super_kwargs)
7073

7174
self.use_activation_hooks = use_activation_hooks
75+
# I want to store the dataclass but I think that will break torch compile
76+
self.fp8_dtype_fw = fp8_dtype.fp8_dtype_fw
77+
self.fp8_dtype_bw = fp8_dtype.fp8_dtype_bw
78+
self.emulate = False
7279

73-
def forward(self, x):
80+
def forward(self, input):
7481
# cast x to float8_e4m3fn if not using activation hooks
75-
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3(x)
82+
x_fp8 = input if self.use_activation_hooks else self.cast_to_float8_e4m3(input)
7683

7784
# cast w to float8_e4m3fn
7885
w_fp8 = self.cast_to_float8_e4m3(self.weight)
@@ -94,10 +101,9 @@ def cast_to_float8_e4m3(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
94101
- On AMD Gpus, it casts to torch.float8_e4m3fnuz
95102
96103
"""
97-
fp8_dtype = torch.float8_e4m3fnuz if IS_AMD else torch.float8_e4m3fn
98-
scale = tensor_to_scale(inpt_tensor, fp8_dtype)
104+
scale = tensor_to_scale(inpt_tensor, self.fp8_dtype_fw)
99105
return Float8Tensor.to_float8(
100-
inpt_tensor, scale, fp8_dtype, emulate=self.emulate
106+
inpt_tensor, scale, self.fp8_dtype_fw, emulate=self.emulate
101107
)
102108

103109
def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
@@ -110,11 +116,11 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
110116
- On AMD Gpus, it casts to torch.float8_e4m3fnuz
111117
112118
"""
113-
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate)
119+
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate, self.fp8_dtype_bw)
114120

115121
@classmethod
116122
def from_float(
117-
cls, mod, emulate: bool = False, use_activation_hooks: bool = False
123+
cls, mod, emulate: bool = False, use_activation_hooks: bool = False, fp8_dtypes: Optional[FP8Dtypes] = None
118124
) -> "Float8DynamicLinear":
119125
"""
120126
Create an nn.Linear with fp8 compute from a regular nn.Linear
@@ -124,13 +130,15 @@ def from_float(
124130
emulate (bool): whether to emulate fp8 matmul logic in float32
125131
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
126132
"""
133+
if fp8_dtypes is None:
134+
fp8_dtypes = FP8Dtypes()
127135
with torch.device("meta"):
128136
super_kwargs = {
129137
"in_features": mod.in_features,
130138
"out_features": mod.out_features,
131139
"bias": False,
132140
}
133-
new_mod = cls(use_activation_hooks, **super_kwargs)
141+
new_mod = cls(use_activation_hooks, fp8_dtypes, **super_kwargs)
134142
new_mod.weight = mod.weight
135143
new_mod.bias = mod.bias
136144
new_mod.emulate = emulate

float8_experimental/float8_linear.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import dataclasses
1616

17-
from typing import Optional
17+
from typing import Optional, Literal
1818

1919
import float8_experimental.config as config
2020

@@ -27,17 +27,18 @@
2727
E4M3_MAX_POS,
2828
E5M2_MAX_POS,
2929
tensor_to_amax,
30+
FP8Dtypes
3031
)
3132

3233

3334
def _maybe_initialize_amaxes_scales_for_float8_cast(
34-
x,
35-
cur_amax,
36-
amax_history,
37-
scale,
38-
scale_fn_name,
39-
float8_dtype,
40-
is_initialized,
35+
x: torch.Tensor,
36+
cur_amax: torch.Tensor,
37+
amax_history: torch.Tensor,
38+
scale: torch.Tensor,
39+
scale_fn_name: Literal["max"],
40+
float8_dtype: torch.dtype,
41+
is_initialized: bool,
4142
):
4243
"""
4344
If x is about to be cast to `float8` and the amax buffers are not initialized,
@@ -74,11 +75,13 @@ def forward(
7475
scale_fn_name,
7576
is_amax_initialized,
7677
emulate: bool,
78+
fp8_dtype: torch.dtype,
7779
):
7880
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
7981
ctx.scale_fn_name = scale_fn_name
8082
ctx.is_amax_initialized = is_amax_initialized
8183
ctx.emulate = emulate
84+
ctx.fp8_dtype = fp8_dtype
8285
return tensor
8386

8487
@staticmethod
@@ -93,14 +96,14 @@ def backward(ctx, go):
9396
fp8_amax_history_dL_dY,
9497
fp8_scale_dL_dY,
9598
scale_fn_name,
96-
torch.float8_e5m2,
99+
ctx.fp8_dtype,
97100
is_amax_initialized,
98101
)
99102

100103
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
101104

102-
res = to_fp8_no_autograd(go, fp8_scale_dL_dY, torch.float8_e5m2, ctx.emulate)
103-
empty_grads = None, None, None, None, None, None
105+
res = to_fp8_no_autograd(go, fp8_scale_dL_dY, ctx.fp8_dtype, ctx.emulate)
106+
empty_grads = None, None, None, None, None, None, None
104107
return res, *empty_grads
105108

106109

@@ -178,6 +181,14 @@ def __init__(self, *args, **kwargs):
178181
# and torch.compile, this option can disable them
179182
self.enable_pre_and_post_forward = config.enable_pre_and_post_forward
180183

184+
# In the forward we will cast both the activation and weight to float8
185+
# There currenlty 4 different variants in pytorch, see
186+
# https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md
187+
# fp8_dtype_fw will be the tyep used for casting the activation and weight
188+
# fp8_dtype_bw will be the typeused for casting the gradient
189+
self.fp8_dtype_fw = torch.float8_e4m3fn
190+
self.fp8_dtype_bw = torch.float8_e5m2
191+
181192
def register_always_float32_buffer(
182193
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
183194
) -> None:
@@ -212,11 +223,11 @@ def cast_x_to_float8(
212223
self.fp8_amax_history_x,
213224
self.fp8_scale_x,
214225
scale_fn_name,
215-
torch.float8_e4m3fn,
226+
self.fp8_dtype_fw,
216227
is_amax_initialized,
217228
)
218229
x_fp8 = Float8Tensor.to_float8(
219-
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x, self.emulate
230+
x, self.fp8_scale_x, self.fp8_dtype_fw, self.fp8_amax_x, self.emulate
220231
)
221232
return x_fp8
222233

@@ -230,14 +241,14 @@ def cast_w_to_float8(
230241
self.fp8_amax_history_w,
231242
self.fp8_scale_w,
232243
scale_fn_name,
233-
torch.float8_e4m3fn,
244+
self.fp8_dtype_fw,
234245
is_amax_initialized,
235246
)
236247

237248
w_fp8 = Float8Tensor.to_float8(
238249
w,
239250
self.fp8_scale_w,
240-
torch.float8_e4m3fn,
251+
self.fp8_dtype_fw,
241252
self.fp8_amax_w,
242253
self.emulate,
243254
)
@@ -255,6 +266,7 @@ def cast_y_to_float8_in_bw(
255266
scale_fn_name,
256267
self.is_amax_initialized,
257268
emulate,
269+
self.fp8_dtype_bw,
258270
)
259271
return y
260272

@@ -286,10 +298,10 @@ class Float8Linear(Float8LinearMixin, torch.nn.Linear):
286298
scales in way friendly to delayed scaling.
287299
"""
288300

289-
def forward(self, x):
290-
self.float8_pre_forward(x)
301+
def forward(self, input):
302+
self.float8_pre_forward(input)
291303

292-
x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
304+
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
293305
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
294306

295307
y = torch.matmul(x_fp8, w_fp8.t())
@@ -304,7 +316,7 @@ def forward(self, x):
304316
return y
305317

306318
@classmethod
307-
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False):
319+
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False, fp8_dtypes: Optional[FP8Dtypes] = None):
308320
"""
309321
Create an nn.Linear with fp8 compute from a regular nn.Linear
310322
@@ -314,13 +326,17 @@ def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = Fal
314326
use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic
315327
"""
316328
assert not use_activation_hooks, "use_activation_hooks is not supported yet!"
329+
if fp8_dtypes is None:
330+
fp8_dtypes = FP8Dtypes()
317331
# TODO Follow up! This is a great idea but we need the mixin base to create real
318332
# Tensors and the Linear base to create empty params
319333
# with torch.device("meta"):
320334
new_mod = cls(mod.in_features, mod.out_features, bias=False)
321335
new_mod.weight = mod.weight
322336
new_mod.bias = mod.bias
323337
new_mod.emulate = emulate
338+
new_mod.fp8_dtype_fw = fp8_dtypes.fp8_dtype_fw
339+
new_mod.fp8_dtype_bw = fp8_dtypes.fp8_dtype_bw
324340
# I think its okay to send all params and buffers to device
325341
new_mod.to(mod.weight.device)
326342
return new_mod

float8_experimental/float8_linear_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1515
from float8_experimental.float8_linear import Float8Linear
1616

17-
from float8_experimental.float8_utils import amax_history_to_scale_stack
17+
from float8_experimental.float8_utils import amax_history_to_scale_stack, FP8Dtypes
1818
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor
1919

2020
log = logging.getLogger(__name__)
@@ -34,13 +34,15 @@ def get_float8_linear(
3434
linear_ref: torch.nn.Linear,
3535
emulate: bool = False,
3636
use_activation_hooks: bool = False,
37+
fp8_dtypes: Optional[FP8Dtypes] = None,
3738
):
3839
"""Returns a Float8Linear module of the given type, initialized from linear_ref.
3940
Args:
4041
linear_type: The type of Float8Linear to return.
4142
linear_ref: The linear module to initialize from.
4243
emulate: Whether to emulate the fp8 matmul logic in float32.
4344
use_activation_hooks: Whether to use activation hooks for dynamic linear.
45+
fp8_dtypes: The FP8 dtypes to use.
4446
"""
4547
LINEAR_TYPE_MAP = {
4648
LinearType.DELAYED: Float8Linear,
@@ -54,6 +56,7 @@ def get_float8_linear(
5456
copy.deepcopy(linear_ref),
5557
emulate=emulate,
5658
use_activation_hooks=use_activation_hooks,
59+
fp8_dtypes=fp8_dtypes,
5760
)
5861

5962

float8_experimental/float8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def to_float8(
230230
float8_dtype: torch.dtype,
231231
amax_buffer: Optional[torch.Tensor] = None,
232232
emulate: bool = False,
233-
):
233+
)-> "Float8Tensor":
234234
"""Converts a higher precision tensor to float8 in a differentiable way.
235235
236236
Args:

float8_experimental/float8_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from typing import Literal
8+
from dataclasses import dataclass
89

910
import torch
1011
import torch.distributed as dist
@@ -27,6 +28,12 @@
2728
IS_AMD = torch.cuda.is_available() and torch.version.hip is not None
2829

2930

31+
@dataclass(frozen=True)
32+
class FP8Dtypes:
33+
""" Defines the fp8 dtypes to be used in forward and backwrad computations"""
34+
fp8_dtype_fw: torch.dtype = torch.float8_e4m3fn
35+
fp8_dtype_bw: torch.dtype = torch.float8_e5m2
36+
3037
@torch.no_grad()
3138
def amax_to_scale(
3239
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
@@ -61,7 +68,7 @@ def amax_to_scale(
6168
@torch.no_grad()
6269
def amax_history_to_scale(
6370
amax_history: torch.Tensor,
64-
float8_dtype: torch.Tensor,
71+
float8_dtype: torch.dtype,
6572
orig_dtype: torch.dtype,
6673
history_to_scale_fn_type: Literal["max"],
6774
):

0 commit comments

Comments
 (0)