Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 1a57df3

Browse files
committed
some more shuffiling
1 parent fdbba31 commit 1a57df3

File tree

1 file changed

+43
-30
lines changed

1 file changed

+43
-30
lines changed

float8_experimental/float8_tensor.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@ def to_fp8_no_autograd(
2424
"""Convert a tensor to float8 without autograd
2525
This is used in multiple places in the codebase to convert a tensor to float8
2626
27-
This function will calculate the scale, do the scaling, and then convert to a Float8Tensor
27+
This function will apply the scaling, and then convert to a Float8Tensor
28+
29+
Note:
30+
We will call this function with a DTensor subclass. Ideally this would be an aten OP
31+
that DTensor could overload to ensure proper semantics. There are some techincal issues
32+
with that composing with FakeTensor, so we special case here.
33+
34+
DTensor Invariant: DTensor must always be the outer most tensor subclass
35+
2836
Args:
2937
x: the tensor to convert
3038
scale: the scale to use to convert the tensor
@@ -50,6 +58,30 @@ def to_fp8_no_autograd(
5058
return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate)
5159

5260

61+
def from_fp8_no_autograd(x: torch.Tensor) -> torch.Tensor:
62+
"""Convert a tensor from float8 without autograd
63+
64+
This function will handle 3 cases:
65+
1. If the tensor is a DTensor, it will convert the inner tensor to the original precision
66+
2. If the tensor is a Float8Tensor, it will convert the tensor to the original precision
67+
3. If the tensor is a regular tensor, it will pass through this tensor
68+
69+
Args:
70+
x: the tensor to convert
71+
"""
72+
73+
def to_original_precision(grad):
74+
if isinstance(grad, Float8Tensor):
75+
return grad.to_original_precision()
76+
else:
77+
return grad
78+
79+
if isinstance(x, DTensor):
80+
local_grad = x.to_local()
81+
original_precision_grad = to_original_precision(local_grad)
82+
return DTensor.from_local(original_precision_grad, x.device_mesh, x.placements)
83+
84+
5385
@torch._dynamo.allow_in_graph
5486
class ToFloat8ConstrFunc(torch.autograd.Function):
5587
"""
@@ -62,17 +94,16 @@ def forward(
6294
tensor: torch.Tensor,
6395
scale: torch.Tensor,
6496
float8_dtype=torch.float8_e4m3fn,
65-
amax_buffer=None,
97+
amax_buffer: Optional[torch.Tensor] = None,
6698
emulate: bool = False,
6799
):
68-
"""Converts a higher precision tensor to float8 in a differentiable way.
69-
70-
Note:
71-
We will call this function with a DTensor subclass. Ideally this would be an aten OP
72-
that DTensor could overload to ensure proper semantics. There are some techincal issues
73-
with that composing with FakeTensor, so we special case here.
74-
75-
DTensor Invariant: DTensor must always be the outer most tensor subclass
100+
"""Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
101+
Args
102+
tensor: the tensor to convert
103+
scale: the scale to use to convert the tensor
104+
float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
105+
amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion
106+
emulate: whether to emulate the matmuls in fp32
76107
"""
77108
if amax_buffer is not None:
78109
amax_buffer.fill_(tensor_to_amax(tensor))
@@ -81,26 +112,8 @@ def forward(
81112

82113
@staticmethod
83114
def backward(ctx, g):
84-
def to_original_precision(grad):
85-
if isinstance(grad, Float8Tensor):
86-
return grad.to_original_precision()
87-
else:
88-
return grad
89-
90-
if isinstance(g, DTensor):
91-
local_grad = g.to_local()
92-
original_precision_grad = to_original_precision(local_grad)
93-
return (
94-
DTensor.from_local(
95-
original_precision_grad, g.device_mesh, g.placements
96-
),
97-
None,
98-
None,
99-
None,
100-
None,
101-
)
102-
else:
103-
return to_original_precision(g), None, None, None, None
115+
grad = from_fp8_no_autograd(g)
116+
return grad, None, None, None, None
104117

105118

106119
@torch._dynamo.allow_in_graph

0 commit comments

Comments
 (0)