1
1
from collections .abc import Sequence
2
2
from copy import copy
3
+ from typing import cast
3
4
4
5
from scipy .optimize import minimize as scipy_minimize
5
6
from scipy .optimize import root as scipy_root
10
11
from pytensor .graph .basic import truncated_graph_inputs
11
12
from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
12
13
from pytensor .scalar import bool as scalar_bool
13
- from pytensor .tensor .basic import atleast_2d , concatenate
14
+ from pytensor .tensor .basic import atleast_2d , concatenate , zeros_like
14
15
from pytensor .tensor .slinalg import solve
15
16
from pytensor .tensor .variable import TensorVariable
16
17
17
18
18
19
class ScipyWrapperOp (Op , HasInnerGraph ):
19
20
"""Shared logic for scipy optimization ops"""
20
21
21
- __props__ = ("method" , "debug" )
22
-
23
22
def build_fn (self ):
24
23
"""
25
24
This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
@@ -93,28 +92,30 @@ class MinimizeOp(ScipyWrapperOp):
93
92
94
93
def __init__ (
95
94
self ,
96
- x ,
97
- * args ,
98
- objective ,
99
- method = "BFGS" ,
100
- jac = True ,
101
- hess = False ,
102
- hessp = False ,
103
- options : dict | None = None ,
95
+ x : Variable ,
96
+ * args : Variable ,
97
+ objective : Variable ,
98
+ method : str = "BFGS" ,
99
+ jac : bool = True ,
100
+ hess : bool = False ,
101
+ hessp : bool = False ,
102
+ optimizer_kwargs : dict | None = None ,
104
103
debug : bool = False ,
105
104
):
106
105
self .fgraph = FunctionGraph ([x , * args ], [objective ])
107
106
108
107
if jac :
109
- grad_wrt_x = grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
108
+ grad_wrt_x = cast (
109
+ Variable , grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
110
+ )
110
111
self .fgraph .add_output (grad_wrt_x )
111
112
112
113
self .jac = jac
113
114
self .hess = hess
114
115
self .hessp = hessp
115
116
116
117
self .method = method
117
- self .options = options if options is not None else {}
118
+ self .optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
118
119
self .debug = debug
119
120
self ._fn = None
120
121
self ._fn_wrapped = None
@@ -132,9 +133,6 @@ def perform(self, node, inputs, outputs):
132
133
** self .options ,
133
134
)
134
135
135
- if self .debug :
136
- print (res )
137
-
138
136
outputs [0 ][0 ] = res .x
139
137
outputs [1 ][0 ] = res .success
140
138
@@ -185,12 +183,12 @@ def L_op(self, inputs, outputs, output_grads):
185
183
186
184
187
185
def minimize (
188
- objective ,
189
- x ,
186
+ objective : TensorVariable ,
187
+ x : TensorVariable ,
190
188
method : str = "BFGS" ,
191
189
jac : bool = True ,
192
190
debug : bool = False ,
193
- options : dict | None = None ,
191
+ optimizer_kwargs : dict | None = None ,
194
192
):
195
193
"""
196
194
Minimize a scalar objective function using scipy.optimize.minimize.
@@ -214,7 +212,7 @@ def minimize(
214
212
debug : bool, optional
215
213
If True, prints raw scipy result after optimization. Default is False.
216
214
217
- ** optimizer_kwargs
215
+ optimizer_kwargs
218
216
Additional keyword arguments to pass to scipy.optimize.minimize
219
217
220
218
Returns
@@ -236,7 +234,7 @@ def minimize(
236
234
method = method ,
237
235
jac = jac ,
238
236
debug = debug ,
239
- options = options ,
237
+ optimizer_kwargs = optimizer_kwargs ,
240
238
)
241
239
242
240
return minimize_op (x , * args )
@@ -247,12 +245,12 @@ class RootOp(ScipyWrapperOp):
247
245
248
246
def __init__ (
249
247
self ,
250
- variables ,
251
- * args ,
252
- equations ,
253
- method = "hybr" ,
254
- jac = True ,
255
- options : dict | None = None ,
248
+ variables : Variable ,
249
+ * args : Variable ,
250
+ equations : Variable ,
251
+ method : str = "hybr" ,
252
+ jac : bool = True ,
253
+ optimizer_kwargs : dict | None = None ,
256
254
debug : bool = False ,
257
255
):
258
256
self .fgraph = FunctionGraph ([variables , * args ], [equations ])
@@ -264,7 +262,7 @@ def __init__(
264
262
self .jac = jac
265
263
266
264
self .method = method
267
- self .options = options if options is not None else {}
265
+ self .optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
268
266
self .debug = debug
269
267
self ._fn = None
270
268
self ._fn_wrapped = None
@@ -279,12 +277,9 @@ def perform(self, node, inputs, outputs):
279
277
x0 = variables ,
280
278
args = tuple (args ),
281
279
method = self .method ,
282
- ** self .options ,
280
+ ** self .optimizer_kwargs ,
283
281
)
284
282
285
- if self .debug :
286
- print (res )
287
-
288
283
outputs [0 ][0 ] = res .x
289
284
outputs [1 ][0 ] = res .success
290
285
@@ -309,7 +304,7 @@ def L_op(
309
304
310
305
jac_wrt_args = solve (- jac_f_wrt_x_star , output_grad )
311
306
312
- return [x . zeros_like (), jac_wrt_args ]
307
+ return [zeros_like (x ), jac_wrt_args ]
313
308
314
309
315
310
def root (
0 commit comments