Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 605fc1d

Browse files
vkuzofacebook-github-bot
authored andcommitted
make the backward of differentiable float8 casts pass gradient as is (#255)
Summary: Behavior before: * high precision to float8 in fw, float8 to high precision in bw * float8 to high precision in fw, high precision to float8 in bw if grad is a Float8Tensor, pass gradient unchanged otherwise Behavior after: * high precision to float8 in fw, pass gradient unchanged in bw * float8 to high precision in fw, pass gradient unchanged in bw Motivation for the new state: 1. we want gradients to be in high precision unless specified otherwise by the float8 recipe, and the logic to specify grad casting to float8 before the matmul is better implemented elsewhere 2. there is actually no logic change in this diff as the backward casts were not getting hit from existing code, this diff just makes the intended behavior clearer Pull Request resolved: #255 Test Plan: ``` ./test/test_everything.sh ``` Reviewed By: drisspg, malfet, wanchaol Differential Revision: D56956823 Pulled By: vkuzo fbshipit-source-id: 1388420ad933a88986443effdf13ef1f8516138b
1 parent 7c53229 commit 605fc1d

File tree

3 files changed

+21
-39
lines changed

3 files changed

+21
-39
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ pytest test/test_compile.py
118118
./test/test_tp.sh
119119

120120
# run all of these tests
121-
./test/run_everything.sh
121+
./test/test_everything.sh
122122
```
123123

124124
# Benchmarking

float8_experimental/float8_tensor.py

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -112,43 +112,12 @@ def to_fp8_no_autograd(
112112
return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config)
113113

114114

115-
def from_fp8_no_autograd(x: torch.Tensor) -> torch.Tensor:
116-
"""Convert a tensor from float8 without autograd
117-
118-
This function will handle 3 cases:
119-
1. If the tensor is a DTensor, it will convert the inner tensor to the original precision
120-
2. If the tensor is a Float8Tensor, it will convert the tensor to the original precision
121-
3. If the tensor is a regular tensor, it will pass through this tensor
122-
123-
Args:
124-
x: the tensor to convert
125-
"""
126-
127-
def to_original_precision(grad):
128-
if isinstance(grad, Float8Tensor):
129-
return grad.to_original_precision()
130-
else:
131-
return grad
132-
133-
if isinstance(x, DTensor):
134-
local_grad = x.to_local()
135-
original_precision_grad = to_original_precision(local_grad)
136-
return DTensor.from_local(
137-
original_precision_grad,
138-
x.device_mesh,
139-
x.placements,
140-
run_check=False,
141-
shape=x.size(),
142-
stride=x.stride(),
143-
)
144-
else:
145-
return to_original_precision(x)
146-
147-
148115
@torch._dynamo.allow_in_graph
149116
class ToFloat8ConstrFunc(torch.autograd.Function):
150117
"""
151-
A differentiable conversion to fp8
118+
A differentiable conversion to fp8.
119+
* forward: convert from high precision to float8
120+
* backward: pass the gradient without changes
152121
"""
153122

154123
@staticmethod
@@ -175,14 +144,15 @@ def forward(
175144

176145
@staticmethod
177146
def backward(ctx, g):
178-
grad = from_fp8_no_autograd(g)
179-
return grad, None, None, None, None
147+
return g, None, None, None, None
180148

181149

182150
@torch._dynamo.allow_in_graph
183151
class FromFloat8ConstrFunc(torch.autograd.Function):
184152
"""
185-
A differentiable conversion from fp8
153+
A differentiable conversion from fp8.
154+
* forward: convert from float8 to high precision
155+
* backward: pass the gradient without changes
186156
"""
187157

188158
@staticmethod
@@ -191,7 +161,7 @@ def forward(ctx, tensor):
191161

192162
@staticmethod
193163
def backward(ctx, g):
194-
return Float8Tensor.to_float8(g), None, None
164+
return g, None, None
195165

196166

197167
class Float8Tensor(torch.Tensor):

test/test_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def test_preserves_dtype(self) -> None:
5656
x3_hp = x2_lp.to_original_precision()
5757
self.assertTrue(x3_hp.dtype == hp_dtype)
5858

59+
def test_differentiable_casts(self) -> None:
60+
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
61+
for f8_dtype in lp_dtypes:
62+
x = torch.randn(1).requires_grad_()
63+
grad = torch.randn(1)
64+
x_s = tensor_to_scale(x, f8_dtype)
65+
x_f8 = Float8Tensor.to_float8(x, x_s, f8_dtype)
66+
x_f8_hp = x_f8.to_original_precision()
67+
x_f8_hp.backward(grad)
68+
# the gradient should be unchanged through both casts
69+
torch.testing.assert_close(grad, x.grad, rtol=0, atol=0)
70+
5971

6072
class TestFloat8Linear:
6173
def _test_linear_impl(

0 commit comments

Comments
 (0)