Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 48bd03a

Browse files
committed
some more shuffiling
1 parent fdbba31 commit 48bd03a

File tree

2 files changed

+45
-44
lines changed

2 files changed

+45
-44
lines changed

float8_experimental/float8_tensor.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@ def to_fp8_no_autograd(
2424
"""Convert a tensor to float8 without autograd
2525
This is used in multiple places in the codebase to convert a tensor to float8
2626
27-
This function will calculate the scale, do the scaling, and then convert to a Float8Tensor
27+
This function will apply the scaling, and then convert to a Float8Tensor
28+
29+
Note:
30+
We will call this function with a DTensor subclass. Ideally this would be an aten OP
31+
that DTensor could overload to ensure proper semantics. There are some techincal issues
32+
with that composing with FakeTensor, so we special case here.
33+
34+
DTensor Invariant: DTensor must always be the outer most tensor subclass
35+
2836
Args:
2937
x: the tensor to convert
3038
scale: the scale to use to convert the tensor
@@ -50,6 +58,32 @@ def to_fp8_no_autograd(
5058
return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate)
5159

5260

61+
def from_fp8_no_autograd(x: torch.Tensor) -> torch.Tensor:
62+
"""Convert a tensor from float8 without autograd
63+
64+
This function will handle 3 cases:
65+
1. If the tensor is a DTensor, it will convert the inner tensor to the original precision
66+
2. If the tensor is a Float8Tensor, it will convert the tensor to the original precision
67+
3. If the tensor is a regular tensor, it will pass through this tensor
68+
69+
Args:
70+
x: the tensor to convert
71+
"""
72+
73+
def to_original_precision(grad):
74+
if isinstance(grad, Float8Tensor):
75+
return grad.to_original_precision()
76+
else:
77+
return grad
78+
79+
if isinstance(x, DTensor):
80+
local_grad = x.to_local()
81+
original_precision_grad = to_original_precision(local_grad)
82+
return DTensor.from_local(original_precision_grad, x.device_mesh, x.placements)
83+
else:
84+
return to_original_precision(x)
85+
86+
5387
@torch._dynamo.allow_in_graph
5488
class ToFloat8ConstrFunc(torch.autograd.Function):
5589
"""
@@ -62,17 +96,16 @@ def forward(
6296
tensor: torch.Tensor,
6397
scale: torch.Tensor,
6498
float8_dtype=torch.float8_e4m3fn,
65-
amax_buffer=None,
99+
amax_buffer: Optional[torch.Tensor] = None,
66100
emulate: bool = False,
67101
):
68-
"""Converts a higher precision tensor to float8 in a differentiable way.
69-
70-
Note:
71-
We will call this function with a DTensor subclass. Ideally this would be an aten OP
72-
that DTensor could overload to ensure proper semantics. There are some techincal issues
73-
with that composing with FakeTensor, so we special case here.
74-
75-
DTensor Invariant: DTensor must always be the outer most tensor subclass
102+
"""Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
103+
Args
104+
tensor: the tensor to convert
105+
scale: the scale to use to convert the tensor
106+
float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
107+
amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion
108+
emulate: whether to emulate the matmuls in fp32
76109
"""
77110
if amax_buffer is not None:
78111
amax_buffer.fill_(tensor_to_amax(tensor))
@@ -81,26 +114,8 @@ def forward(
81114

82115
@staticmethod
83116
def backward(ctx, g):
84-
def to_original_precision(grad):
85-
if isinstance(grad, Float8Tensor):
86-
return grad.to_original_precision()
87-
else:
88-
return grad
89-
90-
if isinstance(g, DTensor):
91-
local_grad = g.to_local()
92-
original_precision_grad = to_original_precision(local_grad)
93-
return (
94-
DTensor.from_local(
95-
original_precision_grad, g.device_mesh, g.placements
96-
),
97-
None,
98-
None,
99-
None,
100-
None,
101-
)
102-
else:
103-
return to_original_precision(g), None, None, None, None
117+
grad = from_fp8_no_autograd(g)
118+
return grad, None, None, None, None
104119

105120

106121
@torch._dynamo.allow_in_graph

test/test_dtensor.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,6 @@
2222
from tqdm import tqdm
2323

2424

25-
# def setup_distributed_fake():
26-
# world_size = int(os.environ.get("WORLD_SIZE", -1))
27-
# print(world_size)
28-
# RANK = int(os.environ.get("RANK", -1))
29-
# init_process_group("fake", store=FakeStore(), rank=RANK, world_size=world_size)
30-
# # device_mesh = init_device_mesh("cuda", (world_size,))
31-
# device_mesh = DeviceMesh("cuda", (world_size,))
32-
33-
# # seed must be the same in all processes
34-
# torch.manual_seed(1)
35-
# return device_mesh
36-
37-
3825
def setup_distributed():
3926
world_size = int(os.environ.get("WORLD_SIZE", -1))
4027
device_mesh = init_device_mesh("cuda", (world_size,))
@@ -155,7 +142,6 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
155142
# other test files to not use TestCase but instead just add the test
156143
# cases in the main func.
157144
device_mesh = setup_distributed()
158-
# device_mesh = setup_distributed_fake()
159145
tests = [
160146
test_scaled_mm,
161147
test_fp8_redistribute,

0 commit comments

Comments
 (0)