@@ -1434,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
1434
1434
return RavelMultiIndex (mode = mode , order = order )(* args )
1435
1435
1436
1436
1437
+ _broadcast_assert = Assert (
1438
+ "Could not broadcast dimensions. Broadcasting is only allowed along "
1439
+ "axes that have a statically known length 1. Use `specify_shape` to "
1440
+ "inform PyTensor of a known shape."
1441
+ )
1442
+
1443
+
1437
1444
def broadcast_shape (* arrays , ** kwargs ) -> Tuple [aes .ScalarVariable , ...]:
1438
1445
"""Compute the shape resulting from broadcasting arrays.
1439
1446
@@ -1510,20 +1517,19 @@ def broadcast_shape_iter(
1510
1517
for dim_shapes in zip (* array_shapes ):
1511
1518
# Get the shapes in this dimension that are not broadcastable
1512
1519
# (i.e. not symbolically known to be broadcastable)
1513
- maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at ]
1520
+ non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at ]
1514
1521
1515
- if len (maybe_non_bcast_shapes ) == 0 :
1522
+ if len (non_bcast_shapes ) == 0 :
1516
1523
# Every shape was broadcastable in this dimension
1517
1524
result_dims .append (one_at )
1518
- elif len (maybe_non_bcast_shapes ) == 1 :
1525
+ elif len (non_bcast_shapes ) == 1 :
1519
1526
# Only one shape might not be broadcastable in this dimension
1520
- result_dims .extend (maybe_non_bcast_shapes )
1527
+ result_dims .extend (non_bcast_shapes )
1521
1528
else :
1522
1529
# More than one shape might not be broadcastable in this dimension
1523
-
1524
1530
nonconst_nb_shapes : Set [int ] = set ()
1525
1531
const_nb_shapes : Set [Variable ] = set ()
1526
- for shape in maybe_non_bcast_shapes :
1532
+ for shape in non_bcast_shapes :
1527
1533
if isinstance (shape , Constant ):
1528
1534
const_nb_shapes .add (shape .value .item ())
1529
1535
else :
@@ -1534,7 +1540,6 @@ def broadcast_shape_iter(
1534
1540
f"Could not broadcast dimensions. Incompatible shapes were { array_shapes } ."
1535
1541
)
1536
1542
1537
- assert_op = Assert ("Could not dynamically broadcast dimensions." )
1538
1543
if len (const_nb_shapes ) == 1 :
1539
1544
(first_length ,) = const_nb_shapes
1540
1545
other_lengths = nonconst_nb_shapes
@@ -1547,23 +1552,8 @@ def broadcast_shape_iter(
1547
1552
continue
1548
1553
1549
1554
# Add assert that all remaining shapes are equal
1550
- use_scalars = False
1551
- if use_scalars :
1552
- condition = None
1553
- for other in other_lengths :
1554
- cond = aes .eq (first_length , other )
1555
- if condition is None :
1556
- condition = cond
1557
- else :
1558
- condition = aes .and_ (condition , cond )
1559
- else :
1560
- condition = pt_all (
1561
- [pt_eq (first_length , other ) for other in other_lengths ]
1562
- )
1563
- if condition is None :
1564
- result_dims .append (first_length )
1565
- else :
1566
- result_dims .append (assert_op (first_length , condition ))
1555
+ condition = pt_all ([pt_eq (first_length , other ) for other in other_lengths ])
1556
+ result_dims .append (_broadcast_assert (first_length , condition ))
1567
1557
1568
1558
return tuple (result_dims )
1569
1559
0 commit comments