@@ -221,13 +221,16 @@ def product(domains, n_samples=-1):
221
221
def build_model (distfam , valuedomain , vardomains , extra_args = None ):
222
222
if extra_args is None :
223
223
extra_args = {}
224
+
224
225
with Model () as m :
225
- vals = {}
226
+ param_vars = {}
226
227
for v , dom in vardomains .items ():
227
- vals [v ] = dom .vals [0 ]
228
- vals .update (extra_args )
229
- distfam ("value" , size = valuedomain .shape , transform = None , ** vals )
230
- return m
228
+ v_at = aesara .shared (np .asarray (dom .vals [0 ]))
229
+ v_at .name = v
230
+ param_vars [v ] = v_at
231
+ param_vars .update (extra_args )
232
+ distfam ("value" , ** param_vars , transform = None )
233
+ return m , param_vars
231
234
232
235
233
236
def laplace_asymmetric_logpdf (value , kappa , b , mu ):
@@ -606,14 +609,34 @@ def logp_reference(args):
606
609
args .update (scipy_args )
607
610
return scipy_logp (** args )
608
611
609
- model = build_model (pymc3_dist , domain , paramdomains , extra_args )
610
- logp = model .fastlogp
612
+ model , param_vars = build_model (pymc3_dist , domain , paramdomains , extra_args )
613
+ logp = model .fastlogp_nojac
611
614
612
615
domains = paramdomains .copy ()
613
616
domains ["value" ] = domain
614
617
for pt in product (domains , n_samples = n_samples ):
618
+
615
619
pt = dict (pt )
616
- pt_logp = Point (pt , model = model )
620
+ pt_d = {}
621
+ for k , v in pt .items ():
622
+ nv = param_vars .get (k , model .named_vars .get (k ))
623
+ nv = getattr (nv .tag , "value_var" , nv )
624
+
625
+ transform = getattr (nv .tag , "transform" , None )
626
+ if transform :
627
+ # TODO: The compiled graph behind this should be cached and
628
+ # reused (if it isn't already).
629
+ v = transform .forward (v ).eval ()
630
+
631
+ if nv .name in param_vars :
632
+ # Update the shared parameter variables in `param_vars`
633
+ param_vars [nv .name ].set_value (v )
634
+ else :
635
+ # Create an argument entry for the (potentially
636
+ # transformed) "value" variable
637
+ pt_d [nv .name ] = v
638
+
639
+ pt_logp = Point (pt_d , model = model )
617
640
pt_ref = Point (pt , filter_model_vars = False , model = model )
618
641
assert_almost_equal (
619
642
logp (pt_logp ),
0 commit comments