Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit cb55df2

Browse files
drisspgfacebook-github-bot
authored andcommitted
Enable restricted split + cat in order to enable SP (#253)
Summary: This comes from needing to support sequence parallelism in torchtitan Pull Request resolved: #253 Reviewed By: wanchaol Differential Revision: D57134004 Pulled By: drisspg fbshipit-source-id: e6c67ba7b2b96045867ece467400b0e4a3305e1d
1 parent 14b00aa commit cb55df2

File tree

3 files changed

+64
-3
lines changed

3 files changed

+64
-3
lines changed

float8_experimental/float8_ops.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from typing import Any, Dict
6+
from typing import Any, Dict, Tuple
77

88
import torch
99

@@ -50,6 +50,52 @@ def float8_desugar_op(aten_op, args, kwargs=None):
5050
)
5151

5252

53+
@implements([aten.split.Tensor])
54+
def float8_split(aten_op, args, kwargs=None):
55+
new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs)
56+
57+
def make_float8(data):
58+
return Float8Tensor(
59+
data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
60+
)
61+
62+
out = map(make_float8, new_data_tensors)
63+
return list(out)
64+
65+
66+
# Errors cant `cat_cuda float8 e4m3fn`
67+
@implements([aten.cat.default])
68+
def float8_cat(aten_op, args, kwargs=None):
69+
chunked_tensors: Tuple[Float8Tensor] = args[0]
70+
71+
orig_dtype = chunked_tensors[0]._orig_dtype
72+
scale = chunked_tensors[0]._scale
73+
mm_config = chunked_tensors[0]._mm_config
74+
fp8_dtype = chunked_tensors[0]._data.dtype
75+
chunk_data = []
76+
for chunk in chunked_tensors:
77+
assert isinstance(
78+
chunk, Float8Tensor
79+
), "Expecting all chunks to be of type Float8Tensor"
80+
assert (
81+
chunk._orig_dtype == orig_dtype
82+
), "Expecting all chunks to be of the same dtype"
83+
assert (
84+
chunk._scale is scale
85+
), "Expecting all chunks to have thee same scale as a result of a split"
86+
assert (
87+
chunk._mm_config is mm_config
88+
), "Expecting all chunks to have thee same mm config as a result of a split"
89+
assert (
90+
chunk._data.dtype == fp8_dtype
91+
), "Expecting all chunks to be of the same dtype as a result of a split"
92+
chunk_data.append(chunk._data.view(torch.uint8))
93+
94+
new_data = aten_op(chunk_data, *args[1:], **kwargs)
95+
new_data = new_data.view(fp8_dtype)
96+
return Float8Tensor(new_data, scale, orig_dtype, mm_config)
97+
98+
5399
@implements([aten.sum.dim_IntList])
54100
def float8_cast_up_op(aten_op, args, kwargs=None):
55101
"""Be careful with this function, this is a "fallback" op that

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dev = [
2929
"black==23.3.0",
3030
"usort==1.0.6",
3131
"ufmt==2.1.0",
32-
"libcst==1.0.1",
32+
"libcst==1.1.0",
3333
"pytest==7.4.0",
3434
"bumpver",
3535
"pip-tools",

test/test_base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@
4444
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
4545

4646

47+
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
48+
assert torch.all(a._data == b._data).item(), "scales are not identical"
49+
assert torch.all(a._data == b._data).item(), "data is not identical"
50+
return True
51+
52+
4753
class TestFloat8Tensor(unittest.TestCase):
4854
def test_preserves_dtype(self) -> None:
4955
# hp means high precision, lp means low precision
@@ -68,6 +74,15 @@ def test_differentiable_casts(self) -> None:
6874
# the gradient should be unchanged through both casts
6975
torch.testing.assert_close(grad, x.grad, rtol=0, atol=0)
7076

77+
def test_split_cat(self):
78+
a = torch.rand(16, 16, dtype=torch.bfloat16)
79+
scale = tensor_to_scale(a, torch.float8_e4m3fn)
80+
fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn)
81+
82+
splits = torch.split(fp8_a, 16)
83+
catted = torch.cat(splits, dim=0)
84+
assert bitwise_identical(fp8_a, catted)
85+
7186

7287
class TestFloat8Linear:
7388
def _test_linear_impl(
@@ -357,7 +372,7 @@ def test_different_configs_error(self):
357372
):
358373
a @ b
359374

360-
def test_merge_configs(sel):
375+
def test_merge_configs(self):
361376
a = ScaledMMConfig(False, True, True)
362377
b = ScaledMMConfig(True, False, False)
363378
with pytest.raises(

0 commit comments

Comments
 (0)