Skip to content

Commit 9807a18

Browse files
Feedback
1 parent a47cbfc commit 9807a18

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

pytensor/tensor/optimize.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def __init__(self, fn):
4747
self.cache_hits = 0
4848
self.cache_misses = 0
4949

50+
self.value_calls = 0
51+
self.grad_calls = 0
5052
self.value_and_grad_calls = 0
5153
self.hess_calls = 0
5254

@@ -57,26 +59,27 @@ def __call__(self, x, *args):
5759
If the input `x` is the same as the last input, return the cached result. Otherwise update the cache with the
5860
new input and result.
5961
"""
60-
cache_hit = np.all(x == self.last_x)
6162

62-
if self.last_x is None or not cache_hit:
63+
if self.last_result is None or not (x == self.last_x).all():
6364
self.cache_misses += 1
64-
result = self.fn(x, *args)
6565
self.last_x = x
66+
67+
result = self.fn(x, *args)
6668
self.last_result = result
69+
6770
return result
6871

6972
else:
7073
self.cache_hits += 1
7174
return self.last_result
7275

7376
def value(self, x, *args):
74-
self.value_and_grad_calls += 1
75-
res = self(x, *args)
76-
if isinstance(res, tuple):
77-
return res[0]
78-
else:
79-
return res
77+
self.value_calls += 1
78+
return self(x, *args)[0]
79+
80+
def grad(self, x, *args):
81+
self.grad_calls += 1
82+
return self(x, *args)[1]
8083

8184
def value_and_grad(self, x, *args):
8285
self.value_and_grad_calls += 1
@@ -97,6 +100,8 @@ def clear_cache(self):
97100
self.last_result = None
98101
self.cache_hits = 0
99102
self.cache_misses = 0
103+
self.value_calls = 0
104+
self.grad_calls = 0
100105
self.value_and_grad_calls = 0
101106
self.hess_calls = 0
102107

@@ -109,14 +114,8 @@ def build_fn(self):
109114
This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
110115
wrapper function logic is there to handle this.
111116
"""
112-
# TODO: Introduce rewrites to change MinimizeOp to MinimizeScalarOp and RootOp to RootScalarOp
113-
# when x is scalar. That will remove the need for the wrapper.
114-
115117
outputs = self.inner_outputs
116-
if len(outputs) == 1:
117-
outputs = outputs[0]
118-
self._fn = fn = function(self.inner_inputs, outputs)
119-
118+
self._fn = fn = function(self.inner_inputs, outputs, trust_input=True)
120119
# Do this reassignment to see the compiled graph in the dprint
121120
# self.fgraph = fn.maker.fgraph
122121

@@ -166,6 +165,10 @@ def prepare_node(
166165

167166
def make_node(self, *inputs):
168167
assert len(inputs) == len(self.inner_inputs)
168+
for input, inner_input in zip(inputs, self.inner_inputs):
169+
assert (
170+
input.type == inner_input.type
171+
), f"Input {input} does not match expected type {inner_input.type}"
169172

170173
return Apply(
171174
self, inputs, [self.inner_inputs[0].type(), scalar_bool("success")]
@@ -192,16 +195,17 @@ def __init__(
192195

193196
def perform(self, node, inputs, outputs):
194197
f = self.fn_wrapped
198+
f.clear_cache()
199+
195200
x0, *args = inputs
196201

197202
res = scipy_minimize_scalar(
198-
fun=f,
203+
fun=f.value,
199204
args=tuple(args),
200205
method=self.method,
201206
**self.optimizer_kwargs,
202207
)
203208

204-
f.clear_cache()
205209
outputs[0][0] = np.array(res.x)
206210
outputs[1][0] = np.bool_(res.success)
207211

@@ -214,11 +218,9 @@ def L_op(self, inputs, outputs, output_grads):
214218
inner_fx = self.fgraph.outputs[0]
215219

216220
implicit_f = grad(inner_fx, inner_x)
217-
df_dx = grad(implicit_f, inner_x)
218-
219-
df_dthetas = [
220-
grad(implicit_f, arg, disconnected_inputs="ignore") for arg in inner_args
221-
]
221+
df_dx, *df_dthetas = grad(
222+
implicit_f, [inner_x, *inner_args], disconnect_inputs="ignore"
223+
)
222224

223225
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
224226
df_dx_star, *df_dthetas_stars = graph_replace(

0 commit comments

Comments
 (0)