@@ -154,11 +154,11 @@ def __init__(self, input_broadcastable, new_order):
154
154
155
155
# List of input dimensions to drop
156
156
drop = []
157
- for i , b in enumerate (input_broadcastable ):
157
+ for i , bcasted in enumerate (input_broadcastable ):
158
158
if i not in new_order :
159
159
# We want to drop this dimension because it's not a value in
160
160
# `new_order`
161
- if b == 1 :
161
+ if bcasted :
162
162
drop .append (i )
163
163
else :
164
164
# We cannot drop non-broadcastable dimensions
@@ -187,13 +187,13 @@ def __setstate__(self, state):
187
187
188
188
def make_node (self , _input ):
189
189
input = as_tensor_variable (_input )
190
- ib = tuple ( s == 1 for s in input .type .shape )
191
- if ib != self .input_broadcastable :
192
- if len ( ib ) != len ( self . input_broadcastable ):
193
- raise TypeError (
194
- "The number of dimensions of the "
195
- f"input is incorrect for this op. Expected { self . input_broadcastable } , got { ib } ."
196
- )
190
+ ib = input .type .broadcastable
191
+ if len ( ib ) != len ( self .input_broadcastable ) :
192
+ raise TypeError (
193
+ "The number of dimensions of the "
194
+ f"input is incorrect for this op. Expected { self . input_broadcastable } , got { ib } . "
195
+ )
196
+ else :
197
197
for expected , b in zip (self .input_broadcastable , ib ):
198
198
if expected is True and b is False :
199
199
raise TypeError (
@@ -270,7 +270,7 @@ def grad(self, inp, grads):
270
270
return [inp [0 ].zeros_like (dtype = config .floatX )]
271
271
else :
272
272
return [
273
- DimShuffle (tuple ( s == 1 for s in gz .type .shape ) , grad_order )(
273
+ DimShuffle (gz .type .broadcastable , grad_order )(
274
274
Elemwise (scalar_identity )(gz )
275
275
)
276
276
]
@@ -407,7 +407,7 @@ def get_output_info(self, dim_shuffle, *inputs):
407
407
# TODO: use LComplete instead
408
408
args .append (
409
409
dim_shuffle (
410
- tuple ( 1 if s == 1 else None for s in input .type .shape ) ,
410
+ input .type .broadcastable ,
411
411
["x" ] * difference + list (range (length )),
412
412
)(input )
413
413
)
@@ -419,45 +419,67 @@ def get_output_info(self, dim_shuffle, *inputs):
419
419
# of all inputs in parallel... the all() gives us each output
420
420
# broadcastable bit in turn.
421
421
422
- def get_most_specialized_shape (shapes ):
423
- shapes = set (shapes )
424
- # All shapes are the same
425
- if len (shapes ) == 1 :
426
- return tuple (shapes )[0 ]
427
-
428
- # Only valid indeterminate case
429
- if shapes == {None , 1 }:
430
- return None
431
-
432
- shapes .discard (1 )
433
- shapes .discard (None )
434
- if len (shapes ) > 1 :
435
- raise ValueError
436
- return tuple (shapes )[0 ]
422
+ def get_most_specialized_shape (shapes , bcasting ):
423
+ seen_dims = set (zip (shapes , bcasting ))
424
+ if len (seen_dims ) == 1 :
425
+ return next (iter (seen_dims ))[0 ]
426
+ elif len (seen_dims ) == 2 and (1 , True ) in seen_dims :
427
+ # this is fine since one dimension broadcasts to another common dimension
428
+ seen_dims .discard ((1 , True ))
429
+ return next (iter (seen_dims ))[0 ]
430
+ # we have set length >= 2 and it is not the case (1, True) in seen_dims
431
+ # let's drops dims that do not matter in comparison
432
+ # do not care about 1 that broadcasts
433
+ # the simple case was checked above so two discard
434
+ # never manage to produce an empty set
435
+ seen_dims .discard ((1 , True ))
436
+ # do not care about unknown that does not broadcast because it
437
+ # should anyway match to known that did not broadcast
438
+ seen_dims .discard ((None , False ))
439
+ if len (seen_dims ) > 1 :
440
+ # we did not manage to specialize dims, raise an error
441
+ raise ValueError (f"shapes and broadcast mismatch: { seen_dims } " )
442
+ return next (iter (seen_dims ))[0 ]
437
443
438
444
# it is multiplied by nout because Elemwise supports multiple outputs
439
445
# (nout of them)
440
446
try :
441
447
out_shapes = [
442
448
[
443
- get_most_specialized_shape (shape )
444
- for shape in zip (* [inp .type .shape for inp in inputs ])
449
+ get_most_specialized_shape (shape , bcasting )
450
+ for shape , bcasting in zip (
451
+ zip (* [inp .type .shape for inp in inputs ]),
452
+ zip (* [inp .type .broadcastable for inp in inputs ]),
453
+ )
454
+ ]
455
+ ] * shadow .nout
456
+ out_broadcastable = [
457
+ [
458
+ all (bcast )
459
+ for bcast in zip (* [inp .type .broadcastable for inp in inputs ])
445
460
]
446
461
] * shadow .nout
447
- except ValueError :
462
+ except ValueError as e :
448
463
raise ValueError (
449
- f"Incompatible Elemwise input shapes { [inp .type .shape for inp in inputs ]} "
450
- )
464
+ "Incompatible Elemwise input broadcasting pattern: "
465
+ f"{ [inp .type .broadcastable for inp in inputs ]} . "
466
+ "Pytensor has static broadcasting befaviour as it "
467
+ "simplifies gradients and execution graph dramatically. "
468
+ "So even input shapes may contain ones they should be "
469
+ "explicitly be marked broadcastable and your current shapes are "
470
+ f"{ [inp .type .shape for inp in inputs ]} "
471
+ ) from e
451
472
452
473
# inplace_pattern maps output idx -> input idx
453
474
inplace_pattern = self .inplace_pattern
454
475
if inplace_pattern :
455
476
for overwriter , overwritten in inplace_pattern .items ():
456
- for out_s , in_s in zip (
457
- out_shapes [overwriter ],
458
- inputs [overwritten ].type .shape ,
477
+ for out_b , in_b in zip (
478
+ out_broadcastable [overwriter ],
479
+ inputs [overwritten ].type .broadcastable ,
459
480
):
460
- if in_s == 1 and out_s != 1 :
481
+ # the dimension lost its broadcasting property
482
+ if in_b and not out_b :
461
483
raise ValueError (
462
484
"Operation cannot be done inplace on an input "
463
485
"with broadcasted dimensions."
@@ -474,7 +496,7 @@ def get_most_specialized_shape(shapes):
474
496
)
475
497
)
476
498
assert len (out_dtypes ) == len (out_shapes )
477
- return out_dtypes , out_shapes , inputs
499
+ return out_dtypes , out_shapes , out_broadcastable , inputs
478
500
479
501
def make_node (self , * inputs ):
480
502
"""
@@ -483,10 +505,12 @@ def make_node(self, *inputs):
483
505
using DimShuffle.
484
506
"""
485
507
inputs = [as_tensor_variable (i ) for i in inputs ]
486
- out_dtypes , out_shapes , inputs = self .get_output_info (DimShuffle , * inputs )
508
+ out_dtypes , out_shapes , out_broadcastable , inputs = self .get_output_info (
509
+ DimShuffle , * inputs
510
+ )
487
511
outputs = [
488
- TensorType (dtype = dtype , shape = shape )()
489
- for dtype , shape in zip (out_dtypes , out_shapes )
512
+ TensorType (dtype = dtype , shape = shape , broadcastable = bcasting )()
513
+ for dtype , shape , bcasting in zip (out_dtypes , out_shapes , out_broadcastable )
490
514
]
491
515
return Apply (self , inputs , outputs )
492
516
@@ -543,8 +567,6 @@ def connection_pattern(self, node):
543
567
return [[True for output in node .outputs ] for ipt in node .inputs ]
544
568
545
569
def L_op (self , inputs , outs , ograds ):
546
- from pytensor .tensor .math import sum as at_sum
547
-
548
570
# Compute grad with respect to broadcasted input
549
571
rval = self ._bgrad (inputs , outs , ograds )
550
572
@@ -577,16 +599,17 @@ def L_op(self, inputs, outs, ograds):
577
599
# List of all the dimensions that are broadcastable for input[i] so
578
600
# we can sum over them
579
601
# TODO: only count dimensions that were effectively broadcasted
602
+ # the comment was introduced there
603
+ # https://github.com/Theano/Theano/commit/1ddd6c38a7a627bab8ee1a4c3d45295fc9c6aace
580
604
to_sum = [
581
605
j
582
- for j , in_s in enumerate (ipt .type .shape )
583
- if in_s == 1 and outs [0 ].type .shape [j ] != 1
606
+ for j , in_broadcastable in enumerate (ipt .type .broadcastable )
607
+ if in_broadcastable and not outs [0 ].type .broadcastable [j ]
584
608
]
585
-
586
609
if to_sum :
587
- sr = at_sum (rval [i ], axis = to_sum , keepdims = True )
610
+ sr = pytensor .tensor .math .sum (rval [i ], axis = to_sum , keepdims = True )
611
+ # TODO: check if the rval type is the same as ipt type
588
612
rval [i ] = sr
589
-
590
613
return rval
591
614
592
615
def _bgrad (self , inputs , outputs , ograds ):
@@ -824,18 +847,38 @@ def perform(self, node, inputs, output_storage):
824
847
storage [0 ] = variable
825
848
826
849
def infer_shape (self , fgraph , node , i_shapes ) -> List [Tuple [TensorVariable , ...]]:
827
-
828
- if len (node .outputs ) > 1 :
829
- from pytensor .tensor .exceptions import ShapeError
830
-
831
- raise ShapeError (
832
- "Multiple outputs are not supported by the default `Elemwise.infer_shape`"
850
+ # some shapes can be already static, some unknown and None
851
+ # in the make node we copy shape and broadcasting patterns
852
+ # use this hint here
853
+ out_shape = list (node .outputs [0 ].type .shape )
854
+ for i , (o_size , i_sizes , i_bcasting ) in enumerate (
855
+ zip (
856
+ out_shape ,
857
+ zip (* i_shapes ),
858
+ zip (* (i .type .broadcastable for i in node .inputs )),
833
859
)
834
-
835
- out_shape = pytensor .tensor .broadcast_shape (* i_shapes , arrays_are_shapes = True )
836
-
837
- # The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s
838
- return [tuple (as_tensor_variable (s ) for s in out_shape )]
860
+ ):
861
+ if o_size is None :
862
+ # the unknown shape, should be inferred from some of the inputs
863
+ candidates = [
864
+ s for s , bcasted in zip (i_sizes , i_bcasting ) if not bcasted
865
+ ]
866
+ # None shape can't be broadcasted so it for sure has broadcasting=False
867
+ # thus some of the inputs have to have the non broadcastable dimension
868
+ # NOTE: It appears that user can break that assumption in some custom op,
869
+ # manually creating output nodes
870
+ if len (candidates ) == 0 :
871
+ from pytensor .tensor .rewriting .shape import ShapeError
872
+
873
+ raise ShapeError (
874
+ "Encountered a non broadcasting unknown dimension in elemwise, "
875
+ "but all the input dims were broadcastable. "
876
+ "This can happen if custom make_node did not follow broadcasting conventions"
877
+ )
878
+ # TODO: which best guess is better then 1 or i_sizes[0] or max(i_sizes)?
879
+ out_shape [i ] = candidates [0 ]
880
+ # shape entries are scalars, so return tuple of scalars
881
+ return [tuple (as_tensor_variable (s ) for s in out_shape )] * len (node .outputs )
839
882
840
883
def _c_all (self , node , nodename , inames , onames , sub ):
841
884
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
@@ -898,7 +941,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
898
941
# for each input:
899
942
# same as range(ndim), but with 'x' at all broadcastable positions
900
943
orders = [
901
- [s == 1 and "x" or i for i , s in enumerate (input .type .shape )]
944
+ [bcast and "x" or i for i , bcast in enumerate (input .type .broadcastable )]
902
945
for input in inputs
903
946
]
904
947
@@ -921,7 +964,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
921
964
[
922
965
f"PyArray_ISFORTRAN({ arr } )"
923
966
for arr , var in z
924
- if not all (s == 1 for s in var .type .shape )
967
+ if not all (var .type .broadcastable )
925
968
]
926
969
)
927
970
# If it is a scalar, make it c contig to prevent problem with
@@ -1006,7 +1049,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
1006
1049
or
1007
1050
# Use simpler code when output ndim == 0 or 1
1008
1051
# or for broadcated scalar.
1009
- all (s == 1 for s in node .outputs [0 ].type .shape )
1052
+ all (node .outputs [0 ].type .broadcastable )
1010
1053
):
1011
1054
if nnested :
1012
1055
all_code = [("" , "" )] * (nnested - 1 ) + [("" , code )] + ["" ]
@@ -1078,7 +1121,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
1078
1121
all (o .ndim >= 1 for o in node .outputs )
1079
1122
and
1080
1123
# Don't use the contig code for broadcasted scalar.
1081
- not all (s == 1 for s in node .outputs [0 ].type .shape )
1124
+ not all (node .outputs [0 ].type .broadcastable )
1082
1125
):
1083
1126
contig = None
1084
1127
try :
@@ -1111,7 +1154,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
1111
1154
"""
1112
1155
index = ""
1113
1156
for x , var in zip (inames + onames , inputs + node .outputs ):
1114
- if not all (s == 1 for s in var .type .shape ):
1157
+ if not all (var .type .broadcastable ):
1115
1158
contig += (
1116
1159
"""
1117
1160
dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s);
@@ -1145,7 +1188,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
1145
1188
)
1146
1189
if contig is not None :
1147
1190
z = list (zip (inames + onames , inputs + node .outputs ))
1148
- all_broadcastable = all (s == 1 for s in var .type .shape )
1191
+ all_broadcastable = all (var .type .broadcastable )
1149
1192
cond1 = " && " .join (
1150
1193
[
1151
1194
"PyArray_ISCONTIGUOUS(%s)" % arr
@@ -1199,7 +1242,7 @@ def c_support_code_apply(self, node, nodename):
1199
1242
return support_code
1200
1243
1201
1244
def c_code_cache_version_apply (self , node ):
1202
- version = [14 ] # the version corresponding to the c code in this Op
1245
+ version = [15 ] # the version corresponding to the c code in this Op
1203
1246
1204
1247
# now we insert versions for the ops on which we depend...
1205
1248
scalar_node = Apply (
0 commit comments