Skip to content

Commit c76d626

Browse files
committed
Allow einsum to work with inputs of unknown static shape
1 parent 0ba2fa4 commit c76d626

File tree

2 files changed

+104
-20
lines changed

2 files changed

+104
-20
lines changed

pytensor/tensor/einsum.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import itertools
23
from collections.abc import Sequence
34
from functools import partial, reduce
45
from itertools import pairwise
@@ -9,6 +10,8 @@
910
normalize_axis_index,
1011
normalize_axis_tuple,
1112
)
13+
from opt_einsum.helpers import find_contraction
14+
from opt_einsum.parser import parse_einsum_input
1215

1316
from pytensor.compile.builders import OpFromGraph
1417
from pytensor.tensor import TensorLike, vectorize
@@ -32,14 +35,13 @@ class Einsum(OpFromGraph):
3235
Wrapper Op for Einsum graphs
3336
"""
3437

35-
__props__ = ("subscripts", "optimize")
38+
__props__ = ("subscripts", "path", "optimized")
3639

37-
def __init__(
38-
self, *args, subscripts: str, optimize: str | None = "optimal", **kwargs
39-
):
40+
def __init__(self, *args, subscripts: str, path: str, optimized: bool, **kwargs):
4041
self.subscripts = subscripts
41-
self.optimize = optimize
42-
super().__init__(*args, **kwargs)
42+
self.path = path
43+
self.optimized = optimized
44+
super().__init__(*args, **kwargs, strict=True)
4345

4446

4547
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
@@ -140,6 +142,57 @@ def _general_dot(
140142
return cast(TensorVariable, out)
141143

142144

145+
PATH = tuple[tuple[int] | tuple[int, int]]
146+
147+
148+
def contraction_list_from_path(
149+
subscripts: str, operands: Sequence[TensorLike], path: PATH
150+
):
151+
"""TODO Docstrings
152+
153+
Code adapted from einsum_opt
154+
"""
155+
fake_operands = [
156+
np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands
157+
]
158+
input_subscripts, output_subscript, operands = parse_einsum_input(
159+
(subscripts, *fake_operands)
160+
)
161+
162+
# Build a few useful list and sets
163+
input_list = input_subscripts.split(",")
164+
input_sets = [set(x) for x in input_list]
165+
output_set = set(output_subscript)
166+
167+
# Build contraction tuple (positions, gemm, einsum_str, remaining)
168+
contraction_list = []
169+
for cnum, contract_inds in enumerate(path):
170+
# Make sure we remove inds from right to left
171+
contract_inds = tuple(sorted(contract_inds, reverse=True))
172+
173+
contract_tuple = find_contraction(contract_inds, input_sets, output_set)
174+
out_inds, input_sets, idx_removed, idx_contract = contract_tuple
175+
176+
tmp_inputs = [input_list.pop(x) for x in contract_inds]
177+
178+
# Last contraction
179+
if (cnum - len(path)) == -1:
180+
idx_result = output_subscript
181+
else:
182+
# use tensordot order to minimize transpositions
183+
all_input_inds = "".join(tmp_inputs)
184+
idx_result = "".join(sorted(out_inds, key=all_input_inds.find))
185+
186+
input_list.append(idx_result)
187+
einsum_str = ",".join(tmp_inputs) + "->" + idx_result
188+
189+
# We only need the first three inputs to build the forward graph
190+
contraction = (contract_inds, idx_removed, einsum_str, None, None)
191+
contraction_list.append(contraction)
192+
193+
return contraction_list
194+
195+
143196
def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
144197
"""
145198
Multiplication and summation of tensors using the Einstein summation convention.
@@ -166,18 +219,35 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
166219
# TODO: Do we need this as dependency?
167220
from opt_einsum import contract_path
168221

169-
operands = cast(tuple[TensorVariable], tuple(map(as_tensor, operands)))
222+
operands = [as_tensor(operand) for operand in operands]
170223
shapes = [operand.type.shape for operand in operands]
171224

172-
# TODE: Do fast path at creation time, and optimize only in fast_run
173-
_, contraction_list = contract_path(
174-
subscripts,
175-
*shapes,
176-
einsum_call=True,
177-
use_blas=True,
178-
optimize="optimal",
179-
shapes=True,
180-
)
225+
if None in itertools.chain.from_iterable(shapes):
226+
# We mark optimized = False, even in cases where there is no ordering optimization to be done
227+
# because the inner graph may have to accommodate dynamic shapes.
228+
# If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
229+
if len(operands) == 1:
230+
path = [(0,)]
231+
else:
232+
# Create default path of repeating (1,0) that executes left to right cyclically
233+
# with intermediate outputs being pushed to the end of the stack
234+
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
235+
path = [(1, 0) for i in range(len(operands) - 1)]
236+
contraction_list = contraction_list_from_path(subscripts, operands, path)
237+
optimized = (
238+
len(operands) <= 2
239+
) # If there are only 1 or 2 operands, there is no optimization to be done?
240+
else:
241+
_, contraction_list = contract_path(
242+
subscripts,
243+
*shapes,
244+
einsum_call=True,
245+
use_blas=True,
246+
optimize="optimal",
247+
shapes=True,
248+
)
249+
path = [contraction[0] for contraction in contraction_list]
250+
optimized = True
181251

182252
def sum_uniques(
183253
operand: TensorVariable, names: str, uniques: list[str]
@@ -244,6 +314,7 @@ def sum_repeats(
244314
lhs, rhs = map(einsum_operands.pop, operand_indices)
245315
lhs_names, rhs_names = input_names
246316

317+
# TODO: Do this as well?
247318
# handle cases where one side of a contracting or batch dimension is 1
248319
# but its counterpart is not.
249320
# lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
@@ -321,6 +392,10 @@ def sum_repeats(
321392
axes=(lhs_cont, rhs_cont),
322393
batch_axes=(lhs_batch, rhs_batch),
323394
)
395+
else:
396+
raise ValueError(
397+
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}"
398+
)
324399

325400
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
326401
assert len(names) == len(result_names) == len(set(names))
@@ -336,5 +411,7 @@ def sum_repeats(
336411
subscripts=subscripts,
337412
inputs=list(operands),
338413
outputs=[einsum_result],
414+
path=tuple(path),
415+
optimized=optimized,
339416
)(*operands)
340417
return cast(TensorVariable, out)

tests/tensor/test_einsum.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_general_dot():
7272
)
7373

7474

75+
@pytest.mark.parametrize("static_shape_known", [True, False])
7576
@pytest.mark.parametrize(
7677
"signature",
7778
[
@@ -95,16 +96,23 @@ def test_general_dot():
9596
"oij,imj,mjkn,lnk,plk->op",
9697
],
9798
)
98-
def test_parse_einsum_input(signature):
99+
def test_einsum_signatures(static_shape_known, signature):
99100
letters_to_dims = dict(zip("ijklmnop", [2, 3, 5, 7, 11, 13, 17, 19], strict=True))
100101

101102
inputs = signature.split("->")[0].split(",")
102103

103104
shapes = [tuple(letters_to_dims[letter] for letter in inp) for inp in inputs]
105+
if static_shape_known:
106+
static_shapes = shapes
107+
else:
108+
static_shapes = [[None] * len(shape) for shape in shapes]
109+
104110
operands = [
105-
pt.tensor(name, shape=shape) for name, shape in zip(ascii_lowercase, shapes)
111+
pt.tensor(name, shape=static_shape)
112+
for name, static_shape in zip(ascii_lowercase, static_shapes)
106113
]
107114
out = pt.einsum(signature, *operands)
115+
assert out.owner.op.optimized == static_shape_known or len(operands) <= 2
108116

109117
rng = np.random.default_rng(37)
110118
test_values = [rng.normal(size=shape) for shape in shapes]
@@ -113,9 +121,8 @@ def test_parse_einsum_input(signature):
113121
fn = function(operands, out)
114122
pt_out = fn(*test_values)
115123

116-
# print()
117124
# import pytensor
118-
# pytensor.dprint(fn, print_type=True)
125+
# print(); pytensor.dprint(fn, print_type=True)
119126

120127
# assert out.type.shape == np_out.shape # Reshape operations lose static shape
121128
np.testing.assert_allclose(pt_out, np_out)

0 commit comments

Comments
 (0)