Skip to content

Commit 624e832

Browse files
wyliNic-Ma
andauthored
5532 tensor.shape in tensor.Size (#5533)
Signed-off-by: Wenqi Li <[email protected]> Fixes #5532 ### Description returns early if it's torch.Size ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]> Co-authored-by: Nic Ma <[email protected]>
1 parent bb3ecf6 commit 624e832

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

monai/data/meta_tensor.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import functools
1415
import warnings
1516
from copy import deepcopy
1617
from typing import Any, Sequence
@@ -29,6 +30,24 @@
2930
__all__ = ["MetaTensor"]
3031

3132

33+
@functools.lru_cache(None)
34+
def _get_named_tuple_like_type(func):
35+
if (
36+
hasattr(torch, "return_types")
37+
and hasattr(func, "__name__")
38+
and hasattr(torch.return_types, func.__name__)
39+
and isinstance(getattr(torch.return_types, func.__name__), type)
40+
):
41+
return getattr(torch.return_types, func.__name__)
42+
return None
43+
44+
45+
def _not_requiring_metadata(ret):
46+
return isinstance(ret, (int, str, bytes, torch.Size, torch.dtype, torch.device, np.ndarray)) or not (
47+
isinstance(ret, MetaTensor) or (isinstance(ret, Sequence) and any(isinstance(x, MetaTensor) for x in ret))
48+
)
49+
50+
3251
class MetaTensor(MetaObj, torch.Tensor):
3352
"""
3453
Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata.
@@ -253,20 +272,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
253272
# we might have 1 or multiple outputs. Might be MetaTensor, might be something
254273
# else (e.g., `__repr__` returns a string).
255274
# Convert to list (if necessary), process, and at end remove list if one was added.
256-
if (
257-
hasattr(torch, "return_types")
258-
and hasattr(func, "__name__")
259-
and hasattr(torch.return_types, func.__name__)
260-
and isinstance(getattr(torch.return_types, func.__name__), type)
261-
and isinstance(ret, getattr(torch.return_types, func.__name__))
262-
):
275+
if _not_requiring_metadata(ret):
276+
return ret
277+
if _get_named_tuple_like_type(func) is not None and isinstance(ret, _get_named_tuple_like_type(func)):
263278
# for torch.max(torch.tensor(1.0), dim=0), the return type is named-tuple like
264279
out_items = MetaTensor.update_meta(ret, func, args, kwargs)
265280
for idx in range(ret.n_fields):
266281
ret[idx].meta = out_items[idx].meta
267282
ret[idx].applied_operations = out_items[idx].applied_operations
268283
return ret
269-
if isinstance(ret, (str, bytes)) or not isinstance(ret, Sequence):
284+
if not isinstance(ret, Sequence):
270285
ret = [ret]
271286
unpack = True
272287
else:

tests/test_meta_tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,12 @@ def test_str(self):
424424
t = MetaTensor([1.0], affine=torch.tensor(1), meta={"fname": "filename"})
425425
self.assertEqual(str(t), "tensor([1.])")
426426

427+
def test_shape(self):
428+
s = MetaTensor([1])
429+
self.assertEqual(s.shape, torch.Size([1]))
430+
self.assertEqual(s.size(), torch.Size([1]))
431+
self.assertEqual(s.size(0), 1)
432+
427433
def test_astype(self):
428434
t = MetaTensor([1.0], affine=torch.tensor(1), meta={"fname": "filename"})
429435
for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.compat.long, np.uint16):

0 commit comments

Comments
 (0)