Skip to content

Commit 2f0119c

Browse files
minimize works
1 parent e42f0fd commit 2f0119c

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

pytensor/tensor/optimize.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
from copy import copy
33
from typing import cast
44

5+
import numpy as np
56
from scipy.optimize import minimize as scipy_minimize
67
from scipy.optimize import root as scipy_root
78

89
from pytensor import Variable, function, graph_replace
9-
from pytensor.gradient import DisconnectedType, grad, jacobian
10+
from pytensor.gradient import grad, jacobian
1011
from pytensor.graph import Apply, Constant, FunctionGraph
11-
from pytensor.graph.basic import truncated_graph_inputs
12+
from pytensor.graph.basic import graph_inputs, truncated_graph_inputs
1213
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
1314
from pytensor.scalar import bool as scalar_bool
15+
from pytensor.tensor import dot
1416
from pytensor.tensor.basic import atleast_2d, concatenate, zeros_like
17+
from pytensor.tensor.blockwise import Blockwise
1518
from pytensor.tensor.slinalg import solve
1619
from pytensor.tensor.variable import TensorVariable
1720

@@ -33,7 +36,7 @@ def build_fn(self):
3336
self._fn = fn = function(self.inner_inputs, outputs)
3437

3538
# Do this reassignment to see the compiled graph in the dprint
36-
self.fgraph = fn.maker.fgraph
39+
# self.fgraph = fn.maker.fgraph
3740

3841
if self.inner_inputs[0].type.shape == ():
3942

@@ -128,11 +131,11 @@ def perform(self, node, inputs, outputs):
128131
x0=x0,
129132
args=tuple(args),
130133
method=self.method,
131-
**self.options,
134+
**self.optimizer_kwargs,
132135
)
133136

134137
outputs[0][0] = res.x
135-
outputs[1][0] = res.success
138+
outputs[1][0] = np.bool_(res.success)
136139

137140
def L_op(self, inputs, outputs, output_grads):
138141
x, *args = inputs
@@ -158,26 +161,22 @@ def L_op(self, inputs, outputs, output_grads):
158161

159162
df_dx_star, df_dtheta_star = graph_replace([df_dx, df_dtheta], replace=replace)
160163

161-
grad_wrt_args_vector = solve(-df_dtheta_star, df_dx_star)
164+
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
162165

163166
cursor = 0
164167
grad_wrt_args = []
165168

166-
for output_grad, arg in zip(output_grads, args, strict=True):
169+
for arg in args:
167170
arg_shape = arg.shape
168171
arg_size = arg_shape.prod()
169-
arg_grad = grad_wrt_args_vector[cursor : cursor + arg_size].reshape(
170-
arg_shape
172+
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
173+
(*x_star.shape, *arg_shape)
171174
)
172175

173-
grad_wrt_args.append(
174-
arg_grad * output_grad
175-
if not isinstance(output_grad.type, DisconnectedType)
176-
else DisconnectedType()
177-
)
176+
grad_wrt_args.append(dot(output_grad, arg_grad))
178177
cursor += arg_size
179178

180-
return [x.zeros_like(), *grad_wrt_args]
179+
return [zeros_like(x), *grad_wrt_args]
181180

182181

183182
def minimize(
@@ -217,7 +216,7 @@ def minimize(
217216
"""
218217
args = [
219218
arg
220-
for arg in truncated_graph_inputs([objective], [x])
219+
for arg in graph_inputs([objective], [x])
221220
if (arg is not x and not isinstance(arg, Constant))
222221
]
223222

@@ -230,7 +229,18 @@ def minimize(
230229
optimizer_kwargs=optimizer_kwargs,
231230
)
232231

233-
return minimize_op(x, *args)
232+
input_core_ndim = [var.ndim for var in minimize_op.inner_inputs]
233+
input_signatures = [
234+
f'({",".join(f"i{i}{n}" for n in range(ndim))})'
235+
for i, ndim in enumerate(input_core_ndim)
236+
]
237+
238+
# Output dimensions are always the same as the first input (the initial values for the optimizer),
239+
# then a scalar for the success flag
240+
output_signatures = [input_signatures[0], "()"]
241+
242+
signature = f"{','.join(input_signatures)}->{','.join(output_signatures)}"
243+
return Blockwise(minimize_op, signature=signature)(x, *args)
234244

235245

236246
class RootOp(ScipyWrapperOp):

tests/tensor/test_optimize.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytensor
44
import pytensor.tensor as pt
5-
from pytensor import config
5+
from pytensor import config, function
66
from pytensor.tensor.optimize import minimize, root
77
from tests import unittest_tools as utt
88

@@ -24,8 +24,12 @@ def test_simple_minimize():
2424
a_val = 2.0
2525
c_val = 3.0
2626

27-
assert success
28-
assert minimized_x.eval({a: a_val, c: c_val, x: 0.0}) == (2 * a_val * c_val)
27+
f = function([a, c, x], [minimized_x, success])
28+
29+
minimized_x_val, success_val = f(a_val, c_val, 0.0)
30+
31+
assert success_val
32+
assert minimized_x_val == (2 * a_val * c_val)
2933

3034
def f(x, a, b):
3135
objective = (x - a * b) ** 2
@@ -51,7 +55,8 @@ def rosenbrock_shifted_scaled(x, a, b):
5155
x0 = np.zeros(5).astype(floatX)
5256
x_star_val = minimized_x.eval({a: a_val, b: b_val, x: x0})
5357

54-
assert success
58+
assert success.eval({a: a_val, b: b_val, x: x0})
59+
5560
np.testing.assert_allclose(
5661
x_star_val, np.ones_like(x_star_val), atol=1e-6, rtol=1e-6
5762
)

0 commit comments

Comments
 (0)