|
11 | 11 |
|
12 | 12 | from __future__ import annotations
|
13 | 13 |
|
| 14 | +import functools |
14 | 15 | import warnings
|
15 | 16 | from copy import deepcopy
|
16 | 17 | from typing import Any, Sequence
|
|
29 | 30 | __all__ = ["MetaTensor"]
|
30 | 31 |
|
31 | 32 |
|
| 33 | +@functools.lru_cache(None) |
| 34 | +def _get_named_tuple_like_type(func): |
| 35 | + if ( |
| 36 | + hasattr(torch, "return_types") |
| 37 | + and hasattr(func, "__name__") |
| 38 | + and hasattr(torch.return_types, func.__name__) |
| 39 | + and isinstance(getattr(torch.return_types, func.__name__), type) |
| 40 | + ): |
| 41 | + return getattr(torch.return_types, func.__name__) |
| 42 | + return None |
| 43 | + |
| 44 | + |
| 45 | +def _not_requiring_metadata(ret): |
| 46 | + return isinstance(ret, (int, str, bytes, torch.Size, torch.dtype, torch.device, np.ndarray)) or not ( |
| 47 | + isinstance(ret, MetaTensor) or (isinstance(ret, Sequence) and any(isinstance(x, MetaTensor) for x in ret)) |
| 48 | + ) |
| 49 | + |
| 50 | + |
32 | 51 | class MetaTensor(MetaObj, torch.Tensor):
|
33 | 52 | """
|
34 | 53 | Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata.
|
@@ -253,20 +272,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
|
253 | 272 | # we might have 1 or multiple outputs. Might be MetaTensor, might be something
|
254 | 273 | # else (e.g., `__repr__` returns a string).
|
255 | 274 | # Convert to list (if necessary), process, and at end remove list if one was added.
|
256 |
| - if ( |
257 |
| - hasattr(torch, "return_types") |
258 |
| - and hasattr(func, "__name__") |
259 |
| - and hasattr(torch.return_types, func.__name__) |
260 |
| - and isinstance(getattr(torch.return_types, func.__name__), type) |
261 |
| - and isinstance(ret, getattr(torch.return_types, func.__name__)) |
262 |
| - ): |
| 275 | + if _not_requiring_metadata(ret): |
| 276 | + return ret |
| 277 | + if _get_named_tuple_like_type(func) is not None and isinstance(ret, _get_named_tuple_like_type(func)): |
263 | 278 | # for torch.max(torch.tensor(1.0), dim=0), the return type is named-tuple like
|
264 | 279 | out_items = MetaTensor.update_meta(ret, func, args, kwargs)
|
265 | 280 | for idx in range(ret.n_fields):
|
266 | 281 | ret[idx].meta = out_items[idx].meta
|
267 | 282 | ret[idx].applied_operations = out_items[idx].applied_operations
|
268 | 283 | return ret
|
269 |
| - if isinstance(ret, (str, bytes)) or not isinstance(ret, Sequence): |
| 284 | + if not isinstance(ret, Sequence): |
270 | 285 | ret = [ret]
|
271 | 286 | unpack = True
|
272 | 287 | else:
|
|
0 commit comments