@@ -746,16 +746,28 @@ def set_key_if_doesnt_exist(d, key, value):
746
746
if rope is not None :
747
747
display_rope (rope )
748
748
749
- def create_axes_grid (figsize , varnames ):
750
- n = np .ceil (len (varnames ) / 2.0 ).astype (int )
749
+ def create_axes_grid (figsize , traces ):
750
+ n = np .ceil (len (traces ) / 2.0 ).astype (int )
751
751
if figsize is None :
752
752
figsize = (12 , n * 2.5 )
753
753
fig , ax = plt .subplots (n , 2 , figsize = figsize )
754
754
ax = ax .reshape (2 * n )
755
- if len (varnames ) % 2 == 1 :
755
+ if len (traces ) % 2 == 1 :
756
756
ax [- 1 ].set_axis_off ()
757
757
ax = ax [:- 1 ]
758
758
return ax , fig
759
+
760
+ def get_trace_dict (tr , varnames ):
761
+ traces = {}
762
+ for v in varnames :
763
+ vals = tr .get_values (v , combine = True , squeeze = True )
764
+ if vals .ndim > 1 :
765
+ vals_flat = vals .reshape (vals .shape [0 ], - 1 ).T
766
+ for i ,vi in enumerate (vals_flat ):
767
+ traces ['_' .join ([v ,str (i )])] = vi
768
+ else :
769
+ traces [v ] = vals
770
+ return traces
759
771
760
772
if isinstance (trace , np .ndarray ):
761
773
if figsize is None :
@@ -770,12 +782,13 @@ def create_axes_grid(figsize, varnames):
770
782
else :
771
783
varnames = [name for name in trace .varnames if not name .endswith ('_' )]
772
784
785
+ trace_dict = get_trace_dict (trace , varnames )
786
+
773
787
if ax is None :
774
- ax , fig = create_axes_grid (figsize , varnames )
788
+ ax , fig = create_axes_grid (figsize , trace_dict )
775
789
776
- for a , v in zip (ax , varnames ):
777
- tr_values = transform (trace .get_values (
778
- v , combine = True , squeeze = True ))
790
+ for a , v in zip (ax , trace_dict ):
791
+ tr_values = transform (trace_dict [v ])
779
792
plot_posterior_op (tr_values , ax = a )
780
793
a .set_title (v )
781
794
0 commit comments