Skip to content

Commit c233ce4

Browse files
Implement minimize_scalar
1 parent 2f0119c commit c233ce4

File tree

2 files changed

+129
-1
lines changed

2 files changed

+129
-1
lines changed

pytensor/tensor/optimize.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
from scipy.optimize import minimize as scipy_minimize
7+
from scipy.optimize import minimize_scalar as scipy_minimize_scalar
78
from scipy.optimize import root as scipy_root
89

910
from pytensor import Variable, function, graph_replace
@@ -90,6 +91,104 @@ def make_node(self, *inputs):
9091
)
9192

9293

94+
class MinimizeScalarOp(ScipyWrapperOp):
95+
__props__ = ("method",)
96+
97+
def __init__(
98+
self,
99+
x: Variable,
100+
*args: Variable,
101+
objective: Variable,
102+
method: str = "brent",
103+
optimizer_kwargs: dict | None = None,
104+
):
105+
self.fgraph = FunctionGraph([x, *args], [objective])
106+
107+
self.method = method
108+
self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
109+
self._fn = None
110+
self._fn_wrapped = None
111+
112+
def perform(self, node, inputs, outputs):
113+
f = self.fn_wrapped
114+
x0, *args = inputs
115+
116+
res = scipy_minimize_scalar(
117+
fun=f,
118+
args=tuple(args),
119+
method=self.method,
120+
**self.optimizer_kwargs,
121+
)
122+
123+
outputs[0][0] = np.array(res.x)
124+
outputs[1][0] = np.bool_(res.success)
125+
126+
def L_op(self, inputs, outputs, output_grads):
127+
x, *args = inputs
128+
x_star, _ = outputs
129+
output_grad, _ = output_grads
130+
131+
inner_x, *inner_args = self.fgraph.inputs
132+
inner_fx = self.fgraph.outputs[0]
133+
134+
implicit_f = grad(inner_fx, inner_x)
135+
df_dx = grad(implicit_f, inner_x)
136+
137+
df_dthetas = [
138+
grad(implicit_f, arg, disconnected_inputs="ignore") for arg in inner_args
139+
]
140+
141+
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
142+
df_dx_star, *df_dthetas_stars = graph_replace(
143+
[df_dx, *df_dthetas], replace=replace
144+
)
145+
146+
grad_wrt_args = [
147+
(-df_dtheta_star / df_dx_star) * output_grad
148+
for df_dtheta_star in df_dthetas_stars
149+
]
150+
151+
return [zeros_like(x), *grad_wrt_args]
152+
153+
154+
def minimize_scalar(
155+
objective: TensorVariable,
156+
x: TensorVariable,
157+
method: str = "brent",
158+
optimizer_kwargs: dict | None = None,
159+
):
160+
"""
161+
Minimize a scalar objective function using scipy.optimize.minimize_scalar.
162+
"""
163+
164+
args = [
165+
arg
166+
for arg in graph_inputs([objective], [x])
167+
if (arg is not x and not isinstance(arg, Constant))
168+
]
169+
170+
minimize_scalar_op = MinimizeScalarOp(
171+
x,
172+
*args,
173+
objective=objective,
174+
method=method,
175+
optimizer_kwargs=optimizer_kwargs,
176+
)
177+
178+
input_core_ndim = [var.ndim for var in minimize_scalar_op.inner_inputs]
179+
input_signatures = [
180+
f'({",".join(f"i{i}{n}" for n in range(ndim))})'
181+
for i, ndim in enumerate(input_core_ndim)
182+
]
183+
184+
# Output dimensions are always the same as the first input (the initial values for the optimizer),
185+
# then a scalar for the success flag
186+
output_signatures = [input_signatures[0], "()"]
187+
188+
signature = f"{','.join(input_signatures)}->{','.join(output_signatures)}"
189+
return Blockwise(minimize_scalar_op, signature=signature)(x, *args)
190+
191+
93192
class MinimizeOp(ScipyWrapperOp):
94193
__props__ = ("method", "jac", "hess", "hessp")
95194

tests/tensor/test_optimize.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,42 @@
33
import pytensor
44
import pytensor.tensor as pt
55
from pytensor import config, function
6-
from pytensor.tensor.optimize import minimize, root
6+
from pytensor.tensor.optimize import minimize, minimize_scalar, root
77
from tests import unittest_tools as utt
88

99

1010
floatX = config.floatX
1111

1212

13+
def test_minimize_scalar():
14+
x = pt.scalar("x")
15+
a = pt.scalar("a")
16+
c = pt.scalar("c")
17+
18+
b = a * 2
19+
b.name = "b"
20+
out = (x - b * c) ** 2
21+
22+
minimized_x, success = minimize_scalar(out, x)
23+
24+
a_val = 2.0
25+
c_val = 3.0
26+
27+
f = function([a, c, x], [minimized_x, success])
28+
29+
minimized_x_val, success_val = f(a_val, c_val, 0.0)
30+
31+
assert success_val
32+
np.testing.assert_allclose(minimized_x_val, (2 * a_val * c_val))
33+
34+
def f(x, a, b):
35+
objective = (x - a * b) ** 2
36+
out = minimize_scalar(objective, x)[0]
37+
return out
38+
39+
utt.verify_grad(f, [0.0, a_val, c_val], eps=1e-6)
40+
41+
1342
def test_simple_minimize():
1443
x = pt.scalar("x")
1544
a = pt.scalar("a")

0 commit comments

Comments
 (0)