@@ -39,10 +39,11 @@ class LRUCache1:
39
39
expensive functions.
40
40
"""
41
41
42
- def __init__ (self , fn ):
42
+ def __init__ (self , fn , copy_x : bool = False ):
43
43
self .fn = fn
44
44
self .last_x = None
45
45
self .last_result = None
46
+ self .copy_x = copy_x
46
47
47
48
self .cache_hits = 0
48
49
self .cache_misses = 0
@@ -59,9 +60,17 @@ def __call__(self, x, *args):
59
60
If the input `x` is the same as the last input, return the cached result. Otherwise update the cache with the
60
61
new input and result.
61
62
"""
63
+ # scipy.optimize.scalar_minimize and scalar_root don't take initial values as an argument, so we can't control
64
+ # the first input to the inner function. Of course, they use a scalar, but we need a 0d numpy array.
65
+ x = np .asarray (x )
62
66
63
67
if self .last_result is None or not (x == self .last_x ).all ():
64
68
self .cache_misses += 1
69
+
70
+ # scipy.optimize.root changes x in place, so the cache has to copy it, otherwise we get false
71
+ # cache hits and optimization always fails.
72
+ if self .copy_x :
73
+ x = x .copy ()
65
74
self .last_x = x
66
75
67
76
result = self .fn (x , * args )
@@ -449,6 +458,9 @@ def __init__(
449
458
450
459
def perform (self , node , inputs , outputs ):
451
460
f = self .fn_wrapped
461
+ f .clear_cache ()
462
+ f .copy_x = True
463
+
452
464
variables , * args = inputs
453
465
454
466
res = scipy_root (
@@ -460,31 +472,53 @@ def perform(self, node, inputs, outputs):
460
472
** self .optimizer_kwargs ,
461
473
)
462
474
463
- outputs [0 ][0 ] = res .x
464
- outputs [1 ][0 ] = res .success
475
+ outputs [0 ][0 ] = res .x . reshape ( variables . shape )
476
+ outputs [1 ][0 ] = np . bool_ ( res .success )
465
477
466
478
def L_op (
467
479
self ,
468
480
inputs : Sequence [Variable ],
469
481
outputs : Sequence [Variable ],
470
482
output_grads : Sequence [Variable ],
471
483
) -> list [Variable ]:
472
- # TODO: Broken
473
484
x , * args = inputs
474
- x_star , success = outputs
485
+ x_star , _ = outputs
475
486
output_grad , _ = output_grads
476
487
477
488
inner_x , * inner_args = self .fgraph .inputs
478
489
inner_fx = self .fgraph .outputs [0 ]
479
490
480
- inner_jac = jacobian (inner_fx , [inner_x , * inner_args ])
491
+ df_dx = jacobian (inner_fx , inner_x ) if not self .jac else self .fgraph .outputs [1 ]
492
+
493
+ df_dtheta = concatenate (
494
+ [
495
+ atleast_2d (jac_column , left = False )
496
+ for jac_column in jacobian (
497
+ inner_fx , inner_args , disconnected_inputs = "ignore"
498
+ )
499
+ ],
500
+ axis = - 1 ,
501
+ )
481
502
482
503
replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
483
- jac_f_wrt_x_star , * jac_f_wrt_args = graph_replace (inner_jac , replace = replace )
504
+ df_dx_star , df_dtheta_star = graph_replace ([ df_dx , df_dtheta ] , replace = replace )
484
505
485
- jac_wrt_args = solve (- jac_f_wrt_x_star , output_grad )
506
+ grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
486
507
487
- return [zeros_like (x ), jac_wrt_args ]
508
+ cursor = 0
509
+ grad_wrt_args = []
510
+
511
+ for arg in args :
512
+ arg_shape = arg .shape
513
+ arg_size = arg_shape .prod ()
514
+ arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
515
+ (* x_star .shape , * arg_shape )
516
+ )
517
+
518
+ grad_wrt_args .append (dot (output_grad , arg_grad ))
519
+ cursor += arg_size
520
+
521
+ return [zeros_like (x ), * grad_wrt_args ]
488
522
489
523
490
524
def root (
0 commit comments