@@ -128,6 +128,32 @@ def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
128
128
]
129
129
130
130
131
+ def _get_parameter_grads_from_vector (
132
+ grad_wrt_args_vector : Variable ,
133
+ x_star : Variable ,
134
+ args : Sequence [Variable ],
135
+ output_grad : Variable ,
136
+ ):
137
+ """
138
+ Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
139
+ returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
140
+ """
141
+ cursor = 0
142
+ grad_wrt_args = []
143
+
144
+ for arg in args :
145
+ arg_shape = arg .shape
146
+ arg_size = arg_shape .prod ()
147
+ arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
148
+ (* x_star .shape , * arg_shape )
149
+ )
150
+
151
+ grad_wrt_args .append (dot (output_grad , arg_grad ))
152
+ cursor += arg_size
153
+
154
+ return grad_wrt_args
155
+
156
+
131
157
class ScipyWrapperOp (Op , HasInnerGraph ):
132
158
"""Shared logic for scipy optimization ops"""
133
159
@@ -348,34 +374,25 @@ def L_op(self, inputs, outputs, output_grads):
348
374
349
375
implicit_f = grad (inner_fx , inner_x )
350
376
351
- df_dx = atleast_2d (concatenate (jacobian (implicit_f , [inner_x ]), axis = - 1 ))
377
+ df_dx , * df_dtheta_columns = jacobian (
378
+ implicit_f , [inner_x , * inner_args ], disconnected_inputs = "ignore"
379
+ )
352
380
353
381
df_dtheta = concatenate (
354
- [
355
- atleast_2d (x , left = False )
356
- for x in jacobian (implicit_f , inner_args , disconnected_inputs = "ignore" )
357
- ],
382
+ [atleast_2d (jac_col , left = False ) for jac_col in df_dtheta_columns ],
358
383
axis = - 1 ,
359
384
)
360
385
361
386
replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
362
387
363
- df_dx_star , df_dtheta_star = graph_replace ([df_dx , df_dtheta ], replace = replace )
388
+ df_dx_star , df_dtheta_star = graph_replace (
389
+ [atleast_2d (df_dx ), df_dtheta ], replace = replace
390
+ )
364
391
365
392
grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
366
-
367
- cursor = 0
368
- grad_wrt_args = []
369
-
370
- for arg in args :
371
- arg_shape = arg .shape
372
- arg_size = arg_shape .prod ()
373
- arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
374
- (* x_star .shape , * arg_shape )
375
- )
376
-
377
- grad_wrt_args .append (dot (output_grad , arg_grad ))
378
- cursor += arg_size
393
+ grad_wrt_args = _get_parameter_grads_from_vector (
394
+ grad_wrt_args_vector , x_star , args , output_grad
395
+ )
379
396
380
397
return [zeros_like (x ), * grad_wrt_args ]
381
398
@@ -504,19 +521,9 @@ def L_op(
504
521
df_dx_star , df_dtheta_star = graph_replace ([df_dx , df_dtheta ], replace = replace )
505
522
506
523
grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
507
-
508
- cursor = 0
509
- grad_wrt_args = []
510
-
511
- for arg in args :
512
- arg_shape = arg .shape
513
- arg_size = arg_shape .prod ()
514
- arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
515
- (* x_star .shape , * arg_shape )
516
- )
517
-
518
- grad_wrt_args .append (dot (output_grad , arg_grad ))
519
- cursor += arg_size
524
+ grad_wrt_args = _get_parameter_grads_from_vector (
525
+ grad_wrt_args_vector , x_star , args , output_grad
526
+ )
520
527
521
528
return [zeros_like (x ), * grad_wrt_args ]
522
529
@@ -529,11 +536,7 @@ def root(
529
536
):
530
537
"""Find roots of a system of equations using scipy.optimize.root."""
531
538
532
- args = [
533
- arg
534
- for arg in truncated_graph_inputs ([equations ], [variables ])
535
- if (arg is not variables and not isinstance (arg , Constant ))
536
- ]
539
+ args = _find_optimization_parameters (equations , variables )
537
540
538
541
root_op = RootOp (variables , * args , equations = equations , method = method , jac = jac )
539
542
0 commit comments