|
1 | 1 | import collections
|
2 | 2 | import itertools
|
| 3 | +import warnings |
3 | 4 | from collections.abc import Sequence
|
4 | 5 | from functools import partial, reduce
|
5 | 6 | from itertools import pairwise
|
@@ -385,7 +386,6 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
|
385 | 386 | else:
|
386 | 387 | # Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
|
387 | 388 | # contraction order.
|
388 |
| - # Call _implementation to bypass dispatch |
389 | 389 | _, contraction_list = np.einsum_path(
|
390 | 390 | subscripts,
|
391 | 391 | # Numpy einsum_path requires arrays even though only the shapes matter
|
@@ -428,14 +428,22 @@ def sum_repeats(
|
428 | 428 | names = names.replace(name, "", count - 1)
|
429 | 429 | return operand, names
|
430 | 430 |
|
431 |
| - # def filter_singleton_dims(operand, names, other_shape, other_names): |
432 |
| - # eq = core.definitely_equal |
433 |
| - # keep = [ |
434 |
| - # not eq(operand.shape[i], 1) or j == -1 or eq(other_shape[j], 1) |
435 |
| - # for i, j in enumerate(map(other_names.find, names)) |
436 |
| - # ] |
437 |
| - # sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim))) |
438 |
| - # return lax.squeeze(operand, sqez_axes), "".join(names[i] for i in keep_axes) |
| 431 | + def filter_singleton_dims(operand, names, other_operand, other_names): |
| 432 | + op_bcast = operand.type.broadcastable |
| 433 | + other_bcast = other_operand.type.broadcastable |
| 434 | + keep = [ |
| 435 | + (not op_bcast[i]) or (j == -1) or other_bcast[j] |
| 436 | + for i, j in enumerate(map(other_names.find, names)) |
| 437 | + ] |
| 438 | + keep_axes = [i for i, keep_axis in enumerate(keep) if keep_axis] |
| 439 | + squeeze_axes = [i for i, keep_axis in enumerate(keep) if not keep_axis] |
| 440 | + if squeeze_axes: |
| 441 | + # TODO: We could modify the subscripts to avoid the problem? |
| 442 | + warnings.warn( |
| 443 | + "The same einsum subscript is used for a broadcastable and non-broadcastable dimension. " |
| 444 | + "This can result in a suboptimal contraction path." |
| 445 | + ) |
| 446 | + return operand.squeeze(squeeze_axes), "".join(names[i] for i in keep_axes) |
439 | 447 |
|
440 | 448 | einsum_operands = list(operands) # So we can pop
|
441 | 449 | for operand_indices, contracted_names, einstr, _, _ in contraction_list:
|
@@ -465,13 +473,10 @@ def sum_repeats(
|
465 | 473 | lhs, rhs = map(einsum_operands.pop, operand_indices)
|
466 | 474 | lhs_names, rhs_names = input_names
|
467 | 475 |
|
468 |
| - # TODO: Do this as well? |
469 | 476 | # handle cases where one side of a contracting or batch dimension is 1
|
470 | 477 | # but its counterpart is not.
|
471 |
| - # lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs), |
472 |
| - # rhs_names) |
473 |
| - # rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs), |
474 |
| - # lhs_names) |
| 478 | + lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, rhs, rhs_names) |
| 479 | + rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, lhs, lhs_names) |
475 | 480 |
|
476 | 481 | lhs_counts = collections.Counter(lhs_names)
|
477 | 482 | rhs_counts = collections.Counter(rhs_names)
|
|
0 commit comments