Skip to content

Commit c587e33

Browse files
Implement Root Op
1 parent 4669f0a commit c587e33

File tree

2 files changed

+44
-10
lines changed

2 files changed

+44
-10
lines changed

pytensor/tensor/optimize.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ class LRUCache1:
3939
expensive functions.
4040
"""
4141

42-
def __init__(self, fn):
42+
def __init__(self, fn, copy_x: bool = False):
4343
self.fn = fn
4444
self.last_x = None
4545
self.last_result = None
46+
self.copy_x = copy_x
4647

4748
self.cache_hits = 0
4849
self.cache_misses = 0
@@ -59,9 +60,17 @@ def __call__(self, x, *args):
5960
If the input `x` is the same as the last input, return the cached result. Otherwise update the cache with the
6061
new input and result.
6162
"""
63+
# scipy.optimize.scalar_minimize and scalar_root don't take initial values as an argument, so we can't control
64+
# the first input to the inner function. Of course, they use a scalar, but we need a 0d numpy array.
65+
x = np.asarray(x)
6266

6367
if self.last_result is None or not (x == self.last_x).all():
6468
self.cache_misses += 1
69+
70+
# scipy.optimize.root changes x in place, so the cache has to copy it, otherwise we get false
71+
# cache hits and optimization always fails.
72+
if self.copy_x:
73+
x = x.copy()
6574
self.last_x = x
6675

6776
result = self.fn(x, *args)
@@ -449,6 +458,9 @@ def __init__(
449458

450459
def perform(self, node, inputs, outputs):
451460
f = self.fn_wrapped
461+
f.clear_cache()
462+
f.copy_x = True
463+
452464
variables, *args = inputs
453465

454466
res = scipy_root(
@@ -460,31 +472,53 @@ def perform(self, node, inputs, outputs):
460472
**self.optimizer_kwargs,
461473
)
462474

463-
outputs[0][0] = res.x
464-
outputs[1][0] = res.success
475+
outputs[0][0] = res.x.reshape(variables.shape)
476+
outputs[1][0] = np.bool_(res.success)
465477

466478
def L_op(
467479
self,
468480
inputs: Sequence[Variable],
469481
outputs: Sequence[Variable],
470482
output_grads: Sequence[Variable],
471483
) -> list[Variable]:
472-
# TODO: Broken
473484
x, *args = inputs
474-
x_star, success = outputs
485+
x_star, _ = outputs
475486
output_grad, _ = output_grads
476487

477488
inner_x, *inner_args = self.fgraph.inputs
478489
inner_fx = self.fgraph.outputs[0]
479490

480-
inner_jac = jacobian(inner_fx, [inner_x, *inner_args])
491+
df_dx = jacobian(inner_fx, inner_x) if not self.jac else self.fgraph.outputs[1]
492+
493+
df_dtheta = concatenate(
494+
[
495+
atleast_2d(jac_column, left=False)
496+
for jac_column in jacobian(
497+
inner_fx, inner_args, disconnected_inputs="ignore"
498+
)
499+
],
500+
axis=-1,
501+
)
481502

482503
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
483-
jac_f_wrt_x_star, *jac_f_wrt_args = graph_replace(inner_jac, replace=replace)
504+
df_dx_star, df_dtheta_star = graph_replace([df_dx, df_dtheta], replace=replace)
484505

485-
jac_wrt_args = solve(-jac_f_wrt_x_star, output_grad)
506+
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
486507

487-
return [zeros_like(x), jac_wrt_args]
508+
cursor = 0
509+
grad_wrt_args = []
510+
511+
for arg in args:
512+
arg_shape = arg.shape
513+
arg_size = arg_shape.prod()
514+
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
515+
(*x_star.shape, *arg_shape)
516+
)
517+
518+
grad_wrt_args.append(dot(output_grad, arg_grad))
519+
cursor += arg_size
520+
521+
return [zeros_like(x), *grad_wrt_args]
488522

489523

490524
def root(

tests/tensor/test_optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def test_root_system_of_equations():
142142

143143
f = pt.stack([a[0] * x[0] * pt.cos(x[1]) - b[0], x[0] * x[1] - a[1] * x[1] - b[1]])
144144

145-
root_f, success = root(f, x, debug=True)
145+
root_f, success = root(f, x)
146146
func = pytensor.function([x, a, b], [root_f, success])
147147

148148
x0 = np.array([1.0, 1.0])

0 commit comments

Comments
 (0)