Skip to content

Commit af440de

Browse files
Add HasShape mixin
1 parent 7705b3e commit af440de

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

aesara/graph/type.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import abstractmethod
2-
from typing import Any, Optional, Text, TypeVar, Union
2+
from typing import Any, Optional, Text, Tuple, TypeVar, Union
33

44
from typing_extensions import TypeAlias
55

@@ -257,6 +257,13 @@ def values_eq_approx(cls, a: Any, b: Any):
257257

258258

259259
class HasDataType:
260-
"""A mixing for an `Op` type that has a :attr:`dtype` attribute."""
260+
"""A mixin for a type that has a :attr:`dtype` attribute."""
261261

262262
dtype: str
263+
264+
265+
class HasShape:
266+
"""A mixin for a type that has :attr:`shape` and :attr:`ndim` attributes."""
267+
268+
ndim: int
269+
shape: Tuple[Optional[int], ...]

aesara/tensor/type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from aesara import scalar as aes
99
from aesara.configdefaults import config
1010
from aesara.graph.basic import Variable
11-
from aesara.graph.type import HasDataType
11+
from aesara.graph.type import HasDataType, HasShape
1212
from aesara.graph.utils import MetaType
1313
from aesara.link.c.type import CType
1414
from aesara.misc.safe_asarray import _asarray
@@ -48,7 +48,7 @@
4848
}
4949

5050

51-
class TensorType(CType, HasDataType):
51+
class TensorType(CType, HasDataType, HasShape):
5252
r"""Symbolic `Type` representing `numpy.ndarray`\s."""
5353

5454
__props__: Tuple[str, ...] = ("dtype", "shape")

0 commit comments

Comments
 (0)