|
3 | 3 | #
|
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 |
| -from typing import Any, Dict |
| 6 | +from typing import Any, Dict, Tuple |
7 | 7 |
|
8 | 8 | import torch
|
9 | 9 |
|
@@ -64,24 +64,39 @@ def make_float8(data):
|
64 | 64 |
|
65 | 65 |
|
66 | 66 | # Errors cant `cat_cuda float8 e4m3fn`
|
67 |
| -# @implements([aten.cat.default]) |
68 |
| -# def float8_cat(aten_op, args, kwargs=None): |
69 |
| -# chunked_tensors: Tuple[Float8Tensor] = args[0] |
70 |
| - |
71 |
| -# orig_dtype = args[0][0]._orig_dtype |
72 |
| -# scale = args[0][0]._scale |
73 |
| -# mm_config = args[0][0]._mm_config |
74 |
| -# chunk_data =[] |
75 |
| -# for chunk in chunked_tensors: |
76 |
| -# assert chunk._orig_dtype == orig_dtype, "Expecting all chunks to be of the same dtype" |
77 |
| -# assert chunk._scale is scale, "Expecting all chunks to have thee same scale as a result of a split" |
78 |
| -# assert chunk._mm_config is mm_config, "Expecting all chunks to have thee same mm config as a result of a split" |
79 |
| -# chunk_data.append(chunk._data) |
80 |
| -# new_data = aten_op(chunk_data, *args[1:], **kwargs) |
81 |
| -# return Float8Tensor(new_data, scale, orig_dtype, mm_config) |
82 |
| - |
83 |
| - |
84 |
| -@implements([aten.sum.dim_IntList, aten.cat.default]) |
| 67 | +@implements([aten.cat.default]) |
| 68 | +def float8_cat(aten_op, args, kwargs=None): |
| 69 | + chunked_tensors: Tuple[Float8Tensor] = args[0] |
| 70 | + |
| 71 | + orig_dtype = chunked_tensors[0]._orig_dtype |
| 72 | + scale = chunked_tensors[0]._scale |
| 73 | + mm_config = chunked_tensors[0]._mm_config |
| 74 | + fp8_dtype = chunked_tensors[0]._data.dtype |
| 75 | + chunk_data = [] |
| 76 | + for chunk in chunked_tensors: |
| 77 | + assert isinstance( |
| 78 | + chunk, Float8Tensor |
| 79 | + ), "Expecting all chunks to be of type Float8Tensor" |
| 80 | + assert ( |
| 81 | + chunk._orig_dtype == orig_dtype |
| 82 | + ), "Expecting all chunks to be of the same dtype" |
| 83 | + assert ( |
| 84 | + chunk._scale is scale |
| 85 | + ), "Expecting all chunks to have thee same scale as a result of a split" |
| 86 | + assert ( |
| 87 | + chunk._mm_config is mm_config |
| 88 | + ), "Expecting all chunks to have thee same mm config as a result of a split" |
| 89 | + assert ( |
| 90 | + chunk._data.dtype == fp8_dtype |
| 91 | + ), "Expecting all chunks to be of the same dtype as a result of a split" |
| 92 | + chunk_data.append(chunk._data.view(torch.uint8)) |
| 93 | + |
| 94 | + new_data = aten_op(chunk_data, *args[1:], **kwargs) |
| 95 | + new_data = new_data.view(fp8_dtype) |
| 96 | + return Float8Tensor(new_data, scale, orig_dtype, mm_config) |
| 97 | + |
| 98 | + |
| 99 | +@implements([aten.sum.dim_IntList]) |
85 | 100 | def float8_cast_up_op(aten_op, args, kwargs=None):
|
86 | 101 | """Be careful with this function, this is a "fallback" op that
|
87 | 102 | casts the output of the op to the original precision. And performs the op.
|
|
0 commit comments