Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b5a444a

Browse files
ani300facebook-github-bot
authored andcommitted
Add more compile compatibility for Float8Tensor ops (#285)
Summary: Pull Request resolved: #285 Reviewed By: vkuzo Differential Revision: D59068281 Pulled By: drisspg fbshipit-source-id: 18fa34db74cf60e85ff372ff1091c107119403a0
1 parent 57136bd commit b5a444a

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

float8_experimental/float8_ops.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def decorator(func):
4242
aten.as_strided.default,
4343
aten.clone.default,
4444
aten.detach.default,
45+
aten.slice.Tensor,
46+
aten.transpose.int,
47+
aten.fill_.Scalar,
4548
]
4649
)
4750
def float8_desugar_op(aten_op, args, kwargs=None):
@@ -263,3 +266,55 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
263266
return Float8Tensor(
264267
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
265268
)
269+
270+
271+
@implements([aten.index_put_.default])
272+
def index_put_fp8(aten_op, args, kwargs=None):
273+
fp8_self = args[0]
274+
fp8_values = args[2]
275+
assert isinstance(fp8_self, Float8Tensor)
276+
assert isinstance(fp8_values, Float8Tensor)
277+
assert fp8_self._scale == fp8_values._scale
278+
assert fp8_self.dtype == fp8_values.dtype
279+
assert fp8_self._orig_dtype == fp8_values._orig_dtype
280+
281+
fp8_data = fp8_self._data
282+
fp8_values_data = fp8_values._data
283+
fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs)
284+
return Float8Tensor(
285+
fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config
286+
)
287+
288+
289+
@implements([aten.copy_.default])
290+
def copy_fp8(aten_op, args, kwargs=None):
291+
# For a copy op with Float8Tensors involved, only the following combinations are allowed:
292+
# 1. self is a high precision (hp) tensor, src is a Float8Tensor:
293+
# in this case src is upcasted and unscaled to go into the hp tensor
294+
# 2. self and src are Float8Tensors:
295+
# the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat)
296+
# Every other combination is banned as the semantics are not well defined
297+
298+
self = args[0]
299+
src = args[1]
300+
301+
if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
302+
src_hp = src.to_original_precision()
303+
return aten_op(self, src_hp, *args[2:], **kwargs)
304+
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
305+
assert (
306+
self._orig_dtype == src._orig_dtype
307+
), "Expecting both Float8Tensors to be of the same dtype"
308+
assert (
309+
self._scale == src._scale
310+
), "Expecting both Float8Tensors to have thee same scale"
311+
assert (
312+
self._mm_config == src._mm_config
313+
), "Expecting both Float8Tensors to have thee same mm config"
314+
assert (
315+
self._data.dtype == src._data.dtype
316+
), "Expecting both Float8Tensors to be of the same dtypet"
317+
fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs)
318+
return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config)
319+
else:
320+
raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor")

test/test_base.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,44 @@ def test_split_cat(self):
8383
catted = torch.cat(splits, dim=0)
8484
assert bitwise_identical(fp8_a, catted)
8585

86+
def test_index_put(self):
87+
a = torch.rand(16, dtype=torch.bfloat16)
88+
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
89+
fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn)
90+
91+
index = torch.randint(0, 15, (16,), dtype=torch.long)
92+
93+
b = torch.rand(16, 16, dtype=torch.bfloat16)
94+
scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
95+
fp8_b = Float8Tensor.to_float8(b, scale_a, torch.float8_e4m3fn)
96+
fp8_b_bad = Float8Tensor.to_float8(b, scale_b, torch.float8_e4m3fn)
97+
98+
with self.assertRaises(AssertionError):
99+
b[index] = fp8_a
100+
fp8_b[index] = a
101+
fp8_b_bad[index] = fp8_a
102+
fp8_b[index] = fp8_a
103+
104+
def test_copy_(self):
105+
a = torch.rand(16, dtype=torch.bfloat16)
106+
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
107+
fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn)
108+
109+
b = torch.empty(16, dtype=torch.bfloat16)
110+
b.copy_(fp8_a) # Should work
111+
torch.testing.assert_close(b, fp8_a.to_original_precision())
112+
with self.assertRaises(RuntimeError):
113+
fp8_a.copy_(b) # Should fail
114+
115+
fp8_b = Float8Tensor(
116+
torch.empty(16, dtype=torch.float8_e4m3fn),
117+
scale_a,
118+
torch.bfloat16,
119+
fp8_a._mm_config,
120+
)
121+
fp8_b.copy_(fp8_a)
122+
torch.testing.assert_close(fp8_a._data, fp8_b._data)
123+
86124

87125
class TestFloat8Linear:
88126
def _test_linear_impl(

0 commit comments

Comments
 (0)