@@ -42,6 +42,9 @@ def decorator(func):
42
42
aten .as_strided .default ,
43
43
aten .clone .default ,
44
44
aten .detach .default ,
45
+ aten .slice .Tensor ,
46
+ aten .transpose .int ,
47
+ aten .fill_ .Scalar ,
45
48
]
46
49
)
47
50
def float8_desugar_op (aten_op , args , kwargs = None ):
@@ -263,3 +266,55 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
263
266
return Float8Tensor (
264
267
fp8_out , fp8_input ._scale , fp8_input ._orig_dtype , fp8_input ._mm_config
265
268
)
269
+
270
+
271
+ @implements ([aten .index_put_ .default ])
272
+ def index_put_fp8 (aten_op , args , kwargs = None ):
273
+ fp8_self = args [0 ]
274
+ fp8_values = args [2 ]
275
+ assert isinstance (fp8_self , Float8Tensor )
276
+ assert isinstance (fp8_values , Float8Tensor )
277
+ assert fp8_self ._scale == fp8_values ._scale
278
+ assert fp8_self .dtype == fp8_values .dtype
279
+ assert fp8_self ._orig_dtype == fp8_values ._orig_dtype
280
+
281
+ fp8_data = fp8_self ._data
282
+ fp8_values_data = fp8_values ._data
283
+ fp8_out = aten_op (fp8_data , args [1 ], fp8_values_data , * args [3 :], ** kwargs )
284
+ return Float8Tensor (
285
+ fp8_out , fp8_self ._scale , fp8_self ._orig_dtype , fp8_self ._mm_config
286
+ )
287
+
288
+
289
+ @implements ([aten .copy_ .default ])
290
+ def copy_fp8 (aten_op , args , kwargs = None ):
291
+ # For a copy op with Float8Tensors involved, only the following combinations are allowed:
292
+ # 1. self is a high precision (hp) tensor, src is a Float8Tensor:
293
+ # in this case src is upcasted and unscaled to go into the hp tensor
294
+ # 2. self and src are Float8Tensors:
295
+ # the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat)
296
+ # Every other combination is banned as the semantics are not well defined
297
+
298
+ self = args [0 ]
299
+ src = args [1 ]
300
+
301
+ if not isinstance (self , Float8Tensor ) and isinstance (src , Float8Tensor ):
302
+ src_hp = src .to_original_precision ()
303
+ return aten_op (self , src_hp , * args [2 :], ** kwargs )
304
+ elif isinstance (self , Float8Tensor ) and isinstance (src , Float8Tensor ):
305
+ assert (
306
+ self ._orig_dtype == src ._orig_dtype
307
+ ), "Expecting both Float8Tensors to be of the same dtype"
308
+ assert (
309
+ self ._scale == src ._scale
310
+ ), "Expecting both Float8Tensors to have thee same scale"
311
+ assert (
312
+ self ._mm_config == src ._mm_config
313
+ ), "Expecting both Float8Tensors to have thee same mm config"
314
+ assert (
315
+ self ._data .dtype == src ._data .dtype
316
+ ), "Expecting both Float8Tensors to be of the same dtypet"
317
+ fp8_out = aten_op (self ._data , src ._data , * args [2 :], ** kwargs )
318
+ return Float8Tensor (fp8_out , self ._scale , self ._orig_dtype , self ._mm_config )
319
+ else :
320
+ raise RuntimeError ("Unsupported semantics for copy_ in Float8Tensor" )
0 commit comments