Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

make the backward of differentiable float8 casts pass gradient as is #255

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pytest test/test_compile.py
./test/test_tp.sh

# run all of these tests
./test/run_everything.sh
./test/test_everything.sh
```

# Benchmarking
Expand Down
46 changes: 8 additions & 38 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,43 +112,12 @@ def to_fp8_no_autograd(
return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config)


def from_fp8_no_autograd(x: torch.Tensor) -> torch.Tensor:
"""Convert a tensor from float8 without autograd

This function will handle 3 cases:
1. If the tensor is a DTensor, it will convert the inner tensor to the original precision
2. If the tensor is a Float8Tensor, it will convert the tensor to the original precision
3. If the tensor is a regular tensor, it will pass through this tensor

Args:
x: the tensor to convert
"""

def to_original_precision(grad):
if isinstance(grad, Float8Tensor):
return grad.to_original_precision()
else:
return grad

if isinstance(x, DTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recalled the main reason we had that special handling is that torch.compile (specifically FakeTensor rule) can't take nested subclass rule automatically. If we remove this wondering where we would put such a logic?

cc @drisspg @bdhirsh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I understand the old state correctly:

  1. grad is never a Float8Tensor in practice, because we always have the matmul output high precision
  2. things already work without nested subclasses, which is why removing this function is fine

But, would be great to clarify ^

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I checked the current workflow, we call:

  1. cast_to_float8_e4m3fn
  2. cast_to_float8_e5m2_bw
    and it seems none of these calls would call into this from_fp8_no_autograd

so this should be fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it looks like I'm wrong, the cast_to_float8_e4m3fn's backward would call this, but maybe it's a no-op even as of current state, as the output of the torch.scaled_mm in backward is fp32?

local_grad = x.to_local()
original_precision_grad = to_original_precision(local_grad)
return DTensor.from_local(
original_precision_grad,
x.device_mesh,
x.placements,
run_check=False,
shape=x.size(),
stride=x.stride(),
)
else:
return to_original_precision(x)


@torch._dynamo.allow_in_graph
class ToFloat8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion to fp8
A differentiable conversion to fp8.
* forward: convert from high precision to float8
* backward: pass the gradient without changes
"""

@staticmethod
Expand All @@ -175,14 +144,15 @@ def forward(

@staticmethod
def backward(ctx, g):
grad = from_fp8_no_autograd(g)
return grad, None, None, None, None
return g, None, None, None, None


@torch._dynamo.allow_in_graph
class FromFloat8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion from fp8
A differentiable conversion from fp8.
* forward: convert from float8 to high precision
* backward: pass the gradient without changes
"""

@staticmethod
Expand All @@ -191,7 +161,7 @@ def forward(ctx, tensor):

@staticmethod
def backward(ctx, g):
return Float8Tensor.to_float8(g), None, None
return g, None, None


class Float8Tensor(torch.Tensor):
Expand Down
12 changes: 12 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def test_preserves_dtype(self) -> None:
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)

def test_differentiable_casts(self) -> None:
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
for f8_dtype in lp_dtypes:
x = torch.randn(1).requires_grad_()
grad = torch.randn(1)
x_s = tensor_to_scale(x, f8_dtype)
x_f8 = Float8Tensor.to_float8(x, x_s, f8_dtype)
x_f8_hp = x_f8.to_original_precision()
x_f8_hp.backward(grad)
# the gradient should be unchanged through both casts
torch.testing.assert_close(grad, x.grad, rtol=0, atol=0)


class TestFloat8Linear:
def _test_linear_impl(
Expand Down
Loading