Skip to content

Commit 0ba2fa4

Browse files
jessegrabowskizaxtax
authored andcommitted
Implement Einsum as OpFromGraph
Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Rob Zinkov <[email protected]>
1 parent afc1a6c commit 0ba2fa4

File tree

10 files changed

+597
-8
lines changed

10 files changed

+597
-8
lines changed

pytensor/link/jax/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pytensor.link.jax.dispatch.scan
1515
import pytensor.link.jax.dispatch.sparse
1616
import pytensor.link.jax.dispatch.blockwise
17+
import pytensor.link.jax.dispatch.einsum
1718
import pytensor.link.jax.dispatch.sort
1819

1920
# isort: on

pytensor/link/jax/dispatch/einsum.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.einsum import Einsum
5+
6+
7+
@jax_funcify.register(Einsum)
8+
def jax_funcify_Einsum(op, **kwargs):
9+
subscripts = op.subscripts
10+
optimize = op.optimize
11+
12+
def einsum(*operands):
13+
return jnp.einsum(subscripts, *operands, optimize=optimize)
14+
15+
return einsum

pytensor/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,5 +153,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
153153
from pytensor.tensor.functional import vectorize
154154
# isort: on
155155

156+
from pytensor.tensor.einsum import einsum
157+
156158

157159
__all__ = ["random"] # noqa: F405

pytensor/tensor/basic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,7 +1997,12 @@ def transpose(x, axes=None):
19971997
_x = as_tensor_variable(x)
19981998

19991999
if axes is None:
2000-
axes = list(range((_x.type.ndim - 1), -1, -1))
2000+
axes = tuple(range((_x.type.ndim - 1), -1, -1))
2001+
2002+
if tuple(axes) == tuple(range(len(axes))):
2003+
# No-op
2004+
return _x
2005+
20012006
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
20022007

20032008
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
@@ -3976,6 +3981,10 @@ def moveaxis(
39763981
source = normalize_axis_tuple(source, a.ndim, "source")
39773982
destination = normalize_axis_tuple(destination, a.ndim, "destination")
39783983

3984+
if source == destination:
3985+
# It's a no-op
3986+
return a
3987+
39793988
if len(source) != len(destination):
39803989
raise ValueError(
39813990
"`source` and `destination` arguments must have the same number of elements"
@@ -4290,9 +4299,7 @@ def atleast_Nd(
42904299
atleast_3d = partial(atleast_Nd, n=3)
42914300

42924301

4293-
def expand_dims(
4294-
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
4295-
) -> TensorVariable:
4302+
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
42964303
"""Expand the shape of an array.
42974304
42984305
Insert a new axis that will appear at the `axis` position in the expanded
@@ -4311,7 +4318,7 @@ def expand_dims(
43114318
"""
43124319
a = as_tensor(a)
43134320

4314-
if not isinstance(axis, tuple | list):
4321+
if not isinstance(axis, Sequence):
43154322
axis = (axis,)
43164323

43174324
out_ndim = len(axis) + a.ndim

0 commit comments

Comments
 (0)