Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 93c8ac6

Browse files
committed
raise not implemented on other subclass types
1 parent c630f3b commit 93c8ac6

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

float8_experimental/float8_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from float8_experimental.float8_python_api import addmm_float8_unwrapped
1111
from float8_experimental.float8_tensor import Float8Tensor
1212
from float8_experimental.float8_utils import is_row_major
13-
from torch.distributed._functional_collectives import AsyncCollectiveTensor
1413
from torch.utils._pytree import tree_map
1514

1615
aten = torch.ops.aten
@@ -81,8 +80,6 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
8180
def float8_mm(aten_op, args, kwargs=None):
8281
a = args[0]
8382
b = args[1]
84-
if isinstance(b, AsyncCollectiveTensor):
85-
b = b.trigger_wait()
8683

8784
assert isinstance(a, Float8Tensor) and isinstance(
8885
b, Float8Tensor

float8_experimental/float8_tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
264264
# Lazy import to avoid circular dependency
265265
from float8_experimental.float8_ops import FLOAT8_OPS_TABLE
266266

267+
# All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs
268+
# And don't support mixed tensor subclasses. This will trigger the handler for
269+
# the next type in the dispatch list. torch._C._TensorMeta is for FakeTensor
270+
def allowed_subclasses(type):
271+
return issubclass(cls, type) or isinstance(type, torch._C._TensorMeta)
272+
273+
if not all(allowed_subclasses(t) for t in types):
274+
return NotImplemented
275+
267276
if func in FLOAT8_OPS_TABLE:
268277
return FLOAT8_OPS_TABLE[func](func, args, kwargs)
269278
raise NotImplementedError(f"attempting to run {func}, this is not supported")

0 commit comments

Comments
 (0)