|
4 | 4 |
|
5 | 5 | import numpy as np
|
6 | 6 | from scipy.optimize import minimize as scipy_minimize
|
| 7 | +from scipy.optimize import minimize_scalar as scipy_minimize_scalar |
7 | 8 | from scipy.optimize import root as scipy_root
|
8 | 9 |
|
9 | 10 | from pytensor import Variable, function, graph_replace
|
@@ -90,6 +91,104 @@ def make_node(self, *inputs):
|
90 | 91 | )
|
91 | 92 |
|
92 | 93 |
|
| 94 | +class MinimizeScalarOp(ScipyWrapperOp): |
| 95 | + __props__ = ("method",) |
| 96 | + |
| 97 | + def __init__( |
| 98 | + self, |
| 99 | + x: Variable, |
| 100 | + *args: Variable, |
| 101 | + objective: Variable, |
| 102 | + method: str = "brent", |
| 103 | + optimizer_kwargs: dict | None = None, |
| 104 | + ): |
| 105 | + self.fgraph = FunctionGraph([x, *args], [objective]) |
| 106 | + |
| 107 | + self.method = method |
| 108 | + self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {} |
| 109 | + self._fn = None |
| 110 | + self._fn_wrapped = None |
| 111 | + |
| 112 | + def perform(self, node, inputs, outputs): |
| 113 | + f = self.fn_wrapped |
| 114 | + x0, *args = inputs |
| 115 | + |
| 116 | + res = scipy_minimize_scalar( |
| 117 | + fun=f, |
| 118 | + args=tuple(args), |
| 119 | + method=self.method, |
| 120 | + **self.optimizer_kwargs, |
| 121 | + ) |
| 122 | + |
| 123 | + outputs[0][0] = np.array(res.x) |
| 124 | + outputs[1][0] = np.bool_(res.success) |
| 125 | + |
| 126 | + def L_op(self, inputs, outputs, output_grads): |
| 127 | + x, *args = inputs |
| 128 | + x_star, _ = outputs |
| 129 | + output_grad, _ = output_grads |
| 130 | + |
| 131 | + inner_x, *inner_args = self.fgraph.inputs |
| 132 | + inner_fx = self.fgraph.outputs[0] |
| 133 | + |
| 134 | + implicit_f = grad(inner_fx, inner_x) |
| 135 | + df_dx = grad(implicit_f, inner_x) |
| 136 | + |
| 137 | + df_dthetas = [ |
| 138 | + grad(implicit_f, arg, disconnected_inputs="ignore") for arg in inner_args |
| 139 | + ] |
| 140 | + |
| 141 | + replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True)) |
| 142 | + df_dx_star, *df_dthetas_stars = graph_replace( |
| 143 | + [df_dx, *df_dthetas], replace=replace |
| 144 | + ) |
| 145 | + |
| 146 | + grad_wrt_args = [ |
| 147 | + (-df_dtheta_star / df_dx_star) * output_grad |
| 148 | + for df_dtheta_star in df_dthetas_stars |
| 149 | + ] |
| 150 | + |
| 151 | + return [zeros_like(x), *grad_wrt_args] |
| 152 | + |
| 153 | + |
| 154 | +def minimize_scalar( |
| 155 | + objective: TensorVariable, |
| 156 | + x: TensorVariable, |
| 157 | + method: str = "brent", |
| 158 | + optimizer_kwargs: dict | None = None, |
| 159 | +): |
| 160 | + """ |
| 161 | + Minimize a scalar objective function using scipy.optimize.minimize_scalar. |
| 162 | + """ |
| 163 | + |
| 164 | + args = [ |
| 165 | + arg |
| 166 | + for arg in graph_inputs([objective], [x]) |
| 167 | + if (arg is not x and not isinstance(arg, Constant)) |
| 168 | + ] |
| 169 | + |
| 170 | + minimize_scalar_op = MinimizeScalarOp( |
| 171 | + x, |
| 172 | + *args, |
| 173 | + objective=objective, |
| 174 | + method=method, |
| 175 | + optimizer_kwargs=optimizer_kwargs, |
| 176 | + ) |
| 177 | + |
| 178 | + input_core_ndim = [var.ndim for var in minimize_scalar_op.inner_inputs] |
| 179 | + input_signatures = [ |
| 180 | + f'({",".join(f"i{i}{n}" for n in range(ndim))})' |
| 181 | + for i, ndim in enumerate(input_core_ndim) |
| 182 | + ] |
| 183 | + |
| 184 | + # Output dimensions are always the same as the first input (the initial values for the optimizer), |
| 185 | + # then a scalar for the success flag |
| 186 | + output_signatures = [input_signatures[0], "()"] |
| 187 | + |
| 188 | + signature = f"{','.join(input_signatures)}->{','.join(output_signatures)}" |
| 189 | + return Blockwise(minimize_scalar_op, signature=signature)(x, *args) |
| 190 | + |
| 191 | + |
93 | 192 | class MinimizeOp(ScipyWrapperOp):
|
94 | 193 | __props__ = ("method", "jac", "hess", "hessp")
|
95 | 194 |
|
|
0 commit comments