Skip to content

Commit 05c40b7

Browse files
Factor out parameter pre-processing in TestMatchesScipy
1 parent bf6cce0 commit 05c40b7

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

pymc3/tests/test_distributions.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -647,28 +647,8 @@ def logp_reference(args):
647647
domains = paramdomains.copy()
648648
domains["value"] = domain
649649
for pt in product(domains, n_samples=n_samples):
650-
651650
pt = dict(pt)
652-
pt_d = {}
653-
for k, v in pt.items():
654-
rv_var = model.named_vars.get(k)
655-
nv = param_vars.get(k, rv_var)
656-
nv = getattr(nv.tag, "value_var", nv)
657-
658-
transform = getattr(nv.tag, "transform", None)
659-
if transform:
660-
# TODO: The compiled graph behind this should be cached and
661-
# reused (if it isn't already).
662-
v = transform.forward(rv_var, v).eval()
663-
664-
if nv.name in param_vars:
665-
# Update the shared parameter variables in `param_vars`
666-
param_vars[nv.name].set_value(v)
667-
else:
668-
# Create an argument entry for the (potentially
669-
# transformed) "value" variable
670-
pt_d[nv.name] = v
671-
651+
pt_d = self._model_input_dict(model, param_vars, pt)
672652
pt_logp = Point(pt_d, model=model)
673653
pt_ref = Point(pt, filter_model_vars=False, model=model)
674654
assert_almost_equal(
@@ -678,6 +658,30 @@ def logp_reference(args):
678658
err_msg=str(pt),
679659
)
680660

661+
def _model_input_dict(self, model, param_vars, pt):
662+
"""Create a dict with only the necessary, transformed logp inputs."""
663+
pt_d = {}
664+
for k, v in pt.items():
665+
rv_var = model.named_vars.get(k)
666+
nv = param_vars.get(k, rv_var)
667+
nv = getattr(nv.tag, "value_var", nv)
668+
669+
transform = getattr(nv.tag, "transform", None)
670+
if transform:
671+
# todo: the compiled graph behind this should be cached and
672+
# reused (if it isn't already).
673+
v = transform.forward(rv_var, v).eval()
674+
675+
if nv.name in param_vars:
676+
# update the shared parameter variables in `param_vars`
677+
param_vars[nv.name].set_value(v)
678+
else:
679+
# create an argument entry for the (potentially
680+
# transformed) "value" variable
681+
pt_d[nv.name] = v
682+
683+
return pt_d
684+
681685
def check_logcdf(
682686
self,
683687
pymc3_dist,

0 commit comments

Comments
 (0)