1
+ import logging
1
2
from collections .abc import Sequence
2
3
from copy import copy
3
4
from typing import cast
8
9
from scipy .optimize import root as scipy_root
9
10
10
11
from pytensor import Variable , function , graph_replace
11
- from pytensor .gradient import grad , jacobian
12
+ from pytensor .gradient import grad , hessian , jacobian
12
13
from pytensor .graph import Apply , Constant , FunctionGraph
13
14
from pytensor .graph .basic import graph_inputs , truncated_graph_inputs
14
15
from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
20
21
from pytensor .tensor .variable import TensorVariable
21
22
22
23
24
+ _log = logging .getLogger (__name__ )
25
+
26
+
27
+ class LRUCache1 :
28
+ """
29
+ Simple LRU cache with a memory size of 1.
30
+
31
+ This cache is only usable for a function that takes a single input `x` and returns a single output. The
32
+ function can also take any number of additional arguments `*args`, but these are assumed to be constant
33
+ between function calls.
34
+
35
+ The purpose of this cache is to allow for Hessian computation to be reused when calling scipy.optimize functions.
36
+ It is very often the case that some sub-computations are repeated between the objective, gradient, and hessian
37
+ functions, but by default scipy only allows for the objective and gradient to be fused.
38
+
39
+ By using this cache, all 3 functions can be fused, which can significantly speed up the optimization process for
40
+ expensive functions.
41
+ """
42
+
43
+ def __init__ (self , fn ):
44
+ self .fn = fn
45
+ self .last_x = None
46
+ self .last_result = None
47
+
48
+ self .cache_hits = 0
49
+ self .cache_misses = 0
50
+
51
+ self .value_and_grad_calls = 0
52
+ self .hess_calls = 0
53
+
54
+ def __call__ (self , x , * args ):
55
+ """
56
+ Call the cached function with the given input `x` and additional arguments `*args`.
57
+
58
+ If the input `x` is the same as the last input, return the cached result. Otherwise update the cache with the
59
+ new input and result.
60
+ """
61
+ cache_hit = np .all (x == self .last_x )
62
+
63
+ if self .last_x is None or not cache_hit :
64
+ self .cache_misses += 1
65
+ result = self .fn (x , * args )
66
+ self .last_x = x
67
+ self .last_result = result
68
+ return result
69
+
70
+ else :
71
+ self .cache_hits += 1
72
+ return self .last_result
73
+
74
+ def value (self , x , * args ):
75
+ self .value_and_grad_calls += 1
76
+ res = self (x , * args )
77
+ if isinstance (res , tuple ):
78
+ return res [0 ]
79
+ else :
80
+ return res
81
+
82
+ def value_and_grad (self , x , * args ):
83
+ self .value_and_grad_calls += 1
84
+ return self (x , * args )[:2 ]
85
+
86
+ def hess (self , x , * args ):
87
+ self .hess_calls += 1
88
+ return self (x , * args )[- 1 ]
89
+
90
+ def report (self ):
91
+ _log .info (f"Value and Grad calls: { self .value_and_grad_calls } " )
92
+ _log .info (f"Hess Calls: { self .hess_calls } " )
93
+ _log .info (f"Hits: { self .cache_hits } " )
94
+ _log .info (f"Misses: { self .cache_misses } " )
95
+
96
+ def clear_cache (self ):
97
+ self .last_x = None
98
+ self .last_result = None
99
+ self .cache_hits = 0
100
+ self .cache_misses = 0
101
+ self .value_and_grad_calls = 0
102
+ self .hess_calls = 0
103
+
104
+
23
105
class ScipyWrapperOp (Op , HasInnerGraph ):
24
106
"""Shared logic for scipy optimization ops"""
25
107
@@ -44,9 +126,9 @@ def build_fn(self):
44
126
def fn_wrapper (x , * args ):
45
127
return fn (x .squeeze (), * args )
46
128
47
- self ._fn_wrapped = fn_wrapper
129
+ self ._fn_wrapped = LRUCache1 ( fn_wrapper )
48
130
else :
49
- self ._fn_wrapped = fn
131
+ self ._fn_wrapped = LRUCache1 ( fn )
50
132
51
133
@property
52
134
def fn (self ):
@@ -120,6 +202,7 @@ def perform(self, node, inputs, outputs):
120
202
** self .optimizer_kwargs ,
121
203
)
122
204
205
+ f .clear_cache ()
123
206
outputs [0 ][0 ] = np .array (res .x )
124
207
outputs [1 ][0 ] = np .bool_ (res .success )
125
208
@@ -211,6 +294,12 @@ def __init__(
211
294
)
212
295
self .fgraph .add_output (grad_wrt_x )
213
296
297
+ if hess :
298
+ hess_wrt_x = cast (
299
+ Variable , hessian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
300
+ )
301
+ self .fgraph .add_output (hess_wrt_x )
302
+
214
303
self .jac = jac
215
304
self .hess = hess
216
305
self .hessp = hessp
@@ -225,14 +314,17 @@ def perform(self, node, inputs, outputs):
225
314
x0 , * args = inputs
226
315
227
316
res = scipy_minimize (
228
- fun = f ,
317
+ fun = f . value_and_grad if self . jac else f . value ,
229
318
jac = self .jac ,
230
319
x0 = x0 ,
231
320
args = tuple (args ),
321
+ hess = f .hess if self .hess else None ,
232
322
method = self .method ,
233
323
** self .optimizer_kwargs ,
234
324
)
235
325
326
+ f .clear_cache ()
327
+
236
328
outputs [0 ][0 ] = res .x
237
329
outputs [1 ][0 ] = np .bool_ (res .success )
238
330
@@ -283,6 +375,7 @@ def minimize(
283
375
x : TensorVariable ,
284
376
method : str = "BFGS" ,
285
377
jac : bool = True ,
378
+ hess : bool = False ,
286
379
optimizer_kwargs : dict | None = None ,
287
380
):
288
381
"""
@@ -325,6 +418,7 @@ def minimize(
325
418
objective = objective ,
326
419
method = method ,
327
420
jac = jac ,
421
+ hess = hess ,
328
422
optimizer_kwargs = optimizer_kwargs ,
329
423
)
330
424
0 commit comments