@@ -50,7 +50,38 @@ def float8_desugar_op(aten_op, args, kwargs=None):
50
50
)
51
51
52
52
53
- @implements ([aten .sum .dim_IntList ])
53
+ @implements ([aten .split .Tensor ])
54
+ def float8_split (aten_op , args , kwargs = None ):
55
+ new_data_tensors = aten_op (args [0 ]._data , * args [1 :], ** kwargs )
56
+
57
+ def make_float8 (data ):
58
+ return Float8Tensor (
59
+ data , args [0 ]._scale , args [0 ]._orig_dtype , args [0 ]._mm_config
60
+ )
61
+
62
+ out = map (make_float8 , new_data_tensors )
63
+ return list (out )
64
+
65
+
66
+ # Errors cant `cat_cuda float8 e4m3fn`
67
+ # @implements([aten.cat.default])
68
+ # def float8_cat(aten_op, args, kwargs=None):
69
+ # chunked_tensors: Tuple[Float8Tensor] = args[0]
70
+
71
+ # orig_dtype = args[0][0]._orig_dtype
72
+ # scale = args[0][0]._scale
73
+ # mm_config = args[0][0]._mm_config
74
+ # chunk_data =[]
75
+ # for chunk in chunked_tensors:
76
+ # assert chunk._orig_dtype == orig_dtype, "Expecting all chunks to be of the same dtype"
77
+ # assert chunk._scale is scale, "Expecting all chunks to have thee same scale as a result of a split"
78
+ # assert chunk._mm_config is mm_config, "Expecting all chunks to have thee same mm config as a result of a split"
79
+ # chunk_data.append(chunk._data)
80
+ # new_data = aten_op(chunk_data, *args[1:], **kwargs)
81
+ # return Float8Tensor(new_data, scale, orig_dtype, mm_config)
82
+
83
+
84
+ @implements ([aten .sum .dim_IntList , aten .cat .default ])
54
85
def float8_cast_up_op (aten_op , args , kwargs = None ):
55
86
"""Be careful with this function, this is a "fallback" op that
56
87
casts the output of the op to the original precision. And performs the op.
0 commit comments