@@ -647,28 +647,8 @@ def logp_reference(args):
647
647
domains = paramdomains .copy ()
648
648
domains ["value" ] = domain
649
649
for pt in product (domains , n_samples = n_samples ):
650
-
651
650
pt = dict (pt )
652
- pt_d = {}
653
- for k , v in pt .items ():
654
- rv_var = model .named_vars .get (k )
655
- nv = param_vars .get (k , rv_var )
656
- nv = getattr (nv .tag , "value_var" , nv )
657
-
658
- transform = getattr (nv .tag , "transform" , None )
659
- if transform :
660
- # TODO: The compiled graph behind this should be cached and
661
- # reused (if it isn't already).
662
- v = transform .forward (rv_var , v ).eval ()
663
-
664
- if nv .name in param_vars :
665
- # Update the shared parameter variables in `param_vars`
666
- param_vars [nv .name ].set_value (v )
667
- else :
668
- # Create an argument entry for the (potentially
669
- # transformed) "value" variable
670
- pt_d [nv .name ] = v
671
-
651
+ pt_d = self ._model_input_dict (model , param_vars , pt )
672
652
pt_logp = Point (pt_d , model = model )
673
653
pt_ref = Point (pt , filter_model_vars = False , model = model )
674
654
assert_almost_equal (
@@ -678,6 +658,30 @@ def logp_reference(args):
678
658
err_msg = str (pt ),
679
659
)
680
660
661
+ def _model_input_dict (self , model , param_vars , pt ):
662
+ """Create a dict with only the necessary, transformed logp inputs."""
663
+ pt_d = {}
664
+ for k , v in pt .items ():
665
+ rv_var = model .named_vars .get (k )
666
+ nv = param_vars .get (k , rv_var )
667
+ nv = getattr (nv .tag , "value_var" , nv )
668
+
669
+ transform = getattr (nv .tag , "transform" , None )
670
+ if transform :
671
+ # todo: the compiled graph behind this should be cached and
672
+ # reused (if it isn't already).
673
+ v = transform .forward (rv_var , v ).eval ()
674
+
675
+ if nv .name in param_vars :
676
+ # update the shared parameter variables in `param_vars`
677
+ param_vars [nv .name ].set_value (v )
678
+ else :
679
+ # create an argument entry for the (potentially
680
+ # transformed) "value" variable
681
+ pt_d [nv .name ] = v
682
+
683
+ return pt_d
684
+
681
685
def check_logcdf (
682
686
self ,
683
687
pymc3_dist ,
0 commit comments