Skip to content

Commit cbffae1

Browse files
committed
Allow einsum to work with inputs of unknown static shape
1 parent 43188a8 commit cbffae1

File tree

2 files changed

+102
-20
lines changed

2 files changed

+102
-20
lines changed

pytensor/tensor/einsum.py

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import itertools
23
from collections.abc import Sequence
34
from functools import partial, reduce
45
from itertools import pairwise
@@ -9,6 +10,8 @@
910
normalize_axis_index,
1011
normalize_axis_tuple,
1112
)
13+
from opt_einsum.helpers import find_contraction
14+
from opt_einsum.parser import parse_einsum_input
1215

1316
from pytensor.compile.builders import OpFromGraph
1417
from pytensor.tensor import TensorLike, vectorize
@@ -32,14 +35,13 @@ class Einsum(OpFromGraph):
3235
Wrapper Op for Einsum graphs
3336
"""
3437

35-
__props__ = ("subscripts", "optimize")
38+
__props__ = ("subscripts", "path", "optimized")
3639

37-
def __init__(
38-
self, *args, subscripts: str, optimize: str | None = "optimal", **kwargs
39-
):
40+
def __init__(self, *args, subscripts: str, path: str, optimized: bool, **kwargs):
4041
self.subscripts = subscripts
41-
self.optimize = optimize
42-
super().__init__(*args, **kwargs)
42+
self.path = path
43+
self.optimized = optimized
44+
super().__init__(*args, **kwargs, strict=True)
4345

4446

4547
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
@@ -142,6 +144,57 @@ def _general_dot(
142144
return cast(TensorVariable, out)
143145

144146

147+
PATH = tuple[tuple[int] | tuple[int, int]]
148+
149+
150+
def contraction_list_from_path(
151+
subscripts: str, operands: Sequence[TensorLike], path: PATH
152+
):
153+
"""TODO Docstrings
154+
155+
Code adapted from einsum_opt
156+
"""
157+
fake_operands = [
158+
np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands
159+
]
160+
input_subscripts, output_subscript, operands = parse_einsum_input(
161+
(subscripts, *fake_operands)
162+
)
163+
164+
# Build a few useful list and sets
165+
input_list = input_subscripts.split(",")
166+
input_sets = [set(x) for x in input_list]
167+
output_set = set(output_subscript)
168+
169+
# Build contraction tuple (positions, gemm, einsum_str, remaining)
170+
contraction_list = []
171+
for cnum, contract_inds in enumerate(path):
172+
# Make sure we remove inds from right to left
173+
contract_inds = tuple(sorted(contract_inds, reverse=True))
174+
175+
contract_tuple = find_contraction(contract_inds, input_sets, output_set)
176+
out_inds, input_sets, idx_removed, idx_contract = contract_tuple
177+
178+
tmp_inputs = [input_list.pop(x) for x in contract_inds]
179+
180+
# Last contraction
181+
if (cnum - len(path)) == -1:
182+
idx_result = output_subscript
183+
else:
184+
# use tensordot order to minimize transpositions
185+
all_input_inds = "".join(tmp_inputs)
186+
idx_result = "".join(sorted(out_inds, key=all_input_inds.find))
187+
188+
input_list.append(idx_result)
189+
einsum_str = ",".join(tmp_inputs) + "->" + idx_result
190+
191+
# We only need the first three inputs to build the forward graph
192+
contraction = (contract_inds, idx_removed, einsum_str, None, None)
193+
contraction_list.append(contraction)
194+
195+
return contraction_list
196+
197+
145198
def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
146199
"""
147200
Multiplication and summation of tensors using the Einstein summation convention.
@@ -168,18 +221,33 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
168221
# TODO: Do we need this as dependency?
169222
from opt_einsum import contract_path
170223

171-
operands = cast(tuple[TensorVariable], tuple(map(as_tensor, operands)))
224+
operands = [as_tensor(operand) for operand in operands]
172225
shapes = [operand.type.shape for operand in operands]
173226

174-
# TODE: Do fast path at creation time, and optimize only in fast_run
175-
_, contraction_list = contract_path(
176-
subscripts,
177-
*shapes,
178-
einsum_call=True,
179-
use_blas=True,
180-
optimize="optimal",
181-
shapes=True,
182-
)
227+
if None in itertools.chain.from_iterable(shapes):
228+
# We mark optimized = False, even in cases where there is no ordering optimization to be done
229+
# because the inner graph may have to accommodate dynamic shapes.
230+
# If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
231+
if len(operands) == 1:
232+
path = [(0,)]
233+
else:
234+
# Create default path of repeating (1,0) that executes left to right cyclically
235+
# with intermediate outputs being pushed to the end of the stack
236+
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
237+
path = [(1, 0) for i in range(len(operands) - 1)]
238+
contraction_list = contraction_list_from_path(subscripts, operands, path)
239+
optimized = False
240+
else:
241+
_, contraction_list = contract_path(
242+
subscripts,
243+
*shapes,
244+
einsum_call=True,
245+
use_blas=True,
246+
optimize="optimal",
247+
shapes=True,
248+
)
249+
path = [contraction[0] for contraction in contraction_list]
250+
optimized = True
183251

184252
def sum_uniques(
185253
operand: TensorVariable, names: str, uniques: list[str]
@@ -246,6 +314,7 @@ def sum_repeats(
246314
lhs, rhs = map(einsum_operands.pop, operand_indices)
247315
lhs_names, rhs_names = input_names
248316

317+
# TODO: Do this as well?
249318
# handle cases where one side of a contracting or batch dimension is 1
250319
# but its counterpart is not.
251320
# lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
@@ -323,6 +392,10 @@ def sum_repeats(
323392
axes=(lhs_cont, rhs_cont),
324393
batch_axes=(lhs_batch, rhs_batch),
325394
)
395+
else:
396+
raise ValueError(
397+
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}"
398+
)
326399

327400
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
328401
assert len(names) == len(result_names) == len(set(names))
@@ -338,5 +411,7 @@ def sum_repeats(
338411
subscripts=subscripts,
339412
inputs=list(operands),
340413
outputs=[einsum_result],
414+
path=tuple(path),
415+
optimized=optimized,
341416
)(*operands)
342417
return cast(TensorVariable, out)

tests/tensor/test_einsum.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_general_dot():
7272
)
7373

7474

75+
@pytest.mark.parametrize("static_shape_known", [True, False])
7576
@pytest.mark.parametrize(
7677
"signature",
7778
[
@@ -95,16 +96,23 @@ def test_general_dot():
9596
"oij,imj,mjkn,lnk,plk->op",
9697
],
9798
)
98-
def test_parse_einsum_input(signature):
99+
def test_einsum_signatures(static_shape_known, signature):
99100
letters_to_dims = dict(zip("ijklmnop", [2, 3, 5, 7, 11, 13, 17, 19], strict=True))
100101

101102
inputs = signature.split("->")[0].split(",")
102103

103104
shapes = [tuple(letters_to_dims[letter] for letter in inp) for inp in inputs]
105+
if static_shape_known:
106+
static_shapes = shapes
107+
else:
108+
static_shapes = [[None] * len(shape) for shape in shapes]
109+
104110
operands = [
105-
pt.tensor(name, shape=shape) for name, shape in zip(ascii_lowercase, shapes)
111+
pt.tensor(name, shape=static_shape)
112+
for name, static_shape in zip(ascii_lowercase, static_shapes)
106113
]
107114
out = pt.einsum(signature, *operands)
115+
assert out.owner.op.optimized == static_shape_known
108116

109117
rng = np.random.default_rng(37)
110118
test_values = [rng.normal(size=shape) for shape in shapes]
@@ -113,9 +121,8 @@ def test_parse_einsum_input(signature):
113121
fn = function(operands, out)
114122
pt_out = fn(*test_values)
115123

116-
# print()
117124
# import pytensor
118-
# pytensor.dprint(fn, print_type=True)
125+
# print(); pytensor.dprint(fn, print_type=True)
119126

120127
# assert out.type.shape == np_out.shape # Reshape operations lose static shape
121128
np.testing.assert_allclose(pt_out, np_out)

0 commit comments

Comments
 (0)