Skip to content

Commit 19323fe

Browse files
committed
Implement tempering in ValueGradFunction
1 parent e1782f2 commit 19323fe

File tree

1 file changed

+53
-12
lines changed

1 file changed

+53
-12
lines changed

pymc3/model.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -581,8 +581,9 @@ class ValueGradFunction:
581581
582582
Parameters
583583
----------
584-
cost: theano variable
585-
The value that we compute with its gradient.
584+
costs: list of theano variables
585+
We compute the weighted sum of the specified theano values, and the gradient
586+
of that sum. The weights can be specified with `ValueGradFunction.set_weights`.
586587
grad_vars: list of named theano variables or None
587588
The arguments with respect to which the gradient is computed.
588589
extra_vars: list of named theano variables or None
@@ -610,7 +611,7 @@ class ValueGradFunction:
610611
"""
611612

612613
def __init__(
613-
self, cost, grad_vars, extra_vars=None, dtype=None, casting="no", **kwargs
614+
self, costs, grad_vars, extra_vars=None, dtype=None, casting="no", **kwargs
614615
):
615616
from .distributions import TensorType
616617

@@ -623,19 +624,30 @@ def __init__(
623624
if len(set(names)) != len(names):
624625
raise ValueError("Names of the arguments are not unique.")
625626

626-
if cost.ndim > 0:
627-
raise ValueError("Cost must be a scalar.")
628-
629627
self._grad_vars = grad_vars
630628
self._extra_vars = extra_vars
631629
self._extra_var_names = {var.name for var in extra_vars}
630+
631+
if dtype is None:
632+
dtype = theano.config.floatX
633+
self.dtype = dtype
634+
635+
self._n_costs = len(costs)
636+
if self._n_costs == 0:
637+
raise ValueError("At least one cost is required.")
638+
weights = np.ones(self._n_costs - 1, dtype=self.dtype)
639+
self._weights = theano.shared(weights, "__weights")
640+
641+
cost = costs[0]
642+
for i, val in enumerate(costs[1:]):
643+
if cost.ndim > 0 or val.ndim > 0:
644+
raise ValueError("All costs must be scalar.")
645+
cost = cost + self._weights[i] * val
646+
632647
self._cost = cost
633648
self._ordering = ArrayOrdering(grad_vars)
634649
self.size = self._ordering.size
635650
self._extra_are_set = False
636-
if dtype is None:
637-
dtype = theano.config.floatX
638-
self.dtype = dtype
639651
for var in self._grad_vars:
640652
if not np.can_cast(var.dtype, self.dtype, casting):
641653
raise TypeError(
@@ -674,6 +686,11 @@ def __init__(
674686
inputs, [self._cost_joined, grad], givens=givens, **kwargs
675687
)
676688

689+
def set_weights(self, values):
690+
if values.shape != (self._n_costs - 1,):
691+
raise ValueError("Invalid shape. Must be (n_costs - 1,).")
692+
self._weights.set_value(values)
693+
677694
def set_extra_values(self, extra_vars):
678695
self._extra_are_set = True
679696
for var in self._extra_vars:
@@ -940,7 +957,18 @@ def dlogp_array(self):
940957
vars = inputvars(self.cont_vars)
941958
return self.bijection.mapf(self.fastdlogp(vars))
942959

943-
def logp_dlogp_function(self, grad_vars=None, **kwargs):
960+
def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
961+
"""Compile a theano function that computes logp and gradient.
962+
963+
Parameters
964+
----------
965+
grad_vars: list of random variables, optional
966+
Compute the gradient with respect to those variables. If None,
967+
use all free random variables of this model.
968+
tempered: bool
969+
Compute the tempered logp `free_logp + alpha * observed_logp`.
970+
`alpha` can be changed using `ValueGradFunction.set_weights([alpha])`.
971+
"""
944972
if grad_vars is None:
945973
grad_vars = list(typefilter(self.free_RVs, continuous_types))
946974
else:
@@ -949,9 +977,22 @@ def logp_dlogp_function(self, grad_vars=None, **kwargs):
949977
raise ValueError(
950978
"Can only compute the gradient of " "continuous types: %s" % var
951979
)
980+
981+
if tempered:
982+
with self:
983+
free_RVs_logp = tt.sum([
984+
tt.sum(var.logpt) for var in self.free_RVs + self.potentials
985+
])
986+
observed_RVs_logp = tt.sum([
987+
tt.sum(var.logpt) for var in self.observed_RVs
988+
])
989+
990+
costs = [free_RVs_logp, observed_RVs_logp]
991+
else:
992+
costs = [self.logpt]
952993
varnames = [var.name for var in grad_vars]
953994
extra_vars = [var for var in self.free_RVs if var.name not in varnames]
954-
return ValueGradFunction(self.logpt, grad_vars, extra_vars, **kwargs)
995+
return ValueGradFunction(costs, grad_vars, extra_vars, **kwargs)
955996

956997
@property
957998
def logpt(self):
@@ -1050,7 +1091,7 @@ def add_coords(self, coords):
10501091
return
10511092

10521093
for name in coords:
1053-
if name in { "draw", "chain" }:
1094+
if name in {"draw", "chain"}:
10541095
raise ValueError(
10551096
"Dimensions can not be named `draw` or `chain`, as they are reserved for the sampler's outputs."
10561097
)

0 commit comments

Comments
 (0)