Skip to content

Commit 29e64e9

Browse files
committed
Handle subscripts shared with broadcasted dimensions
1 parent fe20a66 commit 29e64e9

File tree

2 files changed

+58
-14
lines changed

2 files changed

+58
-14
lines changed

pytensor/tensor/einsum.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections
22
import itertools
3+
import warnings
34
from collections.abc import Sequence
45
from functools import partial, reduce
56
from itertools import pairwise
@@ -385,7 +386,6 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
385386
else:
386387
# Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
387388
# contraction order.
388-
# Call _implementation to bypass dispatch
389389
_, contraction_list = np.einsum_path(
390390
subscripts,
391391
# Numpy einsum_path requires arrays even though only the shapes matter
@@ -428,14 +428,22 @@ def sum_repeats(
428428
names = names.replace(name, "", count - 1)
429429
return operand, names
430430

431-
# def filter_singleton_dims(operand, names, other_shape, other_names):
432-
# eq = core.definitely_equal
433-
# keep = [
434-
# not eq(operand.shape[i], 1) or j == -1 or eq(other_shape[j], 1)
435-
# for i, j in enumerate(map(other_names.find, names))
436-
# ]
437-
# sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim)))
438-
# return lax.squeeze(operand, sqez_axes), "".join(names[i] for i in keep_axes)
431+
def filter_singleton_dims(operand, names, other_operand, other_names):
432+
op_bcast = operand.type.broadcastable
433+
other_bcast = other_operand.type.broadcastable
434+
keep = [
435+
(not op_bcast[i]) or (j == -1) or other_bcast[j]
436+
for i, j in enumerate(map(other_names.find, names))
437+
]
438+
keep_axes = [i for i, keep_axis in enumerate(keep) if keep_axis]
439+
squeeze_axes = [i for i, keep_axis in enumerate(keep) if not keep_axis]
440+
if squeeze_axes:
441+
# TODO: We could modify the subscripts to avoid the problem?
442+
warnings.warn(
443+
"The same einsum subscript is used for a broadcastable and non-broadcastable dimension. "
444+
"This can result in a suboptimal contraction path."
445+
)
446+
return operand.squeeze(squeeze_axes), "".join(names[i] for i in keep_axes)
439447

440448
einsum_operands = list(operands) # So we can pop
441449
for operand_indices, contracted_names, einstr, _, _ in contraction_list:
@@ -465,13 +473,10 @@ def sum_repeats(
465473
lhs, rhs = map(einsum_operands.pop, operand_indices)
466474
lhs_names, rhs_names = input_names
467475

468-
# TODO: Do this as well?
469476
# handle cases where one side of a contracting or batch dimension is 1
470477
# but its counterpart is not.
471-
# lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
472-
# rhs_names)
473-
# rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs),
474-
# lhs_names)
478+
lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, rhs, rhs_names)
479+
rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, lhs, lhs_names)
475480

476481
lhs_counts = collections.Counter(lhs_names)
477482
rhs_counts = collections.Counter(rhs_names)

tests/tensor/test_einsum.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from pytensor.tensor.shape import Reshape
1515

1616

17+
# Fail for unexpected warnings in this file
18+
pytestmark = pytest.mark.filterwarnings("error")
19+
1720
floatX = pytensor.config.floatX
1821
ATOL = RTOL = 1e-8 if floatX == "float64" else 1e-4
1922

@@ -214,3 +217,39 @@ def test_ellipsis():
214217
np.testing.assert_allclose(
215218
out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL
216219
)
220+
221+
222+
def test_broadcastable_dims():
223+
# Test that einsum handles broadcasting dims correctly. There are two points:
224+
# 1. Numpy einsum allows the same subscript for degenerate and full dimensions
225+
# There is some stale discussion on whether this should be a bug or not, but for now it is not:
226+
# https://github.com/numpy/numpy/issues/11548
227+
228+
# 2. Using the same letter for dimensions that are and aren't broadcastable
229+
# can lead to suboptimal paths. We check we issue a warning for the following example:
230+
# https://github.com/dgasmith/opt_einsum/issues/220
231+
rng = np.random.default_rng(222)
232+
a = pt.tensor("a", shape=(32, 32, 32))
233+
b = pt.tensor("b", shape=(1000, 32))
234+
c = pt.tensor("c", shape=(1, 32))
235+
236+
a_test = rng.normal(size=a.type.shape).astype(floatX)
237+
b_test = rng.normal(size=b.type.shape).astype(floatX)
238+
c_test = rng.normal(size=c.type.shape).astype(floatX)
239+
240+
# Note b is used for both 1 and 32
241+
with pytest.warns(
242+
UserWarning, match="This can result in a suboptimal contraction path"
243+
):
244+
suboptimal_out = pt.einsum("ijk,bj,bk->i", a, b, c)
245+
assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}]
246+
247+
# If we use a distinct letter we get the optimal path
248+
optimal_out = pt.einsum("ijk,bj,ck->i", a, b, c)
249+
assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}]
250+
251+
suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test})
252+
optimal_eval = optimal_out.eval({a: a_test, b: b_test, c: c_test})
253+
np_eval = np.einsum("ijk,bj,bk->i", a_test, b_test, c_test)
254+
np.testing.assert_allclose(suboptimal_eval, np_eval)
255+
np.testing.assert_allclose(optimal_eval, np_eval)

0 commit comments

Comments
 (0)