Skip to content

Commit 4a8f64f

Browse files
author
Junpeng Lao
committed
scaling for NUTS using 'advi_map'
remaining fix in #2110
1 parent eb6042f commit 4a8f64f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pymc3/sampling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,8 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
602602
obj_optimizer=pm.adagrad_window
603603
)
604604
start = approx.sample(draws=njobs)
605-
cov = approx.cov.eval()
605+
stds = approx.gbij.rmap(approx.std.eval())
606+
cov = model.dict_to_array(stds) ** 2
606607
if njobs == 1:
607608
start = start[0]
608609
elif init == 'map':

0 commit comments

Comments
 (0)