Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ca84eb4

Browse files
committed
standardize on use to_fp8_no_autograd
1 parent b67e5cf commit ca84eb4

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

float8_experimental/float8_linear.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020

2121
import torch
2222

23-
from float8_experimental.float8_tensor import Float8Tensor
23+
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
2424

2525
from float8_experimental.float8_utils import (
2626
amax_history_to_scale,
2727
E4M3_MAX_POS,
2828
E5M2_MAX_POS,
2929
tensor_to_amax,
30-
to_fp8_saturated,
3130
)
3231

3332

@@ -99,10 +98,9 @@ def backward(ctx, go):
9998
)
10099

101100
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
102-
go_scaled = go * fp8_scale_dL_dY
103-
bits_fp8 = to_fp8_saturated(go_scaled, torch.float8_e5m2)
101+
102+
res = to_fp8_no_autograd(go, fp8_scale_dL_dY, torch.float8_e5m2, ctx.emulate)
104103
empty_grads = None, None, None, None, None, None
105-
res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype, emulate=ctx.emulate)
106104
return res, *empty_grads
107105

108106

float8_experimental/float8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def tensor_to_scale(x, float8_dtype):
8686
return amax_to_scale(amax, float8_dtype, x.dtype)
8787

8888

89-
def to_fp8_saturated(x, float8_dtype):
89+
def to_fp8_saturated(x, float8_dtype: torch.dtype):
9090
# The default behavior in PyTorch for casting to `float8_e4m3fn`
9191
# and `e5m2` is to not saturate. In this context, we should saturate.
9292
# A common case where we want to saturate is when the history of a

0 commit comments

Comments
 (0)