Skip to content

Commit 2511a58

Browse files
aloctavodiatwiecki
authored andcommitted
refactor kde-related functions and small fixes (#2191)
* refactor kde-related functions and small fixes * autopep8
1 parent 6c31539 commit 2511a58

File tree

6 files changed

+96
-117
lines changed

6 files changed

+96
-117
lines changed

pymc3/plots/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .autocorrplot import autocorrplot
22
from .compareplot import compareplot
33
from .forestplot import forestplot
4-
from .kdeplot import kdeplot, kde2plot
4+
from .kdeplot import kdeplot
55
from .posteriorplot import plot_posterior, plot_posterior_predictive_glm
66
from .traceplot import traceplot
77
from .energyplot import energyplot

pymc3/plots/artists.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import numpy as np
2-
from scipy.stats import kde, mode
2+
from scipy.stats import mode
33

44
from pymc3.stats import hpd
5-
from .utils import fast_kde
5+
from .kdeplot import fast_kde, kdeplot
66

77

88
def _histplot_bins(column, bins=100):
@@ -16,7 +16,8 @@ def histplot_op(ax, data, alpha=.35):
1616
"""Add a histogram for each column of the data to the provided axes."""
1717
hs = []
1818
for column in data.T:
19-
hs.append(ax.hist(column, bins=_histplot_bins(column), alpha=alpha, align='left'))
19+
hs.append(ax.hist(column, bins=_histplot_bins(
20+
column), alpha=alpha, align='left'))
2021
ax.set_xlim(np.min(data) - 0.5, np.max(data) + 0.5)
2122
return hs
2223

@@ -32,7 +33,8 @@ def kdeplot_op(ax, data, prior=None, prior_alpha=1, prior_style='--'):
3233
x = np.linspace(l, u, len(density))
3334
if prior is not None:
3435
p = prior.logp(x).eval()
35-
pls.append(ax.plot(x, np.exp(p), alpha=prior_alpha, ls=prior_style))
36+
pls.append(ax.plot(x, np.exp(p),
37+
alpha=prior_alpha, ls=prior_style))
3638

3739
ls.append(ax.plot(x, density))
3840
except ValueError:
@@ -46,26 +48,7 @@ def kdeplot_op(ax, data, prior=None, prior_alpha=1, prior_style='--'):
4648
return ls, pls
4749

4850

49-
def kde2plot_op(ax, x, y, grid=200, **kwargs):
50-
xmin = x.min()
51-
xmax = x.max()
52-
ymin = y.min()
53-
ymax = y.max()
54-
extent = kwargs.pop('extent', [])
55-
if len(extent) != 4:
56-
extent = [xmin, xmax, ymin, ymax]
57-
58-
grid = grid * 1j
59-
X, Y = np.mgrid[xmin:xmax:grid, ymin:ymax:grid]
60-
positions = np.vstack([X.ravel(), Y.ravel()])
61-
values = np.vstack([x, y])
62-
kernel = kde.gaussian_kde(values)
63-
Z = np.reshape(kernel(positions).T, X.shape)
64-
65-
ax.imshow(np.rot90(Z), extent=extent, **kwargs)
66-
67-
68-
def plot_posterior_op(trace_values, figsize, ax, kde_plot, point_estimate, round_to,
51+
def plot_posterior_op(trace_values, ax, kde_plot, point_estimate, round_to,
6952
alpha_level, ref_val, rope, text_size=16, **kwargs):
7053
"""Artist to draw posterior."""
7154
def format_as_percent(x, round_to=0):
@@ -139,9 +122,8 @@ def set_key_if_doesnt_exist(d, key, value):
139122
d[key] = value
140123

141124
if kde_plot:
142-
density, l, u = fast_kde(trace_values)
143-
x = np.linspace(l, u, len(density))
144-
ax.plot(x, density, figsize=figsize, **kwargs)
125+
kdeplot(trace_values, alpha=0.35, ax=ax, **kwargs)
126+
145127
else:
146128
set_key_if_doesnt_exist(kwargs, 'bins', 30)
147129
set_key_if_doesnt_exist(kwargs, 'edgecolor', 'w')

pymc3/plots/energyplot.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import matplotlib.pyplot as plt
22
import numpy as np
33

4-
from .utils import fast_kde
4+
from .kdeplot import kdeplot
55

6-
def energyplot(trace, kind='kde', figsize=None, ax=None, legend=True, lw=0, alpha=0.5, frame=True, **kwargs):
6+
7+
def energyplot(trace, kind='kde', figsize=None, ax=None, legend=True, lw=0,
8+
alpha=0.35, frame=True, **kwargs):
79
"""Plot energy transition distribution and marginal energy distribution in order
810
to diagnose poor exploration by HMC algorithms.
911
@@ -25,49 +27,47 @@ def energyplot(trace, kind='kde', figsize=None, ax=None, legend=True, lw=0, alph
2527
Alpha value for plot line. Defaults to 0.35.
2628
frame : bool
2729
Flag for plotting frame around figure.
28-
30+
2931
Returns
3032
-------
3133
3234
ax : matplotlib axes
3335
"""
34-
36+
3537
try:
3638
energy = trace['energy']
3739
except KeyError:
3840
print('There is no energy information in the passed trace.')
3941
return ax
40-
series_dict = {'Marginal energy distribution': energy - energy.mean(),
41-
'Energy transition distribution': np.diff(energy)}
42+
series = [('Marginal energy distribution', energy - energy.mean()),
43+
('Energy transition distribution', np.diff(energy))]
4244

4345
if figsize is None:
4446
figsize = (8, 6)
45-
47+
4648
if ax is None:
4749
_, ax = plt.subplots(figsize=figsize)
4850

4951
if kind == 'kde':
50-
for series in series_dict:
51-
density, l, u = fast_kde(series_dict[series])
52-
x = np.linspace(l, u, len(density))
53-
ax.plot(x, density, label=series, **kwargs)
54-
ax.fill_between(x, density, alpha=alpha)
55-
52+
for label, value in series:
53+
kdeplot(value, label=label, alpha=alpha, shade=True, ax=ax,
54+
**kwargs)
55+
5656
elif kind == 'hist':
57-
for series in series_dict:
58-
ax.hist(series_dict[series], lw=lw, alpha=alpha, label=series, **kwargs)
59-
57+
for label, value in series:
58+
ax.hist(value, lw=lw, alpha=alpha, label=label, **kwargs)
59+
6060
else:
6161
raise ValueError('Plot type {} not recognized.'.format(kind))
6262

6363
ax.set_xticks([])
6464
ax.set_yticks([])
65-
65+
6666
if not frame:
6767
for spine in ax.spines.values():
6868
spine.set_visible(False)
6969

7070
if legend:
7171
ax.legend()
72-
72+
7373
return ax

pymc3/plots/kdeplot.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,72 @@
11
import matplotlib.pyplot as plt
2+
import numpy as np
3+
from scipy.signal import gaussian, convolve
24

3-
from .artists import kdeplot_op, kde2plot_op
45

5-
6-
def kdeplot(data, ax=None):
6+
def kdeplot(trace_values, label=None, alpha=0.35, shade=False, ax=None,
7+
**kwargs):
78
if ax is None:
8-
_, ax = plt.subplots(1, 1, squeeze=True)
9-
kdeplot_op(ax, data)
9+
_, ax = plt.subplots()
10+
density, l, u = fast_kde(trace_values)
11+
x = np.linspace(l, u, len(density))
12+
ax.plot(x, density, label=label, **kwargs)
13+
if shade:
14+
ax.fill_between(x, density, alpha=alpha, **kwargs)
1015
return ax
1116

1217

13-
def kde2plot(x, y, grid=200, ax=None, **kwargs):
14-
if ax is None:
15-
_, ax = plt.subplots(1, 1, squeeze=True)
16-
kde2plot_op(ax, x, y, grid, **kwargs)
17-
return ax
18+
def fast_kde(x):
19+
"""
20+
A fft-based Gaussian kernel density estimate (KDE) for computing
21+
the KDE on a regular grid.
22+
The code was adapted from https://github.com/mfouesneau/faststats
23+
24+
Parameters
25+
----------
26+
27+
x : Numpy array or list
28+
29+
Returns
30+
-------
31+
32+
grid: A gridded 1D KDE of the input points (x).
33+
xmin: minimum value of x
34+
xmax: maximum value of x
35+
36+
"""
37+
x = x[~np.isnan(x)]
38+
x = x[~np.isinf(x)]
39+
n = len(x)
40+
nx = 200
41+
42+
# add small jitter in case input values are the same
43+
x += np.random.uniform(-1E-12, 1E-12, size=n)
44+
xmin, xmax = np.min(x), np.max(x)
45+
46+
# compute histogram
47+
bins = np.linspace(xmin, xmax, nx)
48+
xyi = np.digitize(x, bins)
49+
dx = (xmax - xmin) / (nx - 1)
50+
grid = np.histogram(x, bins=nx)[0]
51+
52+
# Scaling factor for bandwidth
53+
scotts_factor = n ** (-0.2)
54+
# Determine the bandwidth using Scott's rule
55+
std_x = np.std(xyi)
56+
kern_nx = int(np.round(scotts_factor * 2 * np.pi * std_x))
57+
58+
# Evaluate the gaussian function on the kernel grid
59+
kernel = np.reshape(gaussian(kern_nx, scotts_factor * std_x), kern_nx)
60+
61+
# Compute the KDE
62+
# use symmetric padding to correct for data boundaries in the kde
63+
npad = np.min((nx, 2 * kern_nx))
64+
65+
grid = np.concatenate([grid[npad: 0: -1], grid, grid[nx: nx - npad: -1]])
66+
grid = convolve(grid, kernel, mode='same')[npad: npad + nx]
67+
68+
norm_factor = n * dx * (2 * np.pi * std_x ** 2 * scotts_factor ** 2) ** 0.5
69+
70+
grid = grid / norm_factor
71+
72+
return grid, xmin, xmax

pymc3/plots/posteriorplot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def get_trace_dict(tr, varnames):
7979
if figsize is None:
8080
figsize = (6, 2)
8181
if ax is None:
82-
fig, ax = plt.subplots()
83-
plot_posterior_op(transform(trace), figsize=figsize, ax=ax, kde_plot=kde_plot,
82+
fig, ax = plt.subplots(figsize=figsize)
83+
plot_posterior_op(transform(trace), ax=ax, kde_plot=kde_plot,
8484
point_estimate=point_estimate, round_to=round_to,
8585
alpha_level=alpha_level, ref_val=ref_val, rope=rope, text_size=text_size, **kwargs)
8686
else:
@@ -94,7 +94,7 @@ def get_trace_dict(tr, varnames):
9494

9595
for a, v in zip(np.atleast_1d(ax), trace_dict):
9696
tr_values = transform(trace_dict[v])
97-
plot_posterior_op(tr_values, figsize=figsize, ax=a, kde_plot=kde_plot,
97+
plot_posterior_op(tr_values, ax=a, kde_plot=kde_plot,
9898
point_estimate=point_estimate, round_to=round_to,
9999
alpha_level=alpha_level, ref_val=ref_val, rope=rope, text_size=text_size, **kwargs)
100100
a.set_title(v)

pymc3/plots/utils.py

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import matplotlib.pyplot as plt
22
import numpy as np
3-
from scipy.signal import gaussian, convolve
43
# plotting utilities can all be in this namespace
54
from ..util import get_default_varnames # pylint: disable=unused-import
65

@@ -39,60 +38,3 @@ def make_2d(a):
3938
newshape = np.product(a.shape[1:]).astype(int)
4039
a = a.reshape((n, newshape), order='F')
4140
return a
42-
43-
44-
def fast_kde(x):
45-
"""
46-
A fft-based Gaussian kernel density estimate (KDE) for computing
47-
the KDE on a regular grid.
48-
The code was adapted from https://github.com/mfouesneau/faststats
49-
50-
Parameters
51-
----------
52-
53-
x : Numpy array or list
54-
55-
Returns
56-
-------
57-
58-
grid: A gridded 1D KDE of the input points (x).
59-
xmin: minimum value of x
60-
xmax: maximum value of x
61-
62-
"""
63-
x = x[~np.isnan(x)]
64-
x = x[~np.isinf(x)]
65-
n = len(x)
66-
nx = 200
67-
68-
# add small jitter in case input values are the same
69-
x += np.random.uniform(-1E-12, 1E-12, size=n)
70-
xmin, xmax = np.min(x), np.max(x)
71-
72-
# compute histogram
73-
bins = np.linspace(xmin, xmax, nx)
74-
xyi = np.digitize(x, bins)
75-
dx = (xmax - xmin) / (nx - 1)
76-
grid = np.histogram(x, bins=nx)[0]
77-
78-
# Scaling factor for bandwidth
79-
scotts_factor = n ** (-0.2)
80-
# Determine the bandwidth using Scott's rule
81-
std_x = np.std(xyi)
82-
kern_nx = int(np.round(scotts_factor * 2 * np.pi * std_x))
83-
84-
# Evaluate the gaussian function on the kernel grid
85-
kernel = np.reshape(gaussian(kern_nx, scotts_factor * std_x), kern_nx)
86-
87-
# Compute the KDE
88-
# use symmetric padding to correct for data boundaries in the kde
89-
npad = np.min((nx, 2 * kern_nx))
90-
91-
grid = np.concatenate([grid[npad: 0: -1], grid, grid[nx: nx - npad: -1]])
92-
grid = convolve(grid, kernel, mode='same')[npad: npad + nx]
93-
94-
norm_factor = n * dx * (2 * np.pi * std_x ** 2 * scotts_factor ** 2) ** 0.5
95-
96-
grid = grid / norm_factor
97-
98-
return grid, xmin, xmax

0 commit comments

Comments
 (0)