2
2
from copy import copy
3
3
from typing import cast
4
4
5
+ import numpy as np
5
6
from scipy .optimize import minimize as scipy_minimize
6
7
from scipy .optimize import root as scipy_root
7
8
8
9
from pytensor import Variable , function , graph_replace
9
- from pytensor .gradient import DisconnectedType , grad , jacobian
10
+ from pytensor .gradient import grad , jacobian
10
11
from pytensor .graph import Apply , Constant , FunctionGraph
11
- from pytensor .graph .basic import truncated_graph_inputs
12
+ from pytensor .graph .basic import graph_inputs , truncated_graph_inputs
12
13
from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
13
14
from pytensor .scalar import bool as scalar_bool
15
+ from pytensor .tensor import dot
14
16
from pytensor .tensor .basic import atleast_2d , concatenate , zeros_like
17
+ from pytensor .tensor .blockwise import Blockwise
15
18
from pytensor .tensor .slinalg import solve
16
19
from pytensor .tensor .variable import TensorVariable
17
20
@@ -33,7 +36,7 @@ def build_fn(self):
33
36
self ._fn = fn = function (self .inner_inputs , outputs )
34
37
35
38
# Do this reassignment to see the compiled graph in the dprint
36
- self .fgraph = fn .maker .fgraph
39
+ # self.fgraph = fn.maker.fgraph
37
40
38
41
if self .inner_inputs [0 ].type .shape == ():
39
42
@@ -128,11 +131,11 @@ def perform(self, node, inputs, outputs):
128
131
x0 = x0 ,
129
132
args = tuple (args ),
130
133
method = self .method ,
131
- ** self .options ,
134
+ ** self .optimizer_kwargs ,
132
135
)
133
136
134
137
outputs [0 ][0 ] = res .x
135
- outputs [1 ][0 ] = res .success
138
+ outputs [1 ][0 ] = np . bool_ ( res .success )
136
139
137
140
def L_op (self , inputs , outputs , output_grads ):
138
141
x , * args = inputs
@@ -158,26 +161,22 @@ def L_op(self, inputs, outputs, output_grads):
158
161
159
162
df_dx_star , df_dtheta_star = graph_replace ([df_dx , df_dtheta ], replace = replace )
160
163
161
- grad_wrt_args_vector = solve (- df_dtheta_star , df_dx_star )
164
+ grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
162
165
163
166
cursor = 0
164
167
grad_wrt_args = []
165
168
166
- for output_grad , arg in zip ( output_grads , args , strict = True ) :
169
+ for arg in args :
167
170
arg_shape = arg .shape
168
171
arg_size = arg_shape .prod ()
169
- arg_grad = grad_wrt_args_vector [cursor : cursor + arg_size ].reshape (
170
- arg_shape
172
+ arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
173
+ ( * x_star . shape , * arg_shape )
171
174
)
172
175
173
- grad_wrt_args .append (
174
- arg_grad * output_grad
175
- if not isinstance (output_grad .type , DisconnectedType )
176
- else DisconnectedType ()
177
- )
176
+ grad_wrt_args .append (dot (output_grad , arg_grad ))
178
177
cursor += arg_size
179
178
180
- return [x . zeros_like (), * grad_wrt_args ]
179
+ return [zeros_like (x ), * grad_wrt_args ]
181
180
182
181
183
182
def minimize (
@@ -217,7 +216,7 @@ def minimize(
217
216
"""
218
217
args = [
219
218
arg
220
- for arg in truncated_graph_inputs ([objective ], [x ])
219
+ for arg in graph_inputs ([objective ], [x ])
221
220
if (arg is not x and not isinstance (arg , Constant ))
222
221
]
223
222
@@ -230,7 +229,18 @@ def minimize(
230
229
optimizer_kwargs = optimizer_kwargs ,
231
230
)
232
231
233
- return minimize_op (x , * args )
232
+ input_core_ndim = [var .ndim for var in minimize_op .inner_inputs ]
233
+ input_signatures = [
234
+ f'({ "," .join (f"i{ i } { n } " for n in range (ndim ))} )'
235
+ for i , ndim in enumerate (input_core_ndim )
236
+ ]
237
+
238
+ # Output dimensions are always the same as the first input (the initial values for the optimizer),
239
+ # then a scalar for the success flag
240
+ output_signatures = [input_signatures [0 ], "()" ]
241
+
242
+ signature = f"{ ',' .join (input_signatures )} ->{ ',' .join (output_signatures )} "
243
+ return Blockwise (minimize_op , signature = signature )(x , * args )
234
244
235
245
236
246
class RootOp (ScipyWrapperOp ):
0 commit comments