Skip to content

Commit 3fe3257

Browse files
committed
Allow einsum to work with inputs of unknown static shape
1 parent 967c7d7 commit 3fe3257

File tree

2 files changed

+92
-18
lines changed

2 files changed

+92
-18
lines changed

pytensor/tensor/einsum.py

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import itertools
23
from collections.abc import Sequence
34
from functools import partial, reduce
45
from itertools import pairwise
@@ -9,6 +10,8 @@
910
normalize_axis_index,
1011
normalize_axis_tuple,
1112
)
13+
from opt_einsum.helpers import find_contraction
14+
from opt_einsum.parser import parse_einsum_input
1215

1316
from pytensor.compile.builders import OpFromGraph
1417
from pytensor.tensor import TensorLike, vectorize
@@ -32,14 +35,15 @@ class Einsum(OpFromGraph):
3235
Wrapper Op for Einsum graphs
3336
"""
3437

35-
__props__ = ("subscripts", "optimize")
38+
__props__ = ("subscripts", "path", "optimized")
3639

3740
def __init__(
38-
self, *args, subscripts: str, optimize: str | None = "optimal", **kwargs
41+
self, *args, subscripts: str, path: str, optimized: bool, **kwargs
3942
):
4043
self.subscripts = subscripts
41-
self.optimize = optimize
42-
super().__init__(*args, **kwargs)
44+
self.path = path
45+
self.optimized = optimized
46+
super().__init__(*args, **kwargs, strict=True)
4347

4448

4549
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
@@ -142,6 +146,50 @@ def _general_dot(
142146
return cast(TensorVariable, out)
143147

144148

149+
PATH = tuple[tuple[int] | tuple[int, int]]
150+
151+
def contraction_list_from_path(subscripts: str, operands: Sequence[TensorLike], path: PATH):
152+
"""TODO Docstrings
153+
154+
Code adapted from einsum_opt
155+
"""
156+
fake_operands = [np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands]
157+
input_subscripts, output_subscript, operands = parse_einsum_input((subscripts, *fake_operands))
158+
159+
# Build a few useful list and sets
160+
input_list = input_subscripts.split(',')
161+
input_sets = [set(x) for x in input_list]
162+
output_set = set(output_subscript)
163+
164+
# Build contraction tuple (positions, gemm, einsum_str, remaining)
165+
contraction_list = []
166+
for cnum, contract_inds in enumerate(path):
167+
# Make sure we remove inds from right to left
168+
contract_inds = tuple(sorted(list(contract_inds), reverse=True))
169+
170+
contract_tuple = find_contraction(contract_inds, input_sets, output_set)
171+
out_inds, input_sets, idx_removed, idx_contract = contract_tuple
172+
173+
tmp_inputs = [input_list.pop(x) for x in contract_inds]
174+
175+
# Last contraction
176+
if (cnum - len(path)) == -1:
177+
idx_result = output_subscript
178+
else:
179+
# use tensordot order to minimize transpositions
180+
all_input_inds = "".join(tmp_inputs)
181+
idx_result = "".join(sorted(out_inds, key=all_input_inds.find))
182+
183+
input_list.append(idx_result)
184+
einsum_str = ",".join(tmp_inputs) + "->" + idx_result
185+
186+
# We only need the first three inputs to build the forward graph
187+
contraction = (contract_inds, idx_removed, einsum_str, None, None)
188+
contraction_list.append(contraction)
189+
190+
return contraction_list
191+
192+
145193
def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
146194
"""
147195
Multiplication and summation of tensors using the Einstein summation convention.
@@ -168,18 +216,33 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
168216
# TODO: Do we need this as dependency?
169217
from opt_einsum import contract_path
170218

171-
operands = cast(tuple[TensorVariable], tuple(map(as_tensor, operands)))
219+
operands = [as_tensor(operand) for operand in operands]
172220
shapes = [operand.type.shape for operand in operands]
173221

174-
# TODE: Do fast path at creation time, and optimize only in fast_run
175-
_, contraction_list = contract_path(
176-
subscripts,
177-
*shapes,
178-
einsum_call=True,
179-
use_blas=True,
180-
optimize="optimal",
181-
shapes=True,
182-
)
222+
if None in itertools.chain.from_iterable(shapes):
223+
# We mark optimized = False, even in cases where there is no ordering optimization to be done
224+
# because the inner graph may have to accommodate dynamic shapes.
225+
# If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
226+
if len(operands) == 1:
227+
path = [(0,)]
228+
else:
229+
# Create default path of repeating (1,0) that executes left to right cyclically
230+
# with intermediate outputs being pushed to the end of the stack
231+
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
232+
path = [(1,0) for i in range(len(operands) - 1)]
233+
contraction_list = contraction_list_from_path(subscripts, operands, path)
234+
optimized = False
235+
else:
236+
_, contraction_list = contract_path(
237+
subscripts,
238+
*shapes,
239+
einsum_call=True,
240+
use_blas=True,
241+
optimize="optimal",
242+
shapes=True,
243+
)
244+
path = [contraction[0] for contraction in contraction_list]
245+
optimized = True
183246

184247
def sum_uniques(
185248
operand: TensorVariable, names: str, uniques: list[str]
@@ -246,6 +309,7 @@ def sum_repeats(
246309
lhs, rhs = map(einsum_operands.pop, operand_indices)
247310
lhs_names, rhs_names = input_names
248311

312+
# TODO: Do this as well?
249313
# handle cases where one side of a contracting or batch dimension is 1
250314
# but its counterpart is not.
251315
# lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
@@ -323,6 +387,8 @@ def sum_repeats(
323387
axes=(lhs_cont, rhs_cont),
324388
batch_axes=(lhs_batch, rhs_batch),
325389
)
390+
else:
391+
raise ValueError(f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}")
326392

327393
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
328394
assert len(names) == len(result_names) == len(set(names))
@@ -338,5 +404,7 @@ def sum_repeats(
338404
subscripts=subscripts,
339405
inputs=list(operands),
340406
outputs=[einsum_result],
407+
path=tuple(path),
408+
optimized=optimized,
341409
)(*operands)
342410
return cast(TensorVariable, out)

tests/tensor/test_einsum.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_general_dot():
7272
)
7373

7474

75+
@pytest.mark.parametrize("static_shape_known", [True, False])
7576
@pytest.mark.parametrize(
7677
"signature",
7778
[
@@ -95,16 +96,22 @@ def test_general_dot():
9596
"oij,imj,mjkn,lnk,plk->op",
9697
],
9798
)
98-
def test_parse_einsum_input(signature):
99+
def test_enisum_signatures(static_shape_known, signature):
99100
letters_to_dims = dict(zip("ijklmnop", [2, 3, 5, 7, 11, 13, 17, 19], strict=True))
100101

101102
inputs = signature.split("->")[0].split(",")
102103

103104
shapes = [tuple(letters_to_dims[letter] for letter in inp) for inp in inputs]
105+
if static_shape_known:
106+
static_shapes = shapes
107+
else:
108+
static_shapes = [[None] * len(shape) for shape in shapes]
109+
104110
operands = [
105-
pt.tensor(name, shape=shape) for name, shape in zip(ascii_lowercase, shapes)
111+
pt.tensor(name, shape=static_shape) for name, static_shape in zip(ascii_lowercase, static_shapes)
106112
]
107113
out = pt.einsum(signature, *operands)
114+
assert out.owner.op.optimized == static_shape_known
108115

109116
rng = np.random.default_rng(37)
110117
test_values = [rng.normal(size=shape) for shape in shapes]
@@ -113,9 +120,8 @@ def test_parse_einsum_input(signature):
113120
fn = function(operands, out)
114121
pt_out = fn(*test_values)
115122

116-
# print()
117123
# import pytensor
118-
# pytensor.dprint(fn, print_type=True)
124+
# print(); pytensor.dprint(fn, print_type=True)
119125

120126
# assert out.type.shape == np_out.shape # Reshape operations lose static shape
121127
np.testing.assert_allclose(pt_out, np_out)

0 commit comments

Comments
 (0)