Skip to content

Commit 4b52b9e

Browse files
Use LRU cache wrapper for hessian option
1 parent c233ce4 commit 4b52b9e

File tree

2 files changed

+112
-6
lines changed

2 files changed

+112
-6
lines changed

pytensor/tensor/optimize.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections.abc import Sequence
23
from copy import copy
34
from typing import cast
@@ -8,7 +9,7 @@
89
from scipy.optimize import root as scipy_root
910

1011
from pytensor import Variable, function, graph_replace
11-
from pytensor.gradient import grad, jacobian
12+
from pytensor.gradient import grad, hessian, jacobian
1213
from pytensor.graph import Apply, Constant, FunctionGraph
1314
from pytensor.graph.basic import graph_inputs, truncated_graph_inputs
1415
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
@@ -20,6 +21,87 @@
2021
from pytensor.tensor.variable import TensorVariable
2122

2223

24+
_log = logging.getLogger(__name__)
25+
26+
27+
class LRUCache1:
28+
"""
29+
Simple LRU cache with a memory size of 1.
30+
31+
This cache is only usable for a function that takes a single input `x` and returns a single output. The
32+
function can also take any number of additional arguments `*args`, but these are assumed to be constant
33+
between function calls.
34+
35+
The purpose of this cache is to allow for Hessian computation to be reused when calling scipy.optimize functions.
36+
It is very often the case that some sub-computations are repeated between the objective, gradient, and hessian
37+
functions, but by default scipy only allows for the objective and gradient to be fused.
38+
39+
By using this cache, all 3 functions can be fused, which can significantly speed up the optimization process for
40+
expensive functions.
41+
"""
42+
43+
def __init__(self, fn):
44+
self.fn = fn
45+
self.last_x = None
46+
self.last_result = None
47+
48+
self.cache_hits = 0
49+
self.cache_misses = 0
50+
51+
self.value_and_grad_calls = 0
52+
self.hess_calls = 0
53+
54+
def __call__(self, x, *args):
55+
"""
56+
Call the cached function with the given input `x` and additional arguments `*args`.
57+
58+
If the input `x` is the same as the last input, return the cached result. Otherwise update the cache with the
59+
new input and result.
60+
"""
61+
cache_hit = np.all(x == self.last_x)
62+
63+
if self.last_x is None or not cache_hit:
64+
self.cache_misses += 1
65+
result = self.fn(x, *args)
66+
self.last_x = x
67+
self.last_result = result
68+
return result
69+
70+
else:
71+
self.cache_hits += 1
72+
return self.last_result
73+
74+
def value(self, x, *args):
75+
self.value_and_grad_calls += 1
76+
res = self(x, *args)
77+
if isinstance(res, tuple):
78+
return res[0]
79+
else:
80+
return res
81+
82+
def value_and_grad(self, x, *args):
83+
self.value_and_grad_calls += 1
84+
return self(x, *args)[:2]
85+
86+
def hess(self, x, *args):
87+
self.hess_calls += 1
88+
return self(x, *args)[-1]
89+
90+
def report(self):
91+
_log.info(f"Value and Grad calls: {self.value_and_grad_calls}")
92+
_log.info(f"Hess Calls: {self.hess_calls}")
93+
_log.info(f"Hits: {self.cache_hits}")
94+
_log.info(f"Misses: {self.cache_misses}")
95+
96+
def clear_cache(self):
97+
self.last_x = None
98+
self.last_result = None
99+
self.cache_hits = 0
100+
self.cache_misses = 0
101+
self.value_and_grad_calls = 0
102+
self.hess_calls = 0
103+
104+
23105
class ScipyWrapperOp(Op, HasInnerGraph):
24106
"""Shared logic for scipy optimization ops"""
25107

@@ -44,9 +126,9 @@ def build_fn(self):
44126
def fn_wrapper(x, *args):
45127
return fn(x.squeeze(), *args)
46128

47-
self._fn_wrapped = fn_wrapper
129+
self._fn_wrapped = LRUCache1(fn_wrapper)
48130
else:
49-
self._fn_wrapped = fn
131+
self._fn_wrapped = LRUCache1(fn)
50132

51133
@property
52134
def fn(self):
@@ -120,6 +202,7 @@ def perform(self, node, inputs, outputs):
120202
**self.optimizer_kwargs,
121203
)
122204

205+
f.clear_cache()
123206
outputs[0][0] = np.array(res.x)
124207
outputs[1][0] = np.bool_(res.success)
125208

@@ -211,6 +294,12 @@ def __init__(
211294
)
212295
self.fgraph.add_output(grad_wrt_x)
213296

297+
if hess:
298+
hess_wrt_x = cast(
299+
Variable, hessian(self.fgraph.outputs[0], self.fgraph.inputs[0])
300+
)
301+
self.fgraph.add_output(hess_wrt_x)
302+
214303
self.jac = jac
215304
self.hess = hess
216305
self.hessp = hessp
@@ -225,14 +314,17 @@ def perform(self, node, inputs, outputs):
225314
x0, *args = inputs
226315

227316
res = scipy_minimize(
228-
fun=f,
317+
fun=f.value_and_grad if self.jac else f.value,
229318
jac=self.jac,
230319
x0=x0,
231320
args=tuple(args),
321+
hess=f.hess if self.hess else None,
232322
method=self.method,
233323
**self.optimizer_kwargs,
234324
)
235325

326+
f.clear_cache()
327+
236328
outputs[0][0] = res.x
237329
outputs[1][0] = np.bool_(res.success)
238330

@@ -283,6 +375,7 @@ def minimize(
283375
x: TensorVariable,
284376
method: str = "BFGS",
285377
jac: bool = True,
378+
hess: bool = False,
286379
optimizer_kwargs: dict | None = None,
287380
):
288381
"""
@@ -325,6 +418,7 @@ def minimize(
325418
objective=objective,
326419
method=method,
327420
jac=jac,
421+
hess=hess,
328422
optimizer_kwargs=optimizer_kwargs,
329423
)
330424

tests/tensor/test_optimize.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
import pytensor
45
import pytensor.tensor as pt
@@ -68,7 +69,16 @@ def f(x, a, b):
6869
utt.verify_grad(f, [0.0, a_val, c_val], eps=1e-6)
6970

7071

71-
def test_minimize_vector_x():
72+
@pytest.mark.parametrize(
73+
"method, jac, hess",
74+
[
75+
("Newton-CG", True, True),
76+
("L-BFGS-B", True, False),
77+
("powell", False, False),
78+
],
79+
ids=["Newton-CG", "L-BFGS-B", "powell"],
80+
)
81+
def test_minimize_vector_x(method, jac, hess):
7282
def rosenbrock_shifted_scaled(x, a, b):
7383
return (a * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum() + b
7484

@@ -77,7 +87,9 @@ def rosenbrock_shifted_scaled(x, a, b):
7787
b = pt.scalar("b")
7888

7989
objective = rosenbrock_shifted_scaled(x, a, b)
80-
minimized_x, success = minimize(objective, x, method="BFGS")
90+
minimized_x, success = minimize(
91+
objective, x, method=method, jac=jac, hess=hess, optimizer_kwargs={"tol": 1e-16}
92+
)
8193

8294
a_val = 0.5
8395
b_val = 1.0

0 commit comments

Comments
 (0)