Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Make code more dry and standardize on to_fp8_no_autograd #228

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@

import torch

from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd

from float8_experimental.float8_utils import (
amax_history_to_scale,
E4M3_MAX_POS,
E5M2_MAX_POS,
tensor_to_amax,
to_fp8_saturated,
)


Expand Down Expand Up @@ -99,10 +98,9 @@ def backward(ctx, go):
)

fp8_amax_dL_dY.fill_(tensor_to_amax(go))
go_scaled = go * fp8_scale_dL_dY
bits_fp8 = to_fp8_saturated(go_scaled, torch.float8_e5m2)

res = to_fp8_no_autograd(go, fp8_scale_dL_dY, torch.float8_e5m2, ctx.emulate)
empty_grads = None, None, None, None, None, None
res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype, emulate=ctx.emulate)
return res, *empty_grads


Expand Down
10 changes: 8 additions & 2 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):

# All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs
# And don't support mixed tensor subclasses. This will trigger the handler for
# the next type in the dispatch list. torch._C._TensorMeta is for FakeTensor
# the next type in the dispatch list
def allowed_subclasses(type):
return issubclass(cls, type) or isinstance(type, torch._C._TensorMeta)
return (
issubclass(cls, type)
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
or issubclass(
torch._subclasses.functional_tensor.FunctionalTensor, type
)
)

if not all(allowed_subclasses(t) for t in types):
return NotImplemented
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def tensor_to_scale(x, float8_dtype):
return amax_to_scale(amax, float8_dtype, x.dtype)


def to_fp8_saturated(x, float8_dtype):
def to_fp8_saturated(x, float8_dtype: torch.dtype):
# The default behavior in PyTorch for casting to `float8_e4m3fn`
# and `e5m2` is to not saturate. In this context, we should saturate.
# A common case where we want to saturate is when the history of a
Expand Down
1 change: 1 addition & 0 deletions test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ pytest test/test_compile.py
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
./test/test_tp.sh
./test/test_dtensor.sh

echo "all tests successful"