Skip to content

Commit a276aac

Browse files
committed
Add compute_grad option to model.logp_dlogp_function
1 parent eb16420 commit a276aac

File tree

2 files changed

+39
-12
lines changed

2 files changed

+39
-12
lines changed

pymc3/model.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,8 @@ class ValueGradFunction:
597597
See `numpy.can_cast` for a description of the options.
598598
Keep in mind that we cast the variables to the array *and*
599599
back from the array dtype to the variable dtype.
600+
compute_grads: bool, default=True
601+
If False, return only the logp, not the gradient.
600602
kwargs
601603
Extra arguments are passed on to `theano.function`.
602604
@@ -611,7 +613,15 @@ class ValueGradFunction:
611613
"""
612614

613615
def __init__(
614-
self, costs, grad_vars, extra_vars=None, dtype=None, casting="no", **kwargs
616+
self,
617+
costs,
618+
grad_vars,
619+
extra_vars=None,
620+
*,
621+
dtype=None,
622+
casting="no",
623+
compute_grads=True,
624+
**kwargs
615625
):
616626
from .distributions import TensorType
617627

@@ -651,13 +661,13 @@ def __init__(
651661
for var in self._grad_vars:
652662
if not np.can_cast(var.dtype, self.dtype, casting):
653663
raise TypeError(
654-
"Invalid dtype for variable %s. Can not "
655-
"cast to %s with casting rule %s." % (var.name, self.dtype, casting)
664+
f"Invalid dtype for variable {var.name}. Can not "
665+
f"cast to {self.dtype} with casting rule {casting}."
656666
)
657667
if not np.issubdtype(var.dtype, np.floating):
658668
raise TypeError(
659-
"Invalid dtype for variable %s. Must be "
660-
"floating point but is %s." % (var.name, var.dtype)
669+
f"Invalid dtype for variable {var.name}. Must be "
670+
f"floating point but is {var.dtype}."
661671
)
662672

663673
givens = []
@@ -677,13 +687,17 @@ def __init__(
677687
self._cost, grad_vars, self._ordering.vmap
678688
)
679689

680-
grad = tt.grad(self._cost_joined, self._vars_joined)
681-
grad.name = "__grad"
690+
if compute_grads:
691+
grad = tt.grad(self._cost_joined, self._vars_joined)
692+
grad.name = "__grad"
693+
outputs = [self._cost_joined, grad]
694+
else:
695+
outputs = self._cost_joined
682696

683697
inputs = [self._vars_joined]
684698

685699
self._theano_function = theano.function(
686-
inputs, [self._cost_joined, grad], givens=givens, **kwargs
700+
inputs, outputs, givens=givens, **kwargs
687701
)
688702

689703
def set_weights(self, values):
@@ -723,12 +737,12 @@ def __call__(self, array, grad_out=None, extra_vars=None):
723737
else:
724738
out = grad_out
725739

726-
logp, dlogp = self._theano_function(array)
740+
output = self._theano_function(array)
727741
if grad_out is None:
728-
return logp, dlogp
742+
return output
729743
else:
730-
np.copyto(out, dlogp)
731-
return logp
744+
np.copyto(out, output[1])
745+
return output[0]
732746

733747
@property
734748
def profile(self):

pymc3/tests/test_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,18 @@ def test_tempered_logp_dlogp():
393393
func_temp = model.logp_dlogp_function(tempered=True)
394394
func_temp.set_extra_values({})
395395

396+
func_nograd = model.logp_dlogp_function(compute_grads=False)
397+
func_nograd.set_extra_values({})
398+
399+
func_temp_nograd = model.logp_dlogp_function(
400+
tempered=True, compute_grads=False
401+
)
402+
func_temp_nograd.set_extra_values({})
403+
396404
x = np.ones(func.size, dtype=func.dtype)
397405
assert func(x) == func_temp(x)
406+
assert func_nograd(x) == func(x)[0]
407+
assert func_temp_nograd(x) == func(x)[0]
398408

399409
func_temp.set_weights(np.array([0.], dtype=func.dtype))
400410
func_temp_nograd.set_weights(np.array([0.], dtype=func.dtype))
@@ -408,3 +418,6 @@ def test_tempered_logp_dlogp():
408418
func_temp_nograd.set_weights(np.array([0.5], dtype=func.dtype))
409419
npt.assert_allclose(func(x)[0], 4 / 3 * func_temp(x)[0])
410420
npt.assert_allclose(func(x)[1], func_temp(x)[1])
421+
422+
npt.assert_allclose(func_nograd(x), func(x)[0])
423+
npt.assert_allclose(func_temp_nograd(x), func_temp(x)[0])

0 commit comments

Comments
 (0)