Skip to content

Commit 53c26d1

Browse files
committed
Cosmetic improvements to dynamic broadcast checks
1 parent 0fbec99 commit 53c26d1

File tree

1 file changed

+14
-24
lines changed

1 file changed

+14
-24
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
14341434
return RavelMultiIndex(mode=mode, order=order)(*args)
14351435

14361436

1437+
_broadcast_assert = Assert(
1438+
"Could not broadcast dimensions. Broadcasting is only allowed along "
1439+
"axes that have a statically known length 1. Use `specify_shape` to "
1440+
"inform PyTensor of a known shape."
1441+
)
1442+
1443+
14371444
def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
14381445
"""Compute the shape resulting from broadcasting arrays.
14391446
@@ -1510,20 +1517,19 @@ def broadcast_shape_iter(
15101517
for dim_shapes in zip(*array_shapes):
15111518
# Get the shapes in this dimension that are not broadcastable
15121519
# (i.e. not symbolically known to be broadcastable)
1513-
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
1520+
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
15141521

1515-
if len(maybe_non_bcast_shapes) == 0:
1522+
if len(non_bcast_shapes) == 0:
15161523
# Every shape was broadcastable in this dimension
15171524
result_dims.append(one_at)
1518-
elif len(maybe_non_bcast_shapes) == 1:
1525+
elif len(non_bcast_shapes) == 1:
15191526
# Only one shape might not be broadcastable in this dimension
1520-
result_dims.extend(maybe_non_bcast_shapes)
1527+
result_dims.extend(non_bcast_shapes)
15211528
else:
15221529
# More than one shape might not be broadcastable in this dimension
1523-
15241530
nonconst_nb_shapes: Set[int] = set()
15251531
const_nb_shapes: Set[Variable] = set()
1526-
for shape in maybe_non_bcast_shapes:
1532+
for shape in non_bcast_shapes:
15271533
if isinstance(shape, Constant):
15281534
const_nb_shapes.add(shape.value.item())
15291535
else:
@@ -1534,7 +1540,6 @@ def broadcast_shape_iter(
15341540
f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}."
15351541
)
15361542

1537-
assert_op = Assert("Could not dynamically broadcast dimensions.")
15381543
if len(const_nb_shapes) == 1:
15391544
(first_length,) = const_nb_shapes
15401545
other_lengths = nonconst_nb_shapes
@@ -1547,23 +1552,8 @@ def broadcast_shape_iter(
15471552
continue
15481553

15491554
# Add assert that all remaining shapes are equal
1550-
use_scalars = False
1551-
if use_scalars:
1552-
condition = None
1553-
for other in other_lengths:
1554-
cond = aes.eq(first_length, other)
1555-
if condition is None:
1556-
condition = cond
1557-
else:
1558-
condition = aes.and_(condition, cond)
1559-
else:
1560-
condition = pt_all(
1561-
[pt_eq(first_length, other) for other in other_lengths]
1562-
)
1563-
if condition is None:
1564-
result_dims.append(first_length)
1565-
else:
1566-
result_dims.append(assert_op(first_length, condition))
1555+
condition = pt_all([pt_eq(first_length, other) for other in other_lengths])
1556+
result_dims.append(_broadcast_assert(first_length, condition))
15671557

15681558
return tuple(result_dims)
15691559

0 commit comments

Comments
 (0)