Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 6f22688

Browse files
drisspgfacebook-github-bot
authored andcommitted
Make code more dry and standardize on to_fp8_no_autograd (#228)
Summary: Remove direct Float Constructor invocations inside torch.autograd functions and use the `to_fp8_no_autograd` function The is also fixes a problem where the DTensor tests weren't in test_everything.sh which lead to the torch__dispatch being incorrect, also fixed now. `./test/test_everything` is all green Pull Request resolved: #228 Reviewed By: wanchaol Differential Revision: D54235141 Pulled By: drisspg fbshipit-source-id: ac71256e84292d03ea19bb2ea3428d99625963b9
1 parent b67e5cf commit 6f22688

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

float8_experimental/float8_linear.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020

2121
import torch
2222

23-
from float8_experimental.float8_tensor import Float8Tensor
23+
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
2424

2525
from float8_experimental.float8_utils import (
2626
amax_history_to_scale,
2727
E4M3_MAX_POS,
2828
E5M2_MAX_POS,
2929
tensor_to_amax,
30-
to_fp8_saturated,
3130
)
3231

3332

@@ -99,10 +98,9 @@ def backward(ctx, go):
9998
)
10099

101100
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
102-
go_scaled = go * fp8_scale_dL_dY
103-
bits_fp8 = to_fp8_saturated(go_scaled, torch.float8_e5m2)
101+
102+
res = to_fp8_no_autograd(go, fp8_scale_dL_dY, torch.float8_e5m2, ctx.emulate)
104103
empty_grads = None, None, None, None, None, None
105-
res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype, emulate=ctx.emulate)
106104
return res, *empty_grads
107105

108106

float8_experimental/float8_tensor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
266266

267267
# All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs
268268
# And don't support mixed tensor subclasses. This will trigger the handler for
269-
# the next type in the dispatch list. torch._C._TensorMeta is for FakeTensor
269+
# the next type in the dispatch list
270270
def allowed_subclasses(type):
271-
return issubclass(cls, type) or isinstance(type, torch._C._TensorMeta)
271+
return (
272+
issubclass(cls, type)
273+
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
274+
or issubclass(
275+
torch._subclasses.functional_tensor.FunctionalTensor, type
276+
)
277+
)
272278

273279
if not all(allowed_subclasses(t) for t in types):
274280
return NotImplemented

float8_experimental/float8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def tensor_to_scale(x, float8_dtype):
8686
return amax_to_scale(amax, float8_dtype, x.dtype)
8787

8888

89-
def to_fp8_saturated(x, float8_dtype):
89+
def to_fp8_saturated(x, float8_dtype: torch.dtype):
9090
# The default behavior in PyTorch for casting to `float8_e4m3fn`
9191
# and `e5m2` is to not saturate. In this context, we should saturate.
9292
# A common case where we want to saturate is when the history of a

test/test_everything.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ pytest test/test_compile.py
99
./test/test_fsdp.sh
1010
./test/test_fsdp_compile.sh
1111
./test/test_tp.sh
12+
./test/test_dtensor.sh
1213

1314
echo "all tests successful"

0 commit comments

Comments
 (0)