@@ -47,6 +47,8 @@ def __init__(self, fn):
47
47
self .cache_hits = 0
48
48
self .cache_misses = 0
49
49
50
+ self .value_calls = 0
51
+ self .grad_calls = 0
50
52
self .value_and_grad_calls = 0
51
53
self .hess_calls = 0
52
54
@@ -57,26 +59,27 @@ def __call__(self, x, *args):
57
59
If the input `x` is the same as the last input, return the cached result. Otherwise update the cache with the
58
60
new input and result.
59
61
"""
60
- cache_hit = np .all (x == self .last_x )
61
62
62
- if self .last_x is None or not cache_hit :
63
+ if self .last_result is None or not ( x == self . last_x ). all () :
63
64
self .cache_misses += 1
64
- result = self .fn (x , * args )
65
65
self .last_x = x
66
+
67
+ result = self .fn (x , * args )
66
68
self .last_result = result
69
+
67
70
return result
68
71
69
72
else :
70
73
self .cache_hits += 1
71
74
return self .last_result
72
75
73
76
def value (self , x , * args ):
74
- self .value_and_grad_calls += 1
75
- res = self (x , * args )
76
- if isinstance ( res , tuple ):
77
- return res [ 0 ]
78
- else :
79
- return res
77
+ self .value_calls += 1
78
+ return self (x , * args )[ 0 ]
79
+
80
+ def grad ( self , x , * args ):
81
+ self . grad_calls += 1
82
+ return self ( x , * args )[ 1 ]
80
83
81
84
def value_and_grad (self , x , * args ):
82
85
self .value_and_grad_calls += 1
@@ -97,6 +100,8 @@ def clear_cache(self):
97
100
self .last_result = None
98
101
self .cache_hits = 0
99
102
self .cache_misses = 0
103
+ self .value_calls = 0
104
+ self .grad_calls = 0
100
105
self .value_and_grad_calls = 0
101
106
self .hess_calls = 0
102
107
@@ -109,14 +114,8 @@ def build_fn(self):
109
114
This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
110
115
wrapper function logic is there to handle this.
111
116
"""
112
- # TODO: Introduce rewrites to change MinimizeOp to MinimizeScalarOp and RootOp to RootScalarOp
113
- # when x is scalar. That will remove the need for the wrapper.
114
-
115
117
outputs = self .inner_outputs
116
- if len (outputs ) == 1 :
117
- outputs = outputs [0 ]
118
- self ._fn = fn = function (self .inner_inputs , outputs )
119
-
118
+ self ._fn = fn = function (self .inner_inputs , outputs , trust_input = True )
120
119
# Do this reassignment to see the compiled graph in the dprint
121
120
# self.fgraph = fn.maker.fgraph
122
121
@@ -166,6 +165,10 @@ def prepare_node(
166
165
167
166
def make_node (self , * inputs ):
168
167
assert len (inputs ) == len (self .inner_inputs )
168
+ for input , inner_input in zip (inputs , self .inner_inputs ):
169
+ assert (
170
+ input .type == inner_input .type
171
+ ), f"Input { input } does not match expected type { inner_input .type } "
169
172
170
173
return Apply (
171
174
self , inputs , [self .inner_inputs [0 ].type (), scalar_bool ("success" )]
@@ -192,16 +195,17 @@ def __init__(
192
195
193
196
def perform (self , node , inputs , outputs ):
194
197
f = self .fn_wrapped
198
+ f .clear_cache ()
199
+
195
200
x0 , * args = inputs
196
201
197
202
res = scipy_minimize_scalar (
198
- fun = f ,
203
+ fun = f . value ,
199
204
args = tuple (args ),
200
205
method = self .method ,
201
206
** self .optimizer_kwargs ,
202
207
)
203
208
204
- f .clear_cache ()
205
209
outputs [0 ][0 ] = np .array (res .x )
206
210
outputs [1 ][0 ] = np .bool_ (res .success )
207
211
@@ -214,11 +218,9 @@ def L_op(self, inputs, outputs, output_grads):
214
218
inner_fx = self .fgraph .outputs [0 ]
215
219
216
220
implicit_f = grad (inner_fx , inner_x )
217
- df_dx = grad (implicit_f , inner_x )
218
-
219
- df_dthetas = [
220
- grad (implicit_f , arg , disconnected_inputs = "ignore" ) for arg in inner_args
221
- ]
221
+ df_dx , * df_dthetas = grad (
222
+ implicit_f , [inner_x , * inner_args ], disconnect_inputs = "ignore"
223
+ )
222
224
223
225
replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
224
226
df_dx_star , * df_dthetas_stars = graph_replace (
0 commit comments