Skip to content

Commit cdb41cc

Browse files
mypy 😍
1 parent 85253c0 commit cdb41cc

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

pytensor/tensor/optimize.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,14 @@ def _get_parameter_grads_from_vector(
139139
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
140140
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
141141
"""
142+
grad_wrt_args_vector = cast(TensorVariable, grad_wrt_args_vector)
143+
x_star = cast(TensorVariable, x_star)
144+
142145
cursor = 0
143146
grad_wrt_args = []
144147

145148
for arg in args:
149+
arg = cast(TensorVariable, arg)
146150
arg_shape = arg.shape
147151
arg_size = arg_shape.prod()
148152
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
@@ -233,16 +237,17 @@ def scalar_implict_optimization_grads(
233237
output_grad: Variable,
234238
fgraph: FunctionGraph,
235239
) -> list[Variable]:
236-
df_dx, *df_dthetas = grad(
237-
inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"
240+
df_dx, *df_dthetas = cast(
241+
list[Variable],
242+
grad(inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"),
238243
)
239244

240245
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
241246
df_dx_star, *df_dthetas_stars = graph_replace([df_dx, *df_dthetas], replace=replace)
242247

243248
grad_wrt_args = [
244249
(-df_dtheta_star / df_dx_star) * output_grad
245-
for df_dtheta_star in df_dthetas_stars
250+
for df_dtheta_star in cast(list[TensorVariable], df_dthetas_stars)
246251
]
247252

248253
return grad_wrt_args
@@ -297,15 +302,21 @@ def implict_optimization_grads(
297302
fgraph : FunctionGraph
298303
The function graph that contains the inputs and outputs of the optimization problem.
299304
"""
305+
df_dx = cast(TensorVariable, df_dx)
306+
300307
df_dtheta = concatenate(
301-
[atleast_2d(jac_col, left=False) for jac_col in df_dtheta_columns],
308+
[
309+
atleast_2d(jac_col, left=False)
310+
for jac_col in cast(list[TensorVariable], df_dtheta_columns)
311+
],
302312
axis=-1,
303313
)
304314

305315
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
306316

307-
df_dx_star, df_dtheta_star = graph_replace(
308-
[atleast_2d(df_dx), df_dtheta], replace=replace
317+
df_dx_star, df_dtheta_star = cast(
318+
list[TensorVariable],
319+
graph_replace([atleast_2d(df_dx), df_dtheta], replace=replace),
309320
)
310321

311322
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
@@ -546,7 +557,9 @@ def __init__(
546557
self.fgraph = FunctionGraph([variables, *args], [equation])
547558

548559
if jac:
549-
f_prime = grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
560+
f_prime = cast(
561+
Variable, grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
562+
)
550563
self.fgraph.add_output(f_prime)
551564

552565
if hess:
@@ -555,7 +568,9 @@ def __init__(
555568
"Cannot set `hess=True` without `jac=True`. No methods use second derivatives without also"
556569
" using first derivatives."
557570
)
558-
f_double_prime = grad(self.fgraph.outputs[-1], self.fgraph.inputs[0])
571+
f_double_prime = cast(
572+
Variable, grad(self.fgraph.outputs[-1], self.fgraph.inputs[0])
573+
)
559574
self.fgraph.add_output(f_double_prime)
560575

561576
self.method = method

0 commit comments

Comments
 (0)