Skip to content

Commit c1bea19

Browse files
Mypy
1 parent 84882a8 commit c1bea19

File tree

1 file changed

+28
-33
lines changed

1 file changed

+28
-33
lines changed

pytensor/tensor/optimize.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Sequence
22
from copy import copy
3+
from typing import cast
34

45
from scipy.optimize import minimize as scipy_minimize
56
from scipy.optimize import root as scipy_root
@@ -10,16 +11,14 @@
1011
from pytensor.graph.basic import truncated_graph_inputs
1112
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
1213
from pytensor.scalar import bool as scalar_bool
13-
from pytensor.tensor.basic import atleast_2d, concatenate
14+
from pytensor.tensor.basic import atleast_2d, concatenate, zeros_like
1415
from pytensor.tensor.slinalg import solve
1516
from pytensor.tensor.variable import TensorVariable
1617

1718

1819
class ScipyWrapperOp(Op, HasInnerGraph):
1920
"""Shared logic for scipy optimization ops"""
2021

21-
__props__ = ("method", "debug")
22-
2322
def build_fn(self):
2423
"""
2524
This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
@@ -93,28 +92,30 @@ class MinimizeOp(ScipyWrapperOp):
9392

9493
def __init__(
9594
self,
96-
x,
97-
*args,
98-
objective,
99-
method="BFGS",
100-
jac=True,
101-
hess=False,
102-
hessp=False,
103-
options: dict | None = None,
95+
x: Variable,
96+
*args: Variable,
97+
objective: Variable,
98+
method: str = "BFGS",
99+
jac: bool = True,
100+
hess: bool = False,
101+
hessp: bool = False,
102+
optimizer_kwargs: dict | None = None,
104103
debug: bool = False,
105104
):
106105
self.fgraph = FunctionGraph([x, *args], [objective])
107106

108107
if jac:
109-
grad_wrt_x = grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
108+
grad_wrt_x = cast(
109+
Variable, grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
110+
)
110111
self.fgraph.add_output(grad_wrt_x)
111112

112113
self.jac = jac
113114
self.hess = hess
114115
self.hessp = hessp
115116

116117
self.method = method
117-
self.options = options if options is not None else {}
118+
self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
118119
self.debug = debug
119120
self._fn = None
120121
self._fn_wrapped = None
@@ -132,9 +133,6 @@ def perform(self, node, inputs, outputs):
132133
**self.options,
133134
)
134135

135-
if self.debug:
136-
print(res)
137-
138136
outputs[0][0] = res.x
139137
outputs[1][0] = res.success
140138

@@ -185,12 +183,12 @@ def L_op(self, inputs, outputs, output_grads):
185183

186184

187185
def minimize(
188-
objective,
189-
x,
186+
objective: TensorVariable,
187+
x: TensorVariable,
190188
method: str = "BFGS",
191189
jac: bool = True,
192190
debug: bool = False,
193-
options: dict | None = None,
191+
optimizer_kwargs: dict | None = None,
194192
):
195193
"""
196194
Minimize a scalar objective function using scipy.optimize.minimize.
@@ -214,7 +212,7 @@ def minimize(
214212
debug : bool, optional
215213
If True, prints raw scipy result after optimization. Default is False.
216214
217-
**optimizer_kwargs
215+
optimizer_kwargs
218216
Additional keyword arguments to pass to scipy.optimize.minimize
219217
220218
Returns
@@ -236,7 +234,7 @@ def minimize(
236234
method=method,
237235
jac=jac,
238236
debug=debug,
239-
options=options,
237+
optimizer_kwargs=optimizer_kwargs,
240238
)
241239

242240
return minimize_op(x, *args)
@@ -247,12 +245,12 @@ class RootOp(ScipyWrapperOp):
247245

248246
def __init__(
249247
self,
250-
variables,
251-
*args,
252-
equations,
253-
method="hybr",
254-
jac=True,
255-
options: dict | None = None,
248+
variables: Variable,
249+
*args: Variable,
250+
equations: Variable,
251+
method: str = "hybr",
252+
jac: bool = True,
253+
optimizer_kwargs: dict | None = None,
256254
debug: bool = False,
257255
):
258256
self.fgraph = FunctionGraph([variables, *args], [equations])
@@ -264,7 +262,7 @@ def __init__(
264262
self.jac = jac
265263

266264
self.method = method
267-
self.options = options if options is not None else {}
265+
self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
268266
self.debug = debug
269267
self._fn = None
270268
self._fn_wrapped = None
@@ -279,12 +277,9 @@ def perform(self, node, inputs, outputs):
279277
x0=variables,
280278
args=tuple(args),
281279
method=self.method,
282-
**self.options,
280+
**self.optimizer_kwargs,
283281
)
284282

285-
if self.debug:
286-
print(res)
287-
288283
outputs[0][0] = res.x
289284
outputs[1][0] = res.success
290285

@@ -309,7 +304,7 @@ def L_op(
309304

310305
jac_wrt_args = solve(-jac_f_wrt_x_star, output_grad)
311306

312-
return [x.zeros_like(), jac_wrt_args]
307+
return [zeros_like(x), jac_wrt_args]
313308

314309

315310
def root(

0 commit comments

Comments
 (0)