Skip to content

Commit c4ccbee

Browse files
brandonwillardtwiecki
authored andcommitted
Allow dict of value vars in logp signatures and implement Subtensor logp
1 parent e1fedb8 commit c4ccbee

File tree

3 files changed

+150
-40
lines changed

3 files changed

+150
-40
lines changed

pymc3/distributions/distribution.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,17 @@ def _random(*args, **kwargs):
9696
if class_logp:
9797

9898
@_logp.register(rv_type)
99-
def logp(op, value, *dist_params, **kwargs):
100-
return class_logp(value, *dist_params, **kwargs)
99+
def logp(op, var, rvs_to_values, *dist_params, **kwargs):
100+
value_var = rvs_to_values.get(var, var)
101+
return class_logp(value_var, *dist_params, **kwargs)
101102

102103
class_logcdf = clsdict.get("logcdf")
103104
if class_logcdf:
104105

105106
@_logcdf.register(rv_type)
106-
def logcdf(op, value, *dist_params, **kwargs):
107-
return class_logcdf(value, *dist_params, **kwargs)
107+
def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
108+
value_var = rvs_to_values.get(var, var)
109+
return class_logcdf(value_var, *dist_params, **kwargs)
108110

109111
# class_transform = clsdict.get("transform")
110112
# if class_transform:

pymc3/distributions/logp.py

Lines changed: 84 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections.abc import Mapping
1516
from functools import singledispatch
16-
from typing import Optional
17+
from typing import Dict, Optional, Union
1718

1819
import aesara.tensor as at
1920
import numpy as np
2021

2122
from aesara import config
2223
from aesara.gradient import disconnected_grad
2324
from aesara.graph.basic import Constant, clone, graph_inputs, io_toposort
25+
from aesara.graph.fg import FunctionGraph
2426
from aesara.graph.op import Op, compute_test_value
2527
from aesara.graph.type import CType
28+
from aesara.tensor.random.op import RandomVariable
29+
from aesara.tensor.random.opt import local_subtensor_rv_lift
2630
from aesara.tensor.subtensor import (
2731
AdvancedIncSubtensor,
2832
AdvancedIncSubtensor1,
@@ -107,7 +111,7 @@ def _get_scaling(total_size, shape, ndim):
107111

108112
def logpt(
109113
var: TensorVariable,
110-
rv_value: Optional[TensorVariable] = None,
114+
rv_values: Optional[Union[TensorVariable, Dict[TensorVariable, TensorVariable]]] = None,
111115
*,
112116
jacobian: bool = True,
113117
scaling: bool = True,
@@ -127,10 +131,10 @@ def logpt(
127131
==========
128132
var
129133
The `RandomVariable` output that determines the log-likelihood graph.
130-
rv_value
131-
The variable that represents the value of `var` in its log-likelihood.
132-
If no `rv_value` is provided, ``var.tag.value_var`` will be checked
133-
and, when available, used.
134+
rv_values
135+
A variable, or ``dict`` of variables, that represents the value of
136+
`var` in its log-likelihood. If no `rv_value` is provided,
137+
``var.tag.value_var`` will be checked and, when available, used.
134138
jacobian
135139
Whether or not to include the Jacobian term.
136140
scaling
@@ -143,16 +147,17 @@ def logpt(
143147
Sum the log-likelihood.
144148
145149
"""
150+
if not isinstance(rv_values, Mapping):
151+
rv_values = {var: rv_values} if rv_values is not None else {}
146152

147153
rv_var, rv_value_var = extract_rv_and_value_vars(var)
148154

149-
if rv_value is None:
155+
rv_value = rv_values.get(rv_var, rv_value_var)
150156

151-
if rv_var is not None and rv_value_var is None:
152-
raise ValueError(f"No value variable specified or associated with {rv_var}")
157+
if rv_var is not None and rv_value is None:
158+
raise ValueError(f"No value variable specified or associated with {rv_var}")
153159

154-
rv_value = rv_value_var
155-
else:
160+
if rv_value is not None:
156161
rv_value = at.as_tensor(rv_value)
157162

158163
if rv_var is not None:
@@ -163,12 +168,12 @@ def logpt(
163168
rv_value_var = rv_value
164169

165170
if rv_var is None:
166-
167171
if var.owner is not None:
168172
return _logp(
169173
var.owner.op,
170-
rv_value,
171-
var.owner.inputs,
174+
var,
175+
rv_values,
176+
*var.owner.inputs,
172177
jacobian=jacobian,
173178
scaling=scaling,
174179
transformed=transformed,
@@ -189,10 +194,13 @@ def logpt(
189194
# Ultimately, with a graph containing only random variables and
190195
# "deterministics", we can simply replace all the random variables with
191196
# their value variables and be done.
197+
tmp_rv_values = rv_values.copy()
198+
tmp_rv_values[rv_var] = rv_var
199+
192200
if not cdf:
193-
logp_var = _logp(rv_node.op, rv_var, *dist_params, **kwargs)
201+
logp_var = _logp(rv_node.op, rv_var, tmp_rv_values, *dist_params, **kwargs)
194202
else:
195-
logp_var = _logcdf(rv_node.op, rv_var, *dist_params, **kwargs)
203+
logp_var = _logcdf(rv_node.op, rv_var, tmp_rv_values, *dist_params, **kwargs)
196204

197205
transform = getattr(rv_value_var.tag, "transform", None) if rv_value_var else None
198206

@@ -204,10 +212,13 @@ def logpt(
204212
logp_var += transformed_jacobian
205213

206214
# Replace random variables with their value variables
215+
replacements = rv_values.copy()
216+
replacements.update({rv_var: rv_value, rv_value_var: rv_value})
217+
207218
(logp_var,), _ = rvs_to_value_vars(
208219
(logp_var,),
209220
apply_transforms=transformed and not cdf,
210-
initial_replacements={rv_var: rv_value, rv_value_var: rv_value},
221+
initial_replacements=replacements,
211222
)
212223

213224
if sum:
@@ -231,15 +242,24 @@ def logpt(
231242

232243

233244
@singledispatch
234-
def _logp(op: Op, value: TensorVariable, *dist_params, **kwargs):
245+
def _logp(
246+
op: Op,
247+
var: TensorVariable,
248+
rvs_to_values: Dict[TensorVariable, TensorVariable],
249+
*inputs: TensorVariable,
250+
**kwargs,
251+
):
235252
"""Create a log-likelihood graph.
236253
237254
This function dispatches on the type of `op`, which should be a subclass
238255
of `RandomVariable`. If you want to implement new log-likelihood graphs
239256
for a `RandomVariable`, register a new function on this dispatcher.
240257
258+
The default assumes that the log-likelihood of a term is a zero.
259+
241260
"""
242-
return at.zeros_like(value)
261+
value_var = rvs_to_values.get(var, var)
262+
return at.zeros_like(value_var)
243263

244264

245265
def convert_indices(indices, entry):
@@ -256,39 +276,70 @@ def convert_indices(indices, entry):
256276
return entry
257277

258278

259-
def index_from_subtensor(idx_list, indices):
279+
def indices_from_subtensor(idx_list, indices):
260280
"""Compute a useable index tuple from the inputs of a ``*Subtensor**`` ``Op``."""
261-
index = tuple(tuple(convert_indices(indices, idx) for idx in idx_list) if idx_list else indices)
262-
if len(index) == 1:
263-
index = index[0]
264-
return index
281+
return tuple(
282+
tuple(convert_indices(list(indices), idx) for idx in idx_list) if idx_list else indices
283+
)
265284

266285

267286
@_logp.register(IncSubtensor)
268287
@_logp.register(AdvancedIncSubtensor)
269288
@_logp.register(AdvancedIncSubtensor1)
270-
def incsubtensor_logp(op, value, inputs, **kwargs):
271-
rv_var, rv_values, *indices = inputs
289+
def incsubtensor_logp(op, var, rvs_to_values, indexed_rv_var, rv_values, *indices, **kwargs):
272290

273-
index = index_from_subtensor(getattr(op, "idx_list", None), indices)
291+
index = indices_from_subtensor(getattr(op, "idx_list", None), indices)
274292

275293
_, (new_rv_var,) = clone(
276-
tuple(v for v in graph_inputs((rv_var,)) if not isinstance(v, Constant)),
277-
(rv_var,),
294+
tuple(v for v in graph_inputs((indexed_rv_var,)) if not isinstance(v, Constant)),
295+
(indexed_rv_var,),
278296
copy_inputs=False,
279297
copy_orphans=False,
280298
)
281299
new_values = at.set_subtensor(disconnected_grad(new_rv_var)[index], rv_values)
282-
logp_var = logpt(rv_var, new_values, **kwargs)
300+
logp_var = logpt(indexed_rv_var, new_values, **kwargs)
283301

284302
return logp_var
285303

286304

287305
@_logp.register(Subtensor)
288306
@_logp.register(AdvancedSubtensor)
289307
@_logp.register(AdvancedSubtensor1)
290-
def subtensor_logp(op, value, *inputs, **kwargs):
291-
raise NotImplementedError()
308+
def subtensor_logp(op, var, rvs_to_values, indexed_rv_var, *indices, **kwargs):
309+
310+
index = indices_from_subtensor(getattr(op, "idx_list", None), indices)
311+
312+
rv_value = rvs_to_values.get(var, getattr(var.tag, "value_var", None))
313+
314+
if indexed_rv_var.owner and isinstance(indexed_rv_var.owner.op, RandomVariable):
315+
316+
# We need to lift the index operation through the random variable so
317+
# that we have a new random variable consisting of only the relevant
318+
# subset of variables per the index.
319+
var_copy = var.owner.clone().default_output()
320+
fgraph = FunctionGraph(
321+
[i for i in graph_inputs((indexed_rv_var,)) if not isinstance(i, Constant)],
322+
[var_copy],
323+
clone=False,
324+
)
325+
326+
(lifted_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
327+
328+
new_rvs_to_values = rvs_to_values.copy()
329+
new_rvs_to_values[lifted_var] = rv_value
330+
331+
logp_var = logpt(lifted_var, new_rvs_to_values, **kwargs)
332+
333+
for idx_var in index:
334+
logp_var += logpt(idx_var, rvs_to_values, **kwargs)
335+
336+
# TODO: We could add the constant case (i.e. `indexed_rv_var.owner is None`)
337+
else:
338+
raise NotImplementedError(
339+
f"`Subtensor` log-likelihood not implemented for {indexed_rv_var.owner}"
340+
)
341+
342+
return logp_var
292343

293344

294345
def logcdf(*args, **kwargs):
@@ -297,7 +348,7 @@ def logcdf(*args, **kwargs):
297348

298349

299350
@singledispatch
300-
def _logcdf(op, value, *args, **kwargs):
351+
def _logcdf(op, values, *args, **kwargs):
301352
"""Create a log-CDF graph.
302353
303354
This function dispatches on the type of `op`, which should be a subclass

pymc3/tests/test_logp.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,18 @@
2424
from aesara.tensor.subtensor import (
2525
AdvancedIncSubtensor,
2626
AdvancedIncSubtensor1,
27+
AdvancedSubtensor,
28+
AdvancedSubtensor1,
2729
IncSubtensor,
30+
Subtensor,
2831
)
2932

3033
from pymc3.aesaraf import floatX, walk_model
3134
from pymc3.distributions.continuous import Normal, Uniform
35+
from pymc3.distributions.discrete import Bernoulli
3236
from pymc3.distributions.logp import logpt
3337
from pymc3.model import Model
38+
from pymc3.tests.helpers import select_by_precision
3439

3540

3641
def test_logpt_basic():
@@ -73,16 +78,16 @@ def test_logpt_basic():
7378
((np.array([0, 1, 4]), np.array([0, 1, 4])), (5, 5)),
7479
],
7580
)
76-
def test_logpt_univariate_incsubtensor(indices, size):
81+
def test_logpt_incsubtensor(indices, size):
7782
"""Make sure we can compute a log-likelihood for ``Y[idx] = data`` where ``Y`` is univariate."""
7883

7984
mu = floatX(np.power(10, np.arange(np.prod(size)))).reshape(size)
8085
data = mu[indices]
8186
sigma = 0.001
8287
rng = aesara.shared(np.random.RandomState(232), borrow=True)
8388

84-
with Model() as m:
85-
a = Normal("a", mu, sigma, size=size, rng=rng)
89+
a = Normal.dist(mu, sigma, size=size, rng=rng)
90+
a.name = "a"
8691

8792
a_idx = at.set_subtensor(a[indices], data)
8893

@@ -131,3 +136,55 @@ def test_logpt_univariate_incsubtensor(indices, size):
131136
assert isinstance(a_client.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1))
132137
indices = tuple(i.eval() for i in a_client.inputs[2:])
133138
np.testing.assert_almost_equal(indices, indices)
139+
140+
141+
def test_logpt_subtensor():
142+
"""Make sure we can compute a log-likelihood for ``Y[I]`` where ``Y`` and ``I`` are random variables."""
143+
144+
size = 5
145+
146+
mu_base = floatX(np.power(10, np.arange(np.prod(size)))).reshape(size)
147+
mu = np.stack([mu_base, -mu_base])
148+
sigma = 0.001
149+
rng = aesara.shared(np.random.RandomState(232), borrow=True)
150+
151+
A_rv = Normal.dist(mu, sigma, rng=rng)
152+
A_rv.name = "A"
153+
154+
p = 0.5
155+
156+
I_rv = Bernoulli.dist(p, size=size, rng=rng)
157+
I_rv.name = "I"
158+
159+
A_idx = A_rv[I_rv, at.ogrid[A_rv.shape[-1] :]]
160+
161+
assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1))
162+
163+
A_idx_value_var = A_idx.type()
164+
A_idx_value_var.name = "A_idx_value"
165+
166+
I_value_var = I_rv.type()
167+
I_value_var.name = "I_value"
168+
169+
A_idx_logp = logpt(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var})
170+
171+
logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp)
172+
173+
# The compiled graph should not contain any `RandomVariables`
174+
assert not any(isinstance(n.op, RandomVariable) for n in logp_vals_fn.maker.fgraph.apply_nodes)
175+
176+
decimals = select_by_precision(float64=6, float32=4)
177+
178+
for i in range(10):
179+
bern_sp = sp.bernoulli(p)
180+
I_value = bern_sp.rvs(size=size).astype(I_rv.dtype)
181+
182+
norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma)
183+
A_idx_value = norm_sp.rvs().astype(A_idx.dtype)
184+
185+
exp_obs_logps = norm_sp.logpdf(A_idx_value)
186+
exp_obs_logps += bern_sp.logpmf(I_value)
187+
188+
logp_vals = logp_vals_fn(A_idx_value, I_value)
189+
190+
np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)

0 commit comments

Comments
 (0)