Skip to content

Commit 735278a

Browse files
Make logp testing work with transformed values
1 parent 573071f commit 735278a

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

pymc3/tests/test_distributions.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,16 @@ def product(domains, n_samples=-1):
221221
def build_model(distfam, valuedomain, vardomains, extra_args=None):
222222
if extra_args is None:
223223
extra_args = {}
224+
224225
with Model() as m:
225-
vals = {}
226+
param_vars = {}
226227
for v, dom in vardomains.items():
227-
vals[v] = dom.vals[0]
228-
vals.update(extra_args)
229-
distfam("value", size=valuedomain.shape, transform=None, **vals)
230-
return m
228+
v_at = aesara.shared(np.asarray(dom.vals[0]))
229+
v_at.name = v
230+
param_vars[v] = v_at
231+
param_vars.update(extra_args)
232+
distfam("value", **param_vars, transform=None)
233+
return m, param_vars
231234

232235

233236
def laplace_asymmetric_logpdf(value, kappa, b, mu):
@@ -606,14 +609,34 @@ def logp_reference(args):
606609
args.update(scipy_args)
607610
return scipy_logp(**args)
608611

609-
model = build_model(pymc3_dist, domain, paramdomains, extra_args)
610-
logp = model.fastlogp
612+
model, param_vars = build_model(pymc3_dist, domain, paramdomains, extra_args)
613+
logp = model.fastlogp_nojac
611614

612615
domains = paramdomains.copy()
613616
domains["value"] = domain
614617
for pt in product(domains, n_samples=n_samples):
618+
615619
pt = dict(pt)
616-
pt_logp = Point(pt, model=model)
620+
pt_d = {}
621+
for k, v in pt.items():
622+
nv = param_vars.get(k, model.named_vars.get(k))
623+
nv = getattr(nv.tag, "value_var", nv)
624+
625+
transform = getattr(nv.tag, "transform", None)
626+
if transform:
627+
# TODO: The compiled graph behind this should be cached and
628+
# reused (if it isn't already).
629+
v = transform.forward(v).eval()
630+
631+
if nv.name in param_vars:
632+
# Update the shared parameter variables in `param_vars`
633+
param_vars[nv.name].set_value(v)
634+
else:
635+
# Create an argument entry for the (potentially
636+
# transformed) "value" variable
637+
pt_d[nv.name] = v
638+
639+
pt_logp = Point(pt_d, model=model)
617640
pt_ref = Point(pt, filter_model_vars=False, model=model)
618641
assert_almost_equal(
619642
logp(pt_logp),

0 commit comments

Comments
 (0)