Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b67e5cf

Browse files
drisspgfacebook-github-bot
authored andcommitted
Ensure DTensor wraps inner float8 Tensors (#224)
Summary: This special cases the creation of Float8Tensor for when the input data and scale are DTensors. If thats the case we want to maintain the DTensor invariant that it is the outermost tensor that wraps Float8. I add an example forw/backward which doesn't make the most sense and am currently getting this error: ## Updated I realized that I wasn't calling the float8 constructor in the backward that would be able to handle DTensor which lead to below for, reasons.. I have remedied that and I was able to get E2E working ( numerically correct not tested). This did requre that the mm_op had to manually wait be waited on in matmul.. too me this feels like bug in AsyncCollectiveTensor but need to track down. ##### Old errror ``` [rank1]: File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_tensor.py", line 221, in __torch_dispatch__ [rank1]: return FLOAT8_OPS_TABLE[func](func, args, kwargs) [rank1]: File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_ops.py", line 81, in float8_mm [rank1]: assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor), ( [rank1]: AssertionError: Expecting both Float8Tensor for mm inputs but found <class 'float8_experimental.float8_tensor.Float8Tensor'> and <class 'torch.distributed._tensor.api.DTensor'> E0222 16:10:18.512000 140692135212864 torch/distri ``` Why the Error? 1. The output of the scaled_mm is a regular dtensor(torch.Tensor), (converts two DTensors(Float8Tensors) -> DTensor(Float8Tensor) 2. We send this output through the NoopForward which will do nothing in the forward and convert the grad to a Float8Tensor of the e5m2 dtype. 3. This creation of the Float8Tensor from the backpropping Dtensor will hit the special logic in the Float8Tensor construction, that makes sure DTensor re-wraps the Float8Tensor. Pull Request resolved: #224 Reviewed By: bdhirsh Differential Revision: D54204167 Pulled By: drisspg fbshipit-source-id: 9c3c3ccb3cae8b90f5ab5c61fc0e7b96d89176d3
1 parent 47facc8 commit b67e5cf

File tree

4 files changed

+187
-37
lines changed

4 files changed

+187
-37
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def forward(
3030

3131
@staticmethod
3232
def backward(ctx, gradY):
33-
fp8_tensor = to_fp8_no_autograd(gradY, torch.float8_e5m2, ctx.emulate)
33+
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
34+
fp8_tensor = to_fp8_no_autograd(
35+
gradY, gradY_scale, torch.float8_e5m2, ctx.emulate
36+
)
3437
return fp8_tensor, None
3538

3639

float8_experimental/float8_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,14 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
7878

7979
@implements([aten.mm.default])
8080
def float8_mm(aten_op, args, kwargs=None):
81-
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
8281
a = args[0]
8382
b = args[1]
83+
84+
assert isinstance(a, Float8Tensor) and isinstance(
85+
b, Float8Tensor
86+
), "Expecting both Float8Tensor for mm inputs but found {} and {}".format(
87+
type(a), type(b)
88+
)
8489
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
8590
output_dtype = a._orig_dtype
8691
if a._emulate:

float8_experimental/float8_tensor.py

Lines changed: 118 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from typing import Dict
6+
from typing import Dict, Optional
77

88
import torch
99

@@ -13,9 +13,91 @@
1313
to_fp8_saturated,
1414
)
1515

16+
from torch.distributed._tensor import DTensor
17+
1618
aten = torch.ops.aten
1719

1820

21+
def to_fp8_no_autograd(
22+
x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, emulate: bool
23+
) -> "Float8Tensor":
24+
"""Convert a tensor to float8 without autograd
25+
This is used in multiple places in the codebase to convert a tensor to float8
26+
27+
This function will apply the scaling, and then convert to a Float8Tensor
28+
29+
Note:
30+
We will call this function with a DTensor subclass. Ideally this would be an aten OP
31+
that DTensor could overload to ensure proper semantics. There are some techincal issues
32+
with that composing with FakeTensor, so we special case here.
33+
34+
DTensor Invariant: DTensor must always be the outer most tensor subclass
35+
36+
Args:
37+
x: the tensor to convert
38+
scale: the scale to use to convert the tensor
39+
float8_dtype: the float8 dtype to use
40+
emulate: whether to emulate the matmuls in fp32
41+
"""
42+
x_scaled = x * x_scale
43+
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)
44+
45+
if isinstance(bits_fp8, DTensor):
46+
assert isinstance(
47+
x, DTensor
48+
), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
49+
bits_mesh = bits_fp8.device_mesh
50+
bits_placements = bits_fp8.placements
51+
local_bits = bits_fp8.to_local()
52+
local_scale = x_scale.to_local()
53+
inner_float8_tensor = Float8Tensor(
54+
local_bits, local_scale, x.dtype, emulate=emulate
55+
)
56+
return DTensor.from_local(
57+
inner_float8_tensor,
58+
bits_mesh,
59+
bits_placements,
60+
run_check=False,
61+
shape=bits_fp8.size(),
62+
stride=bits_fp8.stride(),
63+
)
64+
65+
return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate)
66+
67+
68+
def from_fp8_no_autograd(x: torch.Tensor) -> torch.Tensor:
69+
"""Convert a tensor from float8 without autograd
70+
71+
This function will handle 3 cases:
72+
1. If the tensor is a DTensor, it will convert the inner tensor to the original precision
73+
2. If the tensor is a Float8Tensor, it will convert the tensor to the original precision
74+
3. If the tensor is a regular tensor, it will pass through this tensor
75+
76+
Args:
77+
x: the tensor to convert
78+
"""
79+
80+
def to_original_precision(grad):
81+
if isinstance(grad, Float8Tensor):
82+
return grad.to_original_precision()
83+
else:
84+
return grad
85+
86+
if isinstance(x, DTensor):
87+
local_grad = x.to_local()
88+
original_precision_grad = to_original_precision(local_grad)
89+
return DTensor.from_local(
90+
original_precision_grad,
91+
x.device_mesh,
92+
x.placements,
93+
run_check=False,
94+
shape=x.size(),
95+
stride=x.stride(),
96+
)
97+
else:
98+
return to_original_precision(x)
99+
100+
19101
@torch._dynamo.allow_in_graph
20102
class ToFloat8ConstrFunc(torch.autograd.Function):
21103
"""
@@ -25,25 +107,29 @@ class ToFloat8ConstrFunc(torch.autograd.Function):
25107
@staticmethod
26108
def forward(
27109
ctx,
28-
tensor,
29-
scale: float,
110+
tensor: torch.Tensor,
111+
scale: torch.Tensor,
30112
float8_dtype=torch.float8_e4m3fn,
31-
amax_buffer=None,
113+
amax_buffer: Optional[torch.Tensor] = None,
32114
emulate: bool = False,
33115
):
116+
"""Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
117+
Args
118+
tensor: the tensor to convert
119+
scale: the scale to use to convert the tensor
120+
float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
121+
amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion
122+
emulate: whether to emulate the matmuls in fp32
123+
"""
34124
if amax_buffer is not None:
35125
amax_buffer.fill_(tensor_to_amax(tensor))
36126

37-
tensor_scaled = tensor * scale
38-
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
39-
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate)
127+
return to_fp8_no_autograd(tensor, scale, float8_dtype, emulate)
40128

41129
@staticmethod
42130
def backward(ctx, g):
43-
if isinstance(g, Float8Tensor):
44-
return g.to_original_precision(), None, None, None, None
45-
else:
46-
return g, None, None, None, None
131+
grad = from_fp8_no_autograd(g)
132+
return grad, None, None, None, None
47133

48134

49135
@torch._dynamo.allow_in_graph
@@ -95,7 +181,11 @@ def __new__(
95181
orig_dtype: torch.dtype,
96182
emulate=False,
97183
):
98-
assert scale.numel() == 1
184+
assert (
185+
scale.numel() == 1
186+
), "Scale should contain a single value, but got: {} elements".format(
187+
scale.numel()
188+
)
99189

100190
self = torch.Tensor._make_wrapper_subclass(
101191
cls,
@@ -138,7 +228,13 @@ def to_original_precision(self):
138228

139229
@staticmethod
140230
@torch._dynamo.allow_in_graph
141-
def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False):
231+
def to_float8(
232+
tensor: torch.Tensor,
233+
scale: torch.Tensor,
234+
float8_dtype: torch.dtype,
235+
amax_buffer: Optional[torch.Tensor] = None,
236+
emulate: bool = False,
237+
):
142238
"""Converts a higher precision tensor to float8 in a differentiable way.
143239
144240
Args:
@@ -168,28 +264,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
168264
# Lazy import to avoid circular dependency
169265
from float8_experimental.float8_ops import FLOAT8_OPS_TABLE
170266

267+
# All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs
268+
# And don't support mixed tensor subclasses. This will trigger the handler for
269+
# the next type in the dispatch list. torch._C._TensorMeta is for FakeTensor
270+
def allowed_subclasses(type):
271+
return issubclass(cls, type) or isinstance(type, torch._C._TensorMeta)
272+
273+
if not all(allowed_subclasses(t) for t in types):
274+
return NotImplemented
275+
171276
if func in FLOAT8_OPS_TABLE:
172277
return FLOAT8_OPS_TABLE[func](func, args, kwargs)
173278
raise NotImplementedError(f"attempting to run {func}, this is not supported")
174279

175280
# Do not force the Float8Tensor type on the returned tensor
176281
__torch_function__ = torch._C._disabled_torch_function_impl
177-
178-
179-
def to_fp8_no_autograd(
180-
x: torch.Tensor, float8_dtype: torch.dtype, emulate: bool
181-
) -> Float8Tensor:
182-
"""Convert a tensor to float8 without autograd
183-
This is used in multiple places in the codebase to convert a tensor to float8
184-
185-
This function will calculate the scale, do the scaling, and then convert to a Float8Tensor
186-
Args:
187-
x: the tensor to convert
188-
scale: the scale to use to convert the tensor
189-
float8_dtype: the float8 dtype to use
190-
emulate: whether to emulate the matmuls in fp32
191-
"""
192-
x_scale = tensor_to_scale(x, float8_dtype)
193-
x_scaled = x * x_scale
194-
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)
195-
return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate)

test/test_dtensor.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
import torch
1313
import torch.nn as nn
1414

15+
from float8_experimental.float8_dynamic_linear import NoopFwToFloat8E5M2Bw
1516
from float8_experimental.float8_tensor import Float8Tensor
1617
from float8_experimental.float8_utils import tensor_to_scale
17-
from torch.distributed._tensor import DTensor, Replicate, Shard
18+
from torch.distributed import init_process_group
19+
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
1820
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
21+
from torch.testing._internal.distributed.fake_pg import FakeStore
22+
from tqdm import tqdm
1923

2024

2125
def setup_distributed():
@@ -92,10 +96,62 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16):
9296
assert out_local._data.dtype == fp8_dtype
9397

9498

99+
def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
100+
device = mesh.device_type
101+
fp8_dtype = torch.float8_e4m3fn
102+
103+
x_fp32 = torch.rand(size, size, device=device)
104+
dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
105+
106+
dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float()
107+
assert isinstance(dist_x_scale, DTensor)
108+
109+
dist_x_fp8 = Float8Tensor.to_float8(dist_x_fp32, dist_x_scale, fp8_dtype)
110+
assert isinstance(dist_x_fp8, DTensor)
111+
112+
113+
def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
114+
device = mesh.device_type
115+
fp8_dtype = torch.float8_e4m3fn
116+
117+
x_fp32 = torch.rand(size, size, device=device, requires_grad=True)
118+
local_weight = torch.rand(2 * size, size, device=device, requires_grad=True)
119+
target = torch.rand(size, 2 * size, device=device)
120+
121+
dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
122+
dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float()
123+
124+
dist_wight_fp32 = distribute_tensor(local_weight, mesh, [Shard(0)])
125+
dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float()
126+
dist_target = distribute_tensor(target, mesh, [Shard(0)])
127+
128+
dist_x_fp8 = Float8Tensor.to_float8(dist_x_fp32, dist_x_scale, fp8_dtype)
129+
dist_weight_fp8 = Float8Tensor.to_float8(
130+
dist_wight_fp32, dist_weight_scale, fp8_dtype
131+
)
132+
133+
out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
134+
out = NoopFwToFloat8E5M2Bw.apply(out, False)
135+
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
136+
loss = torch.sum(torch.abs(out - dist_target))
137+
loss.backward()
138+
139+
95140
if __name__ == "__main__":
96141
# float8 only works on CUDA H100 so we only test cuda and we follow
97142
# other test files to not use TestCase but instead just add the test
98143
# cases in the main func.
99144
device_mesh = setup_distributed()
100-
test_scaled_mm(device_mesh)
101-
test_fp8_redistribute(device_mesh)
145+
tests = [
146+
test_scaled_mm,
147+
test_fp8_redistribute,
148+
test_dtensor_cast_to_fp8,
149+
test_dtensor_fp8_autograd,
150+
]
151+
152+
for test in tqdm(tests, desc="Running tests"):
153+
try:
154+
test(device_mesh)
155+
except Exception as e:
156+
print(f"Test {test.__name__} failed with error: {e}")
157+
raise e

0 commit comments

Comments
 (0)