|
7 | 7 | from scipy.optimize import minimize as scipy_minimize
|
8 | 8 | from scipy.optimize import minimize_scalar as scipy_minimize_scalar
|
9 | 9 | from scipy.optimize import root as scipy_root
|
| 10 | +from scipy.optimize import root_scalar as scipy_root_scalar |
10 | 11 |
|
11 | 12 | from pytensor import Variable, function, graph_replace
|
12 | 13 | from pytensor.gradient import grad, hessian, jacobian
|
@@ -529,8 +530,111 @@ def minimize(
|
529 | 530 | return minimize_op(x, *args)
|
530 | 531 |
|
531 | 532 |
|
| 533 | +class RootScalarOp(ScipyWrapperOp): |
| 534 | + __props__ = ("method", "jac", "hess") |
| 535 | + |
| 536 | + def __init__( |
| 537 | + self, |
| 538 | + variables, |
| 539 | + *args, |
| 540 | + equation, |
| 541 | + method, |
| 542 | + jac: bool = False, |
| 543 | + hess: bool = False, |
| 544 | + optimizer_kwargs=None, |
| 545 | + ): |
| 546 | + self.fgraph = FunctionGraph([variables, *args], [equation]) |
| 547 | + |
| 548 | + if jac: |
| 549 | + f_prime = grad(self.fgraph.outputs[0], self.fgraph.inputs[0]) |
| 550 | + self.fgraph.add_output(f_prime) |
| 551 | + |
| 552 | + if hess: |
| 553 | + if not jac: |
| 554 | + raise ValueError( |
| 555 | + "Cannot set `hess=True` without `jac=True`. No methods use second derivatives without also" |
| 556 | + " using first derivatives." |
| 557 | + ) |
| 558 | + f_double_prime = grad(self.fgraph.outputs[-1], self.fgraph.inputs[0]) |
| 559 | + self.fgraph.add_output(f_double_prime) |
| 560 | + |
| 561 | + self.method = method |
| 562 | + self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {} |
| 563 | + self.jac = jac |
| 564 | + self.hess = hess |
| 565 | + |
| 566 | + self._fn = None |
| 567 | + self._fn_wrapped = None |
| 568 | + |
| 569 | + def perform(self, node, inputs, outputs): |
| 570 | + f = self.fn_wrapped |
| 571 | + f.clear_cache() |
| 572 | + # f.copy_x = True |
| 573 | + |
| 574 | + variables, *args = inputs |
| 575 | + |
| 576 | + res = scipy_root_scalar( |
| 577 | + f=f.value, |
| 578 | + fprime=f.grad if self.jac else None, |
| 579 | + fprime2=f.hess if self.hess else None, |
| 580 | + x0=variables, |
| 581 | + args=tuple(args), |
| 582 | + method=self.method, |
| 583 | + **self.optimizer_kwargs, |
| 584 | + ) |
| 585 | + |
| 586 | + outputs[0][0] = np.array(res.root) |
| 587 | + outputs[1][0] = np.bool_(res.converged) |
| 588 | + |
| 589 | + def L_op(self, inputs, outputs, output_grads): |
| 590 | + x, *args = inputs |
| 591 | + x_star, _ = outputs |
| 592 | + output_grad, _ = output_grads |
| 593 | + |
| 594 | + inner_x, *inner_args = self.fgraph.inputs |
| 595 | + inner_fx = self.fgraph.outputs[0] |
| 596 | + |
| 597 | + grad_wrt_args = scalar_implict_optimization_grads( |
| 598 | + inner_fx=inner_fx, |
| 599 | + inner_x=inner_x, |
| 600 | + inner_args=inner_args, |
| 601 | + args=args, |
| 602 | + x_star=x_star, |
| 603 | + output_grad=output_grad, |
| 604 | + fgraph=self.fgraph, |
| 605 | + ) |
| 606 | + |
| 607 | + return [zeros_like(x), *grad_wrt_args] |
| 608 | + |
| 609 | + |
| 610 | +def root_scalar( |
| 611 | + equation: TensorVariable, |
| 612 | + variables: TensorVariable, |
| 613 | + method: str = "secant", |
| 614 | + jac: bool = False, |
| 615 | + hess: bool = False, |
| 616 | + optimizer_kwargs: dict | None = None, |
| 617 | +): |
| 618 | + """ |
| 619 | + Find roots of a scalar equation using scipy.optimize.root_scalar. |
| 620 | + """ |
| 621 | + args = _find_optimization_parameters(equation, variables) |
| 622 | + |
| 623 | + root_scalar_op = RootScalarOp( |
| 624 | + variables, |
| 625 | + *args, |
| 626 | + equation=equation, |
| 627 | + method=method, |
| 628 | + jac=jac, |
| 629 | + hess=hess, |
| 630 | + optimizer_kwargs=optimizer_kwargs, |
| 631 | + ) |
| 632 | + |
| 633 | + return root_scalar_op(variables, *args) |
| 634 | + |
| 635 | + |
532 | 636 | class RootOp(ScipyWrapperOp):
|
533 |
| - __props__ = ("method", "jac", "optimizer_kwargs") |
| 637 | + __props__ = ("method", "jac") |
534 | 638 |
|
535 | 639 | def __init__(
|
536 | 640 | self,
|
@@ -616,4 +720,4 @@ def root(
|
616 | 720 | return root_op(variables, *args)
|
617 | 721 |
|
618 | 722 |
|
619 |
| -__all__ = ["minimize", "root"] |
| 723 | +__all__ = ["minimize_scalar", "minimize", "root_scalar", "root"] |
0 commit comments