Skip to content

Commit ee884b8

Browse files
ricardoV94jessegrabowski
authored andcommitted
Fix Elemwise and Blockwise gradient for Ops with mixed discrete and continuous output types
1 parent 676296c commit ee884b8

File tree

4 files changed

+56
-46
lines changed

4 files changed

+56
-46
lines changed

pytensor/tensor/blockwise.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytensor.scalar import ScalarType
1919
from pytensor.tensor import as_tensor_variable
2020
from pytensor.tensor.shape import shape_padleft
21-
from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor
21+
from pytensor.tensor.type import TensorType, tensor
2222
from pytensor.tensor.utils import (
2323
_parse_gufunc_signature,
2424
broadcast_static_dim_lengths,
@@ -256,6 +256,10 @@ def as_core(t, core_t):
256256
as_core(ograd, core_ograd)
257257
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
258258
]
259+
# FIXME: These core_outputs do not depend on core_inputs, not pretty
260+
# It's not neccessarily a problem because if they are referenced by the gradient,
261+
# they get replaced later in vectorize. But if the Op was to make any decision
262+
# by introspecting the dependencies of output on inputs it would fail badly!
259263
core_outputs = core_node.outputs
260264

261265
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
@@ -283,27 +287,6 @@ def L_op(self, inputs, outs, ograds):
283287
# Compute grad with respect to broadcasted input
284288
rval = self._bgrad(inputs, outs, ograds)
285289

286-
# TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable
287-
# to the gradient.grad method when the outputs have
288-
# some integer and some floating point outputs
289-
if any(out.type.dtype not in continuous_dtypes for out in outs):
290-
# For integer output, return value may only be zero or undefined
291-
# We don't bother with trying to check that the scalar ops
292-
# correctly returned something that evaluates to 0, we just make
293-
# the return value obviously zero so that gradient.grad can tell
294-
# this op did the right thing.
295-
new_rval = []
296-
for elem, inp in zip(rval, inputs, strict=True):
297-
if isinstance(elem.type, NullType | DisconnectedType):
298-
new_rval.append(elem)
299-
else:
300-
elem = inp.zeros_like()
301-
if str(elem.type.dtype) not in continuous_dtypes:
302-
elem = elem.astype(config.floatX)
303-
assert str(elem.type.dtype) not in discrete_dtypes
304-
new_rval.append(elem)
305-
return new_rval
306-
307290
# Sum out the broadcasted dimensions
308291
batch_ndims = self.batch_ndim(outs[0].owner)
309292
batch_shape = outs[0].type.shape[:batch_ndims]

pytensor/tensor/elemwise.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -515,27 +515,6 @@ def L_op(self, inputs, outs, ograds):
515515
# Compute grad with respect to broadcasted input
516516
rval = self._bgrad(inputs, outs, ograds)
517517

518-
# TODO: make sure that zeros are clearly identifiable
519-
# to the gradient.grad method when the outputs have
520-
# some integer and some floating point outputs
521-
if any(out.type.dtype not in continuous_dtypes for out in outs):
522-
# For integer output, return value may only be zero or undefined
523-
# We don't bother with trying to check that the scalar ops
524-
# correctly returned something that evaluates to 0, we just make
525-
# the return value obviously zero so that gradient.grad can tell
526-
# this op did the right thing.
527-
new_rval = []
528-
for elem, ipt in zip(rval, inputs, strict=True):
529-
if isinstance(elem.type, NullType | DisconnectedType):
530-
new_rval.append(elem)
531-
else:
532-
elem = ipt.zeros_like()
533-
if str(elem.type.dtype) not in continuous_dtypes:
534-
elem = elem.astype(config.floatX)
535-
assert str(elem.type.dtype) not in discrete_dtypes
536-
new_rval.append(elem)
537-
return new_rval
538-
539518
# sum out the broadcasted dimensions
540519
for i, ipt in enumerate(inputs):
541520
if isinstance(rval[i].type, NullType | DisconnectedType):

tests/tensor/test_blockwise.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.graph import Apply, Op
1313
from pytensor.graph.replace import vectorize_node
1414
from pytensor.raise_op import assert_op
15-
from pytensor.tensor import diagonal, log, tensor
15+
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector
1616
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
1717
from pytensor.tensor.nlinalg import MatrixInverse
1818
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
@@ -603,3 +603,26 @@ def core_scipy_fn(A, b):
603603
# Confirm input was destroyed
604604
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
605605
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
606+
607+
608+
def test_gradient_mixed_discrete_output_core_op():
609+
class MixedDtypeCoreOp(Op):
610+
gufunc_signature = "()->(),()"
611+
itypes = [scalar().type]
612+
otypes = [scalar().type, scalar(dtype=int).type]
613+
614+
def perform(self, node, inputs, outputs):
615+
raise NotImplementedError()
616+
617+
def L_op(self, inputs, outputs, output_gradients):
618+
return [ones_like(inputs[0]) * output_gradients[0]]
619+
620+
op = Blockwise(MixedDtypeCoreOp())
621+
x = vector("x")
622+
y, _ = op(x)
623+
624+
np.testing.assert_array_equal(
625+
grad(y.sum(), x).eval({x: np.full(12, np.nan, dtype=config.floatX)}),
626+
np.ones(12, dtype=config.floatX),
627+
strict=True,
628+
)

tests/tensor/test_elemwise.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
import pytensor.scalar as ps
1212
import pytensor.tensor as pt
1313
import tests.unittest_tools as utt
14-
from pytensor import In, Out
14+
from pytensor import In, Out, config, grad
1515
from pytensor.compile.function import function
1616
from pytensor.compile.mode import Mode
17-
from pytensor.configdefaults import config
1817
from pytensor.graph.basic import Apply, Variable
1918
from pytensor.graph.fg import FunctionGraph
2019
from pytensor.graph.replace import vectorize_node
2120
from pytensor.link.basic import PerformLinker
2221
from pytensor.link.c.basic import CLinker, OpWiseCLinker
2322
from pytensor.npy_2_compat import numpy_maxdims
23+
from pytensor.scalar import ScalarOp, float32, float64, int32, int64
2424
from pytensor.tensor import as_tensor_variable
2525
from pytensor.tensor.basic import get_scalar_constant_value, second
2626
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
@@ -1068,3 +1068,28 @@ def test_c_careduce_benchmark(axis, c_contiguous, benchmark):
10681068
return careduce_benchmark_tester(
10691069
axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark
10701070
)
1071+
1072+
1073+
def test_gradient_mixed_discrete_output_scalar_op():
1074+
class MixedDtypeScalarOp(ScalarOp):
1075+
def make_node(self, *inputs):
1076+
float_op = float64 if config.floatX == "float64" else float32
1077+
int_op = int64 if config.floatX == "int64" else int32
1078+
inputs = [float_op()]
1079+
outputs = [float_op(), int_op()]
1080+
return Apply(self, inputs, outputs)
1081+
1082+
def perform(self, node, inputs, outputs):
1083+
raise NotImplementedError()
1084+
1085+
def L_op(self, inputs, outputs, output_gradients):
1086+
return [inputs[0].ones_like() * output_gradients[0]]
1087+
1088+
op = Elemwise(MixedDtypeScalarOp())
1089+
x = vector("x")
1090+
y, _ = op(x)
1091+
np.testing.assert_array_equal(
1092+
grad(y.sum(), x).eval({x: np.full((12,), np.nan, dtype=config.floatX)}),
1093+
np.ones((12,), dtype=config.floatX),
1094+
strict=True,
1095+
)

0 commit comments

Comments
 (0)