Skip to content

Commit 4669f0a

Browse files
use truncated_graph_inputs and refactor
1 parent 9807a18 commit 4669f0a

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

pytensor/tensor/optimize.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pytensor import Variable, function, graph_replace
1212
from pytensor.gradient import grad, hessian, jacobian
1313
from pytensor.graph import Apply, Constant, FunctionGraph
14-
from pytensor.graph.basic import graph_inputs, truncated_graph_inputs
14+
from pytensor.graph.basic import truncated_graph_inputs
1515
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
1616
from pytensor.scalar import bool as scalar_bool
1717
from pytensor.tensor import dot
@@ -106,6 +106,19 @@ def clear_cache(self):
106106
self.hess_calls = 0
107107

108108

109+
def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
110+
"""
111+
Find the parameters of the optimization problem that are not the variable `x`.
112+
113+
This is used to determine the additional arguments that need to be passed to the objective function.
114+
"""
115+
return [
116+
arg
117+
for arg in truncated_graph_inputs([objective], [x])
118+
if (arg is not x and not isinstance(arg, Constant))
119+
]
120+
121+
109122
class ScipyWrapperOp(Op, HasInnerGraph):
110123
"""Shared logic for scipy optimization ops"""
111124

@@ -197,7 +210,9 @@ def perform(self, node, inputs, outputs):
197210
f = self.fn_wrapped
198211
f.clear_cache()
199212

200-
x0, *args = inputs
213+
# minimize_scalar doesn't take x0 as an argument. The Op still needs this input (to symbolically determine
214+
# the args of the objective function), but it is not used in the optimization.
215+
_, *args = inputs
201216

202217
res = scipy_minimize_scalar(
203218
fun=f.value,
@@ -219,7 +234,7 @@ def L_op(self, inputs, outputs, output_grads):
219234

220235
implicit_f = grad(inner_fx, inner_x)
221236
df_dx, *df_dthetas = grad(
222-
implicit_f, [inner_x, *inner_args], disconnect_inputs="ignore"
237+
implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore"
223238
)
224239

225240
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
@@ -245,11 +260,7 @@ def minimize_scalar(
245260
Minimize a scalar objective function using scipy.optimize.minimize_scalar.
246261
"""
247262

248-
args = [
249-
arg
250-
for arg in graph_inputs([objective], [x])
251-
if (arg is not x and not isinstance(arg, Constant))
252-
]
263+
args = _find_optimization_parameters(objective, x)
253264

254265
minimize_scalar_op = MinimizeScalarOp(
255266
x,
@@ -396,11 +407,7 @@ def minimize(
396407
The optimized value of x that minimizes the objective function.
397408
398409
"""
399-
args = [
400-
arg
401-
for arg in graph_inputs([objective], [x])
402-
if (arg is not x and not isinstance(arg, Constant))
403-
]
410+
args = _find_optimization_parameters(objective, x)
404411

405412
minimize_op = MinimizeOp(
406413
x,

0 commit comments

Comments
 (0)