Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 2623617

Browse files
committed
view as data
1 parent e4218e1 commit 2623617

File tree

2 files changed

+50
-20
lines changed

2 files changed

+50
-20
lines changed

float8_experimental/float8_ops.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from typing import Any, Dict
6+
from typing import Any, Dict, Tuple
77

88
import torch
99

@@ -64,24 +64,39 @@ def make_float8(data):
6464

6565

6666
# Errors cant `cat_cuda float8 e4m3fn`
67-
# @implements([aten.cat.default])
68-
# def float8_cat(aten_op, args, kwargs=None):
69-
# chunked_tensors: Tuple[Float8Tensor] = args[0]
70-
71-
# orig_dtype = args[0][0]._orig_dtype
72-
# scale = args[0][0]._scale
73-
# mm_config = args[0][0]._mm_config
74-
# chunk_data =[]
75-
# for chunk in chunked_tensors:
76-
# assert chunk._orig_dtype == orig_dtype, "Expecting all chunks to be of the same dtype"
77-
# assert chunk._scale is scale, "Expecting all chunks to have thee same scale as a result of a split"
78-
# assert chunk._mm_config is mm_config, "Expecting all chunks to have thee same mm config as a result of a split"
79-
# chunk_data.append(chunk._data)
80-
# new_data = aten_op(chunk_data, *args[1:], **kwargs)
81-
# return Float8Tensor(new_data, scale, orig_dtype, mm_config)
82-
83-
84-
@implements([aten.sum.dim_IntList, aten.cat.default])
67+
@implements([aten.cat.default])
68+
def float8_cat(aten_op, args, kwargs=None):
69+
chunked_tensors: Tuple[Float8Tensor] = args[0]
70+
71+
orig_dtype = chunked_tensors[0]._orig_dtype
72+
scale = chunked_tensors[0]._scale
73+
mm_config = chunked_tensors[0]._mm_config
74+
fp8_dtype = chunked_tensors[0]._data.dtype
75+
chunk_data = []
76+
for chunk in chunked_tensors:
77+
assert isinstance(
78+
chunk, Float8Tensor
79+
), "Expecting all chunks to be of type Float8Tensor"
80+
assert (
81+
chunk._orig_dtype == orig_dtype
82+
), "Expecting all chunks to be of the same dtype"
83+
assert (
84+
chunk._scale is scale
85+
), "Expecting all chunks to have thee same scale as a result of a split"
86+
assert (
87+
chunk._mm_config is mm_config
88+
), "Expecting all chunks to have thee same mm config as a result of a split"
89+
assert (
90+
chunk._data.dtype == fp8_dtype
91+
), "Expecting all chunks to be of the same dtype as a result of a split"
92+
chunk_data.append(chunk._data.view(torch.uint8))
93+
94+
new_data = aten_op(chunk_data, *args[1:], **kwargs)
95+
new_data = new_data.view(fp8_dtype)
96+
return Float8Tensor(new_data, scale, orig_dtype, mm_config)
97+
98+
99+
@implements([aten.sum.dim_IntList])
85100
def float8_cast_up_op(aten_op, args, kwargs=None):
86101
"""Be careful with this function, this is a "fallback" op that
87102
casts the output of the op to the original precision. And performs the op.

test/test_base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@
4444
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
4545

4646

47+
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
48+
assert torch.all(a._data == b._data).item(), "scales are not identical"
49+
assert torch.all(a._data == b._data).item(), "data is not identical"
50+
return True
51+
52+
4753
class TestFloat8Tensor(unittest.TestCase):
4854
def test_preserves_dtype(self) -> None:
4955
# hp means high precision, lp means low precision
@@ -68,6 +74,15 @@ def test_differentiable_casts(self) -> None:
6874
# the gradient should be unchanged through both casts
6975
torch.testing.assert_close(grad, x.grad, rtol=0, atol=0)
7076

77+
def test_split_cat(self):
78+
a = torch.rand(16, 16, dtype=torch.bfloat16, device="cuda")
79+
scale = tensor_to_scale(a, torch.float8_e4m3fn)
80+
fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn)
81+
82+
splits = torch.split(fp8_a, 16)
83+
catted = torch.cat(splits, dim=0)
84+
assert bitwise_identical(fp8_a, catted)
85+
7186

7287
class TestFloat8Linear:
7388
def _test_linear_impl(
@@ -357,7 +372,7 @@ def test_different_configs_error(self):
357372
):
358373
a @ b
359374

360-
def test_merge_configs(sel):
375+
def test_merge_configs(self):
361376
a = ScaledMMConfig(False, True, True)
362377
b = ScaledMMConfig(True, False, False)
363378
with pytest.raises(

0 commit comments

Comments
 (0)