@@ -139,10 +139,14 @@ def _get_parameter_grads_from_vector(
139
139
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
140
140
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
141
141
"""
142
+ grad_wrt_args_vector = cast (TensorVariable , grad_wrt_args_vector )
143
+ x_star = cast (TensorVariable , x_star )
144
+
142
145
cursor = 0
143
146
grad_wrt_args = []
144
147
145
148
for arg in args :
149
+ arg = cast (TensorVariable , arg )
146
150
arg_shape = arg .shape
147
151
arg_size = arg_shape .prod ()
148
152
arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
@@ -233,16 +237,17 @@ def scalar_implict_optimization_grads(
233
237
output_grad : Variable ,
234
238
fgraph : FunctionGraph ,
235
239
) -> list [Variable ]:
236
- df_dx , * df_dthetas = grad (
237
- inner_fx , [inner_x , * inner_args ], disconnected_inputs = "ignore"
240
+ df_dx , * df_dthetas = cast (
241
+ list [Variable ],
242
+ grad (inner_fx , [inner_x , * inner_args ], disconnected_inputs = "ignore" ),
238
243
)
239
244
240
245
replace = dict (zip (fgraph .inputs , (x_star , * args ), strict = True ))
241
246
df_dx_star , * df_dthetas_stars = graph_replace ([df_dx , * df_dthetas ], replace = replace )
242
247
243
248
grad_wrt_args = [
244
249
(- df_dtheta_star / df_dx_star ) * output_grad
245
- for df_dtheta_star in df_dthetas_stars
250
+ for df_dtheta_star in cast ( list [ TensorVariable ], df_dthetas_stars )
246
251
]
247
252
248
253
return grad_wrt_args
@@ -297,15 +302,21 @@ def implict_optimization_grads(
297
302
fgraph : FunctionGraph
298
303
The function graph that contains the inputs and outputs of the optimization problem.
299
304
"""
305
+ df_dx = cast (TensorVariable , df_dx )
306
+
300
307
df_dtheta = concatenate (
301
- [atleast_2d (jac_col , left = False ) for jac_col in df_dtheta_columns ],
308
+ [
309
+ atleast_2d (jac_col , left = False )
310
+ for jac_col in cast (list [TensorVariable ], df_dtheta_columns )
311
+ ],
302
312
axis = - 1 ,
303
313
)
304
314
305
315
replace = dict (zip (fgraph .inputs , (x_star , * args ), strict = True ))
306
316
307
- df_dx_star , df_dtheta_star = graph_replace (
308
- [atleast_2d (df_dx ), df_dtheta ], replace = replace
317
+ df_dx_star , df_dtheta_star = cast (
318
+ list [TensorVariable ],
319
+ graph_replace ([atleast_2d (df_dx ), df_dtheta ], replace = replace ),
309
320
)
310
321
311
322
grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
@@ -546,7 +557,9 @@ def __init__(
546
557
self .fgraph = FunctionGraph ([variables , * args ], [equation ])
547
558
548
559
if jac :
549
- f_prime = grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
560
+ f_prime = cast (
561
+ Variable , grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
562
+ )
550
563
self .fgraph .add_output (f_prime )
551
564
552
565
if hess :
@@ -555,7 +568,9 @@ def __init__(
555
568
"Cannot set `hess=True` without `jac=True`. No methods use second derivatives without also"
556
569
" using first derivatives."
557
570
)
558
- f_double_prime = grad (self .fgraph .outputs [- 1 ], self .fgraph .inputs [0 ])
571
+ f_double_prime = cast (
572
+ Variable , grad (self .fgraph .outputs [- 1 ], self .fgraph .inputs [0 ])
573
+ )
559
574
self .fgraph .add_output (f_double_prime )
560
575
561
576
self .method = method
0 commit comments