Skip to content

Commit 0ea61bc

Browse files
Use grad to compute jacobian when input shape is known to be (1,) (#1454)
* More robust shape check for `grad` fallback in `jacobian` * Update scalar test
1 parent ff98ab8 commit 0ea61bc

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

pytensor/gradient.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,13 +2069,13 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
20692069
else:
20702070
wrt = [wrt]
20712071

2072-
if expression.ndim == 0:
2072+
if all(expression.type.broadcastable):
20732073
# expression is just a scalar, use grad
20742074
return as_list_or_tuple(
20752075
using_list,
20762076
using_tuple,
20772077
grad(
2078-
expression,
2078+
expression.squeeze(),
20792079
wrt,
20802080
consider_constant=consider_constant,
20812081
disconnected_inputs=disconnected_inputs,

tests/test_gradient.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytensor.graph.basic import Apply, graph_inputs
3131
from pytensor.graph.null_type import NullType
3232
from pytensor.graph.op import Op
33+
from pytensor.scan.op import Scan
3334
from pytensor.tensor.math import add, dot, exp, sigmoid, sqr, tanh
3435
from pytensor.tensor.math import sum as pt_sum
3536
from pytensor.tensor.random import RandomStream
@@ -1036,6 +1037,17 @@ def test_jacobian_scalar():
10361037
vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
10371038
assert np.allclose(f(vx), 2)
10381039

1040+
# test when input is a shape (1,) vector -- should still be treated as a scalar
1041+
Jx = jacobian(y[None], x)
1042+
f = pytensor.function([x], Jx)
1043+
1044+
# Ensure we hit the scalar grad case (doesn't use scan)
1045+
nodes = f.maker.fgraph.apply_nodes
1046+
assert not any(isinstance(node.op, Scan) for node in nodes)
1047+
1048+
vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
1049+
assert np.allclose(f(vx), 2)
1050+
10391051
# test when the jacobian is called with a tuple as wrt
10401052
Jx = jacobian(y, (x,))
10411053
assert isinstance(Jx, tuple)

0 commit comments

Comments
 (0)