Skip to content

Commit f801c96

Browse files
Factor out shared functions
1 parent c587e33 commit f801c96

File tree

1 file changed

+40
-37
lines changed

1 file changed

+40
-37
lines changed

pytensor/tensor/optimize.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,32 @@ def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
128128
]
129129

130130

131+
def _get_parameter_grads_from_vector(
132+
grad_wrt_args_vector: Variable,
133+
x_star: Variable,
134+
args: Sequence[Variable],
135+
output_grad: Variable,
136+
):
137+
"""
138+
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
139+
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
140+
"""
141+
cursor = 0
142+
grad_wrt_args = []
143+
144+
for arg in args:
145+
arg_shape = arg.shape
146+
arg_size = arg_shape.prod()
147+
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
148+
(*x_star.shape, *arg_shape)
149+
)
150+
151+
grad_wrt_args.append(dot(output_grad, arg_grad))
152+
cursor += arg_size
153+
154+
return grad_wrt_args
155+
156+
131157
class ScipyWrapperOp(Op, HasInnerGraph):
132158
"""Shared logic for scipy optimization ops"""
133159

@@ -348,34 +374,25 @@ def L_op(self, inputs, outputs, output_grads):
348374

349375
implicit_f = grad(inner_fx, inner_x)
350376

351-
df_dx = atleast_2d(concatenate(jacobian(implicit_f, [inner_x]), axis=-1))
377+
df_dx, *df_dtheta_columns = jacobian(
378+
implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore"
379+
)
352380

353381
df_dtheta = concatenate(
354-
[
355-
atleast_2d(x, left=False)
356-
for x in jacobian(implicit_f, inner_args, disconnected_inputs="ignore")
357-
],
382+
[atleast_2d(jac_col, left=False) for jac_col in df_dtheta_columns],
358383
axis=-1,
359384
)
360385

361386
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
362387

363-
df_dx_star, df_dtheta_star = graph_replace([df_dx, df_dtheta], replace=replace)
388+
df_dx_star, df_dtheta_star = graph_replace(
389+
[atleast_2d(df_dx), df_dtheta], replace=replace
390+
)
364391

365392
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
366-
367-
cursor = 0
368-
grad_wrt_args = []
369-
370-
for arg in args:
371-
arg_shape = arg.shape
372-
arg_size = arg_shape.prod()
373-
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
374-
(*x_star.shape, *arg_shape)
375-
)
376-
377-
grad_wrt_args.append(dot(output_grad, arg_grad))
378-
cursor += arg_size
393+
grad_wrt_args = _get_parameter_grads_from_vector(
394+
grad_wrt_args_vector, x_star, args, output_grad
395+
)
379396

380397
return [zeros_like(x), *grad_wrt_args]
381398

@@ -504,19 +521,9 @@ def L_op(
504521
df_dx_star, df_dtheta_star = graph_replace([df_dx, df_dtheta], replace=replace)
505522

506523
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
507-
508-
cursor = 0
509-
grad_wrt_args = []
510-
511-
for arg in args:
512-
arg_shape = arg.shape
513-
arg_size = arg_shape.prod()
514-
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
515-
(*x_star.shape, *arg_shape)
516-
)
517-
518-
grad_wrt_args.append(dot(output_grad, arg_grad))
519-
cursor += arg_size
524+
grad_wrt_args = _get_parameter_grads_from_vector(
525+
grad_wrt_args_vector, x_star, args, output_grad
526+
)
520527

521528
return [zeros_like(x), *grad_wrt_args]
522529

@@ -529,11 +536,7 @@ def root(
529536
):
530537
"""Find roots of a system of equations using scipy.optimize.root."""
531538

532-
args = [
533-
arg
534-
for arg in truncated_graph_inputs([equations], [variables])
535-
if (arg is not variables and not isinstance(arg, Constant))
536-
]
539+
args = _find_optimization_parameters(equations, variables)
537540

538541
root_op = RootOp(variables, *args, equations=equations, method=method, jac=jac)
539542

0 commit comments

Comments
 (0)