Skip to content

Commit b7e0e1f

Browse files
Implement root_scalar
1 parent 48c2c83 commit b7e0e1f

File tree

2 files changed

+136
-3
lines changed

2 files changed

+136
-3
lines changed

pytensor/tensor/optimize.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from scipy.optimize import minimize as scipy_minimize
88
from scipy.optimize import minimize_scalar as scipy_minimize_scalar
99
from scipy.optimize import root as scipy_root
10+
from scipy.optimize import root_scalar as scipy_root_scalar
1011

1112
from pytensor import Variable, function, graph_replace
1213
from pytensor.gradient import grad, hessian, jacobian
@@ -529,8 +530,111 @@ def minimize(
529530
return minimize_op(x, *args)
530531

531532

533+
class RootScalarOp(ScipyWrapperOp):
534+
__props__ = ("method", "jac", "hess")
535+
536+
def __init__(
537+
self,
538+
variables,
539+
*args,
540+
equation,
541+
method,
542+
jac: bool = False,
543+
hess: bool = False,
544+
optimizer_kwargs=None,
545+
):
546+
self.fgraph = FunctionGraph([variables, *args], [equation])
547+
548+
if jac:
549+
f_prime = grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
550+
self.fgraph.add_output(f_prime)
551+
552+
if hess:
553+
if not jac:
554+
raise ValueError(
555+
"Cannot set `hess=True` without `jac=True`. No methods use second derivatives without also"
556+
" using first derivatives."
557+
)
558+
f_double_prime = grad(self.fgraph.outputs[-1], self.fgraph.inputs[0])
559+
self.fgraph.add_output(f_double_prime)
560+
561+
self.method = method
562+
self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
563+
self.jac = jac
564+
self.hess = hess
565+
566+
self._fn = None
567+
self._fn_wrapped = None
568+
569+
def perform(self, node, inputs, outputs):
570+
f = self.fn_wrapped
571+
f.clear_cache()
572+
# f.copy_x = True
573+
574+
variables, *args = inputs
575+
576+
res = scipy_root_scalar(
577+
f=f.value,
578+
fprime=f.grad if self.jac else None,
579+
fprime2=f.hess if self.hess else None,
580+
x0=variables,
581+
args=tuple(args),
582+
method=self.method,
583+
**self.optimizer_kwargs,
584+
)
585+
586+
outputs[0][0] = np.array(res.root)
587+
outputs[1][0] = np.bool_(res.converged)
588+
589+
def L_op(self, inputs, outputs, output_grads):
590+
x, *args = inputs
591+
x_star, _ = outputs
592+
output_grad, _ = output_grads
593+
594+
inner_x, *inner_args = self.fgraph.inputs
595+
inner_fx = self.fgraph.outputs[0]
596+
597+
grad_wrt_args = scalar_implict_optimization_grads(
598+
inner_fx=inner_fx,
599+
inner_x=inner_x,
600+
inner_args=inner_args,
601+
args=args,
602+
x_star=x_star,
603+
output_grad=output_grad,
604+
fgraph=self.fgraph,
605+
)
606+
607+
return [zeros_like(x), *grad_wrt_args]
608+
609+
610+
def root_scalar(
611+
equation: TensorVariable,
612+
variables: TensorVariable,
613+
method: str = "secant",
614+
jac: bool = False,
615+
hess: bool = False,
616+
optimizer_kwargs: dict | None = None,
617+
):
618+
"""
619+
Find roots of a scalar equation using scipy.optimize.root_scalar.
620+
"""
621+
args = _find_optimization_parameters(equation, variables)
622+
623+
root_scalar_op = RootScalarOp(
624+
variables,
625+
*args,
626+
equation=equation,
627+
method=method,
628+
jac=jac,
629+
hess=hess,
630+
optimizer_kwargs=optimizer_kwargs,
631+
)
632+
633+
return root_scalar_op(variables, *args)
634+
635+
532636
class RootOp(ScipyWrapperOp):
533-
__props__ = ("method", "jac", "optimizer_kwargs")
637+
__props__ = ("method", "jac")
534638

535639
def __init__(
536640
self,
@@ -616,4 +720,4 @@ def root(
616720
return root_op(variables, *args)
617721

618722

619-
__all__ = ["minimize", "root"]
723+
__all__ = ["minimize_scalar", "minimize", "root_scalar", "root"]

tests/tensor/test_optimize.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytensor
55
import pytensor.tensor as pt
66
from pytensor import config, function
7-
from pytensor.tensor.optimize import minimize, minimize_scalar, root
7+
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar
88
from tests import unittest_tools as utt
99

1010

@@ -110,6 +110,35 @@ def f(x, a, b):
110110
utt.verify_grad(f, [x0, a_val, b_val], eps=1e-6)
111111

112112

113+
@pytest.mark.parametrize(
114+
"method, jac, hess",
115+
[("secant", False, False), ("newton", True, False), ("halley", True, True)],
116+
)
117+
def test_root_scalar(method, jac, hess):
118+
x = pt.scalar("x")
119+
a = pt.scalar("a")
120+
121+
def fn(x, a):
122+
return x + 2 * a * pt.cos(x)
123+
124+
f = fn(x, a)
125+
root_f, success = root_scalar(f, x, method=method, jac=jac, hess=hess)
126+
func = pytensor.function([x, a], [root_f, success])
127+
128+
x0 = 0.0
129+
a_val = 1.0
130+
solution, success = func(x0, a_val)
131+
132+
assert success
133+
np.testing.assert_allclose(solution, -1.02986653, atol=1e-6, rtol=1e-6)
134+
135+
def root_fn(x, a):
136+
f = fn(x, a)
137+
return root_scalar(f, x, method=method, jac=jac, hess=hess)[0]
138+
139+
utt.verify_grad(root_fn, [x0, a_val], eps=1e-6)
140+
141+
113142
def test_root_simple():
114143
x = pt.scalar("x")
115144
a = pt.scalar("a")

0 commit comments

Comments
 (0)