11
11
from pytensor import Variable , function , graph_replace
12
12
from pytensor .gradient import grad , hessian , jacobian
13
13
from pytensor .graph import Apply , Constant , FunctionGraph
14
- from pytensor .graph .basic import graph_inputs , truncated_graph_inputs
14
+ from pytensor .graph .basic import truncated_graph_inputs
15
15
from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
16
16
from pytensor .scalar import bool as scalar_bool
17
17
from pytensor .tensor import dot
@@ -106,6 +106,19 @@ def clear_cache(self):
106
106
self .hess_calls = 0
107
107
108
108
109
+ def _find_optimization_parameters (objective : TensorVariable , x : TensorVariable ):
110
+ """
111
+ Find the parameters of the optimization problem that are not the variable `x`.
112
+
113
+ This is used to determine the additional arguments that need to be passed to the objective function.
114
+ """
115
+ return [
116
+ arg
117
+ for arg in truncated_graph_inputs ([objective ], [x ])
118
+ if (arg is not x and not isinstance (arg , Constant ))
119
+ ]
120
+
121
+
109
122
class ScipyWrapperOp (Op , HasInnerGraph ):
110
123
"""Shared logic for scipy optimization ops"""
111
124
@@ -197,7 +210,9 @@ def perform(self, node, inputs, outputs):
197
210
f = self .fn_wrapped
198
211
f .clear_cache ()
199
212
200
- x0 , * args = inputs
213
+ # minimize_scalar doesn't take x0 as an argument. The Op still needs this input (to symbolically determine
214
+ # the args of the objective function), but it is not used in the optimization.
215
+ _ , * args = inputs
201
216
202
217
res = scipy_minimize_scalar (
203
218
fun = f .value ,
@@ -219,7 +234,7 @@ def L_op(self, inputs, outputs, output_grads):
219
234
220
235
implicit_f = grad (inner_fx , inner_x )
221
236
df_dx , * df_dthetas = grad (
222
- implicit_f , [inner_x , * inner_args ], disconnect_inputs = "ignore"
237
+ implicit_f , [inner_x , * inner_args ], disconnected_inputs = "ignore"
223
238
)
224
239
225
240
replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
@@ -245,11 +260,7 @@ def minimize_scalar(
245
260
Minimize a scalar objective function using scipy.optimize.minimize_scalar.
246
261
"""
247
262
248
- args = [
249
- arg
250
- for arg in graph_inputs ([objective ], [x ])
251
- if (arg is not x and not isinstance (arg , Constant ))
252
- ]
263
+ args = _find_optimization_parameters (objective , x )
253
264
254
265
minimize_scalar_op = MinimizeScalarOp (
255
266
x ,
@@ -396,11 +407,7 @@ def minimize(
396
407
The optimized value of x that minimizes the objective function.
397
408
398
409
"""
399
- args = [
400
- arg
401
- for arg in graph_inputs ([objective ], [x ])
402
- if (arg is not x and not isinstance (arg , Constant ))
403
- ]
410
+ args = _find_optimization_parameters (objective , x )
404
411
405
412
minimize_op = MinimizeOp (
406
413
x ,
0 commit comments