Skip to content

Commit 99e7eab

Browse files
committed
Revert optimized->optimize and clarify
1 parent 36c8e1a commit 99e7eab

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

pytensor/link/jax/dispatch/einsum.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
@jax_funcify.register(Einsum)
88
def jax_funcify_Einsum(op, **kwargs):
99
subscripts = op.subscripts
10-
optimize = op.optimize
1110

1211
def einsum(*operands):
13-
return jnp.einsum(subscripts, *operands, optimize=optimize)
12+
return jnp.einsum(subscripts, *operands, optimize="optimal")
1413

1514
return einsum

pytensor/tensor/einsum.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,27 @@
3333
class Einsum(OpFromGraph):
3434
"""
3535
Wrapper Op for Einsum graphs
36+
37+
Notes
38+
-----
39+
The `optimized` prop indicates whether the inner graph was optimized, which can only be done when all shapes are
40+
statically known. This is now determined at graph creation time only. We could introduce a rewrite that tries to
41+
optimize the graph if static shapes become known later (e.g., after use of `clone_replace` or shape inference during
42+
rewrites).
43+
44+
Also, once the graph is optimized, it could be inlined for potential further optimization that consider the rest of
45+
the graph.
46+
47+
This prop is different from the `optimize` kwarg in numpy that determines what kind (if any) of optimization is
48+
desired. We haven't decided whether we want to provide this functionality.
3649
"""
3750

38-
__props__ = ("subscripts", "path", "optimize")
51+
__props__ = ("subscripts", "path", "optimized")
3952

40-
def __init__(self, *args, subscripts: str, path: str, optimize: bool, **kwargs):
53+
def __init__(self, *args, subscripts: str, path: str, optimized: bool, **kwargs):
4154
self.subscripts = subscripts
4255
self.path = path
43-
self.optimize = optimize
56+
self.optimized = optimized
4457
super().__init__(*args, **kwargs, strict=True)
4558

4659

@@ -368,7 +381,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
368381
contraction_list = _contraction_list_from_path(subscripts, operands, path)
369382

370383
# If there are only 1 or 2 operands, there is no optimization to be done?
371-
optimize = len(operands) <= 2
384+
optimized = len(operands) <= 2
372385
else:
373386
# Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
374387
# contraction order.
@@ -382,7 +395,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
382395
optimize="optimal",
383396
)
384397
path = [contraction[0] for contraction in contraction_list]
385-
optimize = True
398+
optimized = True
386399

387400
def sum_uniques(
388401
operand: TensorVariable, names: str, uniques: list[str]
@@ -550,6 +563,6 @@ def sum_repeats(
550563
inputs=list(operands),
551564
outputs=[einsum_result],
552565
path=tuple(path),
553-
optimize=optimize,
566+
optimized=optimized,
554567
)(*operands)
555568
return cast(TensorVariable, out)

tests/tensor/test_einsum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_einsum_signatures(static_shape_known, signature):
130130
for name, static_shape in zip(ascii_lowercase, static_shapes)
131131
]
132132
out = pt.einsum(signature, *operands)
133-
assert out.owner.op.optimize == static_shape_known or len(operands) <= 2
133+
assert out.owner.op.optimized == static_shape_known or len(operands) <= 2
134134

135135
rng = np.random.default_rng(37)
136136
test_values = [rng.normal(size=shape).astype(floatX) for shape in shapes]

0 commit comments

Comments
 (0)