Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 4d7d771

Browse files
committed
fix tests
1 parent ca84eb4 commit 4d7d771

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

float8_experimental/float8_tensor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
266266

267267
# All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs
268268
# And don't support mixed tensor subclasses. This will trigger the handler for
269-
# the next type in the dispatch list. torch._C._TensorMeta is for FakeTensor
269+
# the next type in the dispatch list
270270
def allowed_subclasses(type):
271-
return issubclass(cls, type) or isinstance(type, torch._C._TensorMeta)
271+
return (
272+
issubclass(cls, type)
273+
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
274+
or issubclass(
275+
torch._subclasses.functional_tensor.FunctionalTensor, type
276+
)
277+
)
272278

273279
if not all(allowed_subclasses(t) for t in types):
274280
return NotImplemented

test/test_everything.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ pytest test/test_compile.py
99
./test/test_fsdp.sh
1010
./test/test_fsdp_compile.sh
1111
./test/test_tp.sh
12+
./test/test_dtensor.sh
1213

1314
echo "all tests successful"

0 commit comments

Comments
 (0)