-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
refactor kde-related functions and small fixes #2191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from .autocorrplot import autocorrplot | ||
from .compareplot import compareplot | ||
from .forestplot import forestplot | ||
from .kdeplot import kdeplot, kde2plot | ||
from .kdeplot import kdeplot | ||
from .posteriorplot import plot_posterior, plot_posterior_predictive_glm | ||
from .traceplot import traceplot | ||
from .energyplot import energyplot |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from .utils import fast_kde | ||
from .kdeplot import kdeplot | ||
|
||
def energyplot(trace, kind='kde', figsize=None, ax=None, legend=True, lw=0, alpha=0.5, frame=True, **kwargs): | ||
def energyplot(trace, kind='kde', figsize=None, ax=None, legend=True, lw=0, | ||
alpha=0.35, frame=True, **kwargs): | ||
"""Plot energy transition distribution and marginal energy distribution in order | ||
to diagnose poor exploration by HMC algorithms. | ||
|
||
|
@@ -37,8 +38,8 @@ def energyplot(trace, kind='kde', figsize=None, ax=None, legend=True, lw=0, alph | |
except KeyError: | ||
print('There is no energy information in the passed trace.') | ||
return ax | ||
series_dict = {'Marginal energy distribution': energy - energy.mean(), | ||
'Energy transition distribution': np.diff(energy)} | ||
series = [('Marginal energy distribution', energy - energy.mean()), | ||
('Energy transition distribution', np.diff(energy))] | ||
|
||
if figsize is None: | ||
figsize = (8, 6) | ||
|
@@ -47,15 +48,13 @@ def energyplot(trace, kind='kde', figsize=None, ax=None, legend=True, lw=0, alph | |
_, ax = plt.subplots(figsize=figsize) | ||
|
||
if kind == 'kde': | ||
for series in series_dict: | ||
density, l, u = fast_kde(series_dict[series]) | ||
x = np.linspace(l, u, len(density)) | ||
ax.plot(x, density, label=series, **kwargs) | ||
ax.fill_between(x, density, alpha=alpha) | ||
|
||
for label, value in series: | ||
kdeplot(value, label=label, alpha=alpha, shade=True, ax=ax, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to propagate the return values? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the comment. Not sure if I am really understanding you. I think this is what we always do, isn't it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, sorry, this is fine ( |
||
**kwargs) | ||
|
||
elif kind == 'hist': | ||
for series in series_dict: | ||
ax.hist(series_dict[series], lw=lw, alpha=alpha, label=series, **kwargs) | ||
for label, value in series: | ||
ax.hist(value, lw=lw, alpha=alpha, label=label, **kwargs) | ||
|
||
else: | ||
raise ValueError('Plot type {} not recognized.'.format(kind)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,73 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from scipy.signal import gaussian, convolve | ||
|
||
from .artists import kdeplot_op, kde2plot_op | ||
|
||
|
||
def kdeplot(data, ax=None): | ||
def kdeplot(trace_values, label=None, alpha=0.35, shade=False, ax=None, | ||
**kwargs): | ||
if ax is None: | ||
_, ax = plt.subplots(1, 1, squeeze=True) | ||
kdeplot_op(ax, data) | ||
_, ax = plt.subplots() | ||
density, l, u = fast_kde(trace_values) | ||
x = np.linspace(l, u, len(density)) | ||
ax.plot(x, density, label=label, **kwargs) | ||
if shade: | ||
ax.fill_between(x, density, alpha=alpha, **kwargs) | ||
return ax | ||
|
||
def fast_kde(x): | ||
""" | ||
A fft-based Gaussian kernel density estimate (KDE) for computing | ||
the KDE on a regular grid. | ||
The code was adapted from https://github.com/mfouesneau/faststats | ||
|
||
Parameters | ||
---------- | ||
|
||
x : Numpy array or list | ||
|
||
Returns | ||
------- | ||
|
||
grid: A gridded 1D KDE of the input points (x). | ||
xmin: minimum value of x | ||
xmax: maximum value of x | ||
|
||
""" | ||
x = x[~np.isnan(x)] | ||
x = x[~np.isinf(x)] | ||
n = len(x) | ||
nx = 200 | ||
|
||
# add small jitter in case input values are the same | ||
x += np.random.uniform(-1E-12, 1E-12, size=n) | ||
xmin, xmax = np.min(x), np.max(x) | ||
|
||
# compute histogram | ||
bins = np.linspace(xmin, xmax, nx) | ||
xyi = np.digitize(x, bins) | ||
dx = (xmax - xmin) / (nx - 1) | ||
grid = np.histogram(x, bins=nx)[0] | ||
|
||
# Scaling factor for bandwidth | ||
scotts_factor = n ** (-0.2) | ||
# Determine the bandwidth using Scott's rule | ||
std_x = np.std(xyi) | ||
kern_nx = int(np.round(scotts_factor * 2 * np.pi * std_x)) | ||
|
||
# Evaluate the gaussian function on the kernel grid | ||
kernel = np.reshape(gaussian(kern_nx, scotts_factor * std_x), kern_nx) | ||
|
||
# Compute the KDE | ||
# use symmetric padding to correct for data boundaries in the kde | ||
npad = np.min((nx, 2 * kern_nx)) | ||
|
||
grid = np.concatenate([grid[npad: 0: -1], grid, grid[nx: nx - npad: -1]]) | ||
grid = convolve(grid, kernel, mode='same')[npad: npad + nx] | ||
|
||
norm_factor = n * dx * (2 * np.pi * std_x ** 2 * scotts_factor ** 2) ** 0.5 | ||
|
||
grid = grid / norm_factor | ||
|
||
return grid, xmin, xmax | ||
|
||
|
||
def kde2plot(x, y, grid=200, ax=None, **kwargs): | ||
if ax is None: | ||
_, ax = plt.subplots(1, 1, squeeze=True) | ||
kde2plot_op(ax, x, y, grid, **kwargs) | ||
return ax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stray extra-line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will fix it.