Skip to content

Commit ddad8af

Browse files
committed
use static broadcasting
start working on static broadcasting
1 parent f4de2fd commit ddad8af

File tree

5 files changed

+285
-126
lines changed

5 files changed

+285
-126
lines changed

pytensor/tensor/elemwise.py

Lines changed: 107 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,11 @@ def __init__(self, input_broadcastable, new_order):
154154

155155
# List of input dimensions to drop
156156
drop = []
157-
for i, b in enumerate(input_broadcastable):
157+
for i, bcasted in enumerate(input_broadcastable):
158158
if i not in new_order:
159159
# We want to drop this dimension because it's not a value in
160160
# `new_order`
161-
if b == 1:
161+
if bcasted:
162162
drop.append(i)
163163
else:
164164
# We cannot drop non-broadcastable dimensions
@@ -187,13 +187,13 @@ def __setstate__(self, state):
187187

188188
def make_node(self, _input):
189189
input = as_tensor_variable(_input)
190-
ib = tuple(s == 1 for s in input.type.shape)
191-
if ib != self.input_broadcastable:
192-
if len(ib) != len(self.input_broadcastable):
193-
raise TypeError(
194-
"The number of dimensions of the "
195-
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
196-
)
190+
ib = input.type.broadcastable
191+
if len(ib) != len(self.input_broadcastable):
192+
raise TypeError(
193+
"The number of dimensions of the "
194+
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
195+
)
196+
else:
197197
for expected, b in zip(self.input_broadcastable, ib):
198198
if expected is True and b is False:
199199
raise TypeError(
@@ -270,7 +270,7 @@ def grad(self, inp, grads):
270270
return [inp[0].zeros_like(dtype=config.floatX)]
271271
else:
272272
return [
273-
DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
273+
DimShuffle(gz.type.broadcastable, grad_order)(
274274
Elemwise(scalar_identity)(gz)
275275
)
276276
]
@@ -407,7 +407,7 @@ def get_output_info(self, dim_shuffle, *inputs):
407407
# TODO: use LComplete instead
408408
args.append(
409409
dim_shuffle(
410-
tuple(1 if s == 1 else None for s in input.type.shape),
410+
input.type.broadcastable,
411411
["x"] * difference + list(range(length)),
412412
)(input)
413413
)
@@ -419,45 +419,67 @@ def get_output_info(self, dim_shuffle, *inputs):
419419
# of all inputs in parallel... the all() gives us each output
420420
# broadcastable bit in turn.
421421

422-
def get_most_specialized_shape(shapes):
423-
shapes = set(shapes)
424-
# All shapes are the same
425-
if len(shapes) == 1:
426-
return tuple(shapes)[0]
427-
428-
# Only valid indeterminate case
429-
if shapes == {None, 1}:
430-
return None
431-
432-
shapes.discard(1)
433-
shapes.discard(None)
434-
if len(shapes) > 1:
435-
raise ValueError
436-
return tuple(shapes)[0]
422+
def get_most_specialized_shape(shapes, bcasting):
423+
seen_dims = set(zip(shapes, bcasting))
424+
if len(seen_dims) == 1:
425+
return next(iter(seen_dims))[0]
426+
elif len(seen_dims) == 2 and (1, True) in seen_dims:
427+
# this is fine since one dimension broadcasts to another common dimension
428+
seen_dims.discard((1, True))
429+
return next(iter(seen_dims))[0]
430+
# we have set length >= 2 and it is not the case (1, True) in seen_dims
431+
# let's drops dims that do not matter in comparison
432+
# do not care about 1 that broadcasts
433+
# the simple case was checked above so two discard
434+
# never manage to produce an empty set
435+
seen_dims.discard((1, True))
436+
# do not care about unknown that does not broadcast because it
437+
# should anyway match to known that did not broadcast
438+
seen_dims.discard((None, False))
439+
if len(seen_dims) > 1:
440+
# we did not manage to specialize dims, raise an error
441+
raise ValueError(f"shapes and broadcast mismatch: {seen_dims}")
442+
return next(iter(seen_dims))[0]
437443

438444
# it is multiplied by nout because Elemwise supports multiple outputs
439445
# (nout of them)
440446
try:
441447
out_shapes = [
442448
[
443-
get_most_specialized_shape(shape)
444-
for shape in zip(*[inp.type.shape for inp in inputs])
449+
get_most_specialized_shape(shape, bcasting)
450+
for shape, bcasting in zip(
451+
zip(*[inp.type.shape for inp in inputs]),
452+
zip(*[inp.type.broadcastable for inp in inputs]),
453+
)
454+
]
455+
] * shadow.nout
456+
out_broadcastable = [
457+
[
458+
all(bcast)
459+
for bcast in zip(*[inp.type.broadcastable for inp in inputs])
445460
]
446461
] * shadow.nout
447-
except ValueError:
462+
except ValueError as e:
448463
raise ValueError(
449-
f"Incompatible Elemwise input shapes {[inp.type.shape for inp in inputs]}"
450-
)
464+
"Incompatible Elemwise input broadcasting pattern: "
465+
f"{[inp.type.broadcastable for inp in inputs]}. "
466+
"Pytensor has static broadcasting befaviour as it "
467+
"simplifies gradients and execution graph dramatically. "
468+
"So even input shapes may contain ones they should be "
469+
"explicitly be marked broadcastable and your current shapes are "
470+
f"{[inp.type.shape for inp in inputs]}"
471+
) from e
451472

452473
# inplace_pattern maps output idx -> input idx
453474
inplace_pattern = self.inplace_pattern
454475
if inplace_pattern:
455476
for overwriter, overwritten in inplace_pattern.items():
456-
for out_s, in_s in zip(
457-
out_shapes[overwriter],
458-
inputs[overwritten].type.shape,
477+
for out_b, in_b in zip(
478+
out_broadcastable[overwriter],
479+
inputs[overwritten].type.broadcastable,
459480
):
460-
if in_s == 1 and out_s != 1:
481+
# the dimension lost its broadcasting property
482+
if in_b and not out_b:
461483
raise ValueError(
462484
"Operation cannot be done inplace on an input "
463485
"with broadcasted dimensions."
@@ -474,7 +496,7 @@ def get_most_specialized_shape(shapes):
474496
)
475497
)
476498
assert len(out_dtypes) == len(out_shapes)
477-
return out_dtypes, out_shapes, inputs
499+
return out_dtypes, out_shapes, out_broadcastable, inputs
478500

479501
def make_node(self, *inputs):
480502
"""
@@ -483,10 +505,12 @@ def make_node(self, *inputs):
483505
using DimShuffle.
484506
"""
485507
inputs = [as_tensor_variable(i) for i in inputs]
486-
out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
508+
out_dtypes, out_shapes, out_broadcastable, inputs = self.get_output_info(
509+
DimShuffle, *inputs
510+
)
487511
outputs = [
488-
TensorType(dtype=dtype, shape=shape)()
489-
for dtype, shape in zip(out_dtypes, out_shapes)
512+
TensorType(dtype=dtype, shape=shape, broadcastable=bcasting)()
513+
for dtype, shape, bcasting in zip(out_dtypes, out_shapes, out_broadcastable)
490514
]
491515
return Apply(self, inputs, outputs)
492516

@@ -543,8 +567,6 @@ def connection_pattern(self, node):
543567
return [[True for output in node.outputs] for ipt in node.inputs]
544568

545569
def L_op(self, inputs, outs, ograds):
546-
from pytensor.tensor.math import sum as at_sum
547-
548570
# Compute grad with respect to broadcasted input
549571
rval = self._bgrad(inputs, outs, ograds)
550572

@@ -577,16 +599,17 @@ def L_op(self, inputs, outs, ograds):
577599
# List of all the dimensions that are broadcastable for input[i] so
578600
# we can sum over them
579601
# TODO: only count dimensions that were effectively broadcasted
602+
# the comment was introduced there
603+
# https://github.com/Theano/Theano/commit/1ddd6c38a7a627bab8ee1a4c3d45295fc9c6aace
580604
to_sum = [
581605
j
582-
for j, in_s in enumerate(ipt.type.shape)
583-
if in_s == 1 and outs[0].type.shape[j] != 1
606+
for j, in_broadcastable in enumerate(ipt.type.broadcastable)
607+
if in_broadcastable and not outs[0].type.broadcastable[j]
584608
]
585-
586609
if to_sum:
587-
sr = at_sum(rval[i], axis=to_sum, keepdims=True)
610+
sr = pytensor.tensor.math.sum(rval[i], axis=to_sum, keepdims=True)
611+
# TODO: check if the rval type is the same as ipt type
588612
rval[i] = sr
589-
590613
return rval
591614

592615
def _bgrad(self, inputs, outputs, ograds):
@@ -824,18 +847,38 @@ def perform(self, node, inputs, output_storage):
824847
storage[0] = variable
825848

826849
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
827-
828-
if len(node.outputs) > 1:
829-
from pytensor.tensor.exceptions import ShapeError
830-
831-
raise ShapeError(
832-
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
850+
# some shapes can be already static, some unknown and None
851+
# in the make node we copy shape and broadcasting patterns
852+
# use this hint here
853+
out_shape = list(node.outputs[0].type.shape)
854+
for i, (o_size, i_sizes, i_bcasting) in enumerate(
855+
zip(
856+
out_shape,
857+
zip(*i_shapes),
858+
zip(*(i.type.broadcastable for i in node.inputs)),
833859
)
834-
835-
out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
836-
837-
# The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s
838-
return [tuple(as_tensor_variable(s) for s in out_shape)]
860+
):
861+
if o_size is None:
862+
# the unknown shape, should be inferred from some of the inputs
863+
candidates = [
864+
s for s, bcasted in zip(i_sizes, i_bcasting) if not bcasted
865+
]
866+
# None shape can't be broadcasted so it for sure has broadcasting=False
867+
# thus some of the inputs have to have the non broadcastable dimension
868+
# NOTE: It appears that user can break that assumption in some custom op,
869+
# manually creating output nodes
870+
if len(candidates) == 0:
871+
from pytensor.tensor.rewriting.shape import ShapeError
872+
873+
raise ShapeError(
874+
"Encountered a non broadcasting unknown dimension in elemwise, "
875+
"but all the input dims were broadcastable. "
876+
"This can happen if custom make_node did not follow broadcasting conventions"
877+
)
878+
# TODO: which best guess is better then 1 or i_sizes[0] or max(i_sizes)?
879+
out_shape[i] = candidates[0]
880+
# shape entries are scalars, so return tuple of scalars
881+
return [tuple(as_tensor_variable(s) for s in out_shape)] * len(node.outputs)
839882

840883
def _c_all(self, node, nodename, inames, onames, sub):
841884
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
@@ -898,7 +941,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
898941
# for each input:
899942
# same as range(ndim), but with 'x' at all broadcastable positions
900943
orders = [
901-
[s == 1 and "x" or i for i, s in enumerate(input.type.shape)]
944+
[bcast and "x" or i for i, bcast in enumerate(input.type.broadcastable)]
902945
for input in inputs
903946
]
904947

@@ -921,7 +964,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
921964
[
922965
f"PyArray_ISFORTRAN({arr})"
923966
for arr, var in z
924-
if not all(s == 1 for s in var.type.shape)
967+
if not all(var.type.broadcastable)
925968
]
926969
)
927970
# If it is a scalar, make it c contig to prevent problem with
@@ -1006,7 +1049,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
10061049
or
10071050
# Use simpler code when output ndim == 0 or 1
10081051
# or for broadcated scalar.
1009-
all(s == 1 for s in node.outputs[0].type.shape)
1052+
all(node.outputs[0].type.broadcastable)
10101053
):
10111054
if nnested:
10121055
all_code = [("", "")] * (nnested - 1) + [("", code)] + [""]
@@ -1078,7 +1121,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
10781121
all(o.ndim >= 1 for o in node.outputs)
10791122
and
10801123
# Don't use the contig code for broadcasted scalar.
1081-
not all(s == 1 for s in node.outputs[0].type.shape)
1124+
not all(node.outputs[0].type.broadcastable)
10821125
):
10831126
contig = None
10841127
try:
@@ -1111,7 +1154,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
11111154
"""
11121155
index = ""
11131156
for x, var in zip(inames + onames, inputs + node.outputs):
1114-
if not all(s == 1 for s in var.type.shape):
1157+
if not all(var.type.broadcastable):
11151158
contig += (
11161159
"""
11171160
dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s);
@@ -1145,7 +1188,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
11451188
)
11461189
if contig is not None:
11471190
z = list(zip(inames + onames, inputs + node.outputs))
1148-
all_broadcastable = all(s == 1 for s in var.type.shape)
1191+
all_broadcastable = all(var.type.broadcastable)
11491192
cond1 = " && ".join(
11501193
[
11511194
"PyArray_ISCONTIGUOUS(%s)" % arr
@@ -1199,7 +1242,7 @@ def c_support_code_apply(self, node, nodename):
11991242
return support_code
12001243

12011244
def c_code_cache_version_apply(self, node):
1202-
version = [14] # the version corresponding to the c code in this Op
1245+
version = [15] # the version corresponding to the c code in this Op
12031246

12041247
# now we insert versions for the ops on which we depend...
12051248
scalar_node = Apply(

0 commit comments

Comments
 (0)