Skip to content

Commit 2572852

Browse files
AustinRochfordtwiecki
authored andcommitted
ENH Mixture Models (#1437)
* First pass at mixture modelling * No longer necessary to reference self.comp_dists directly in logp * Add dimension internally (when necessary) * Import get_tau_sd * Misc bugfixes * Add sampling to Mixtures * Differentiate between Discrete and Continuous mixtures when possible * Add support for 2D weights * Gracefully try to calculate mean and mode defaults * Add docstrings for Mixture classes * Export mixture models * Reference self.comp_dists * Remove unnecessary pm. * Add Mixture tests * Add missing imports * Add marginalized Gaussian mixture model example * Calculate the mode of the mixture distribution correctly
1 parent f1e622f commit 2572852

File tree

5 files changed

+609
-1
lines changed

5 files changed

+609
-1
lines changed

docs/source/examples.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Mixture Models
4242

4343
.. toctree::
4444
notebooks/gaussian_mixture_model.ipynb
45+
notebooks/marginalized_gaussian_mixture_model.ipynb
4546
notebooks/gaussian-mixture-model-advi.ipynb
4647
notebooks/dp_mix.ipynb
4748

docs/source/notebooks/marginalized_gaussian_mixture_model.ipynb

Lines changed: 319 additions & 0 deletions
Large diffs are not rendered by default.

pymc3/distributions/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
from .distribution import TensorType
4646
from .distribution import draw_values
4747

48+
from .mixture import Mixture
49+
from .mixture import NormalMixture
50+
4851
from .multivariate import MvNormal
4952
from .multivariate import MvStudentT
5053
from .multivariate import Dirichlet
@@ -112,5 +115,7 @@
112115
'AR1',
113116
'GaussianRandomWalk',
114117
'GARCH11',
115-
'SkewNormal'
118+
'SkewNormal',
119+
'Mixture',
120+
'NormalMixture'
116121
]

pymc3/distributions/mixture.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import numpy as np
2+
import theano.tensor as tt
3+
4+
from ..math import logsumexp
5+
from .dist_math import bound
6+
from .distribution import Discrete, Distribution, draw_values, generate_samples
7+
from .continuous import get_tau_sd, Normal
8+
9+
10+
def all_discrete(comp_dists):
11+
"""
12+
Determine if all distributions in comp_dists are discrete
13+
"""
14+
if isinstance(comp_dists, Distribution):
15+
return isinstance(comp_dists, Discrete)
16+
else:
17+
return all(isinstance(comp_dist, Discrete) for comp_dist in comp_dists)
18+
19+
20+
class Mixture(Distribution):
21+
R"""
22+
Mixture log-likelihood
23+
24+
Often used to model subpopulation heterogeneity
25+
26+
.. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)
27+
28+
======== ============================================
29+
Support :math:`\cap_{i = 1}^n \textrm{support}(f_i)`
30+
Mean :math:`\sum_{i = 1}^n w_i \mu_i`
31+
======== ============================================
32+
33+
Parameters
34+
----------
35+
w : array of floats
36+
w >= 0 and w <= 1
37+
the mixutre weights
38+
comp_dists : multidimensional PyMC3 distribution or iterable of one-dimensional PyMC3 distributions
39+
the component distributions :math:`f_1, \ldots, f_n`
40+
"""
41+
def __init__(self, w, comp_dists, *args, **kwargs):
42+
shape = kwargs.pop('shape', ())
43+
44+
self.w = w
45+
self.comp_dists = comp_dists
46+
47+
defaults = kwargs.pop('defaults', [])
48+
49+
if all_discrete(comp_dists):
50+
dtype = kwargs.pop('dtype', 'int64')
51+
else:
52+
dtype = kwargs.pop('dtype', 'float64')
53+
54+
try:
55+
self.mean = (w * self._comp_means()).sum(axis=-1)
56+
57+
if 'mean' not in defaults:
58+
defaults.append('mean')
59+
except AttributeError:
60+
pass
61+
62+
try:
63+
comp_modes = self._comp_modes()
64+
comp_mode_logps = self.logp(comp_modes)
65+
self.mode = comp_modes[tt.argmax(w * comp_mode_logps, axis=-1)]
66+
67+
if 'mode' not in defaults:
68+
defaults.append('mode')
69+
except AttributeError:
70+
pass
71+
72+
super(Mixture, self).__init__(shape, dtype, defaults=defaults,
73+
*args, **kwargs)
74+
75+
def _comp_logp(self, value):
76+
comp_dists = self.comp_dists
77+
78+
try:
79+
value_ = value if value.ndim > 1 else tt.shape_padright(value)
80+
81+
return comp_dists.logp(value_)
82+
except AttributeError:
83+
return tt.stack([comp_dist.logp(value) for comp_dist in comp_dists],
84+
axis=1)
85+
86+
def _comp_means(self):
87+
try:
88+
return self.comp_dists.mean
89+
except AttributeError:
90+
return tt.stack([comp_dist.mean for comp_dist in self.comp_dists],
91+
axis=1)
92+
93+
def _comp_modes(self):
94+
try:
95+
return self.comp_dists.mode
96+
except AttributeError:
97+
return tt.stack([comp_dist.mode for comp_dist in self.comp_dists],
98+
axis=1)
99+
100+
def _comp_samples(self, point=None, size=None, repeat=None):
101+
try:
102+
samples = self.comp_dists.random(point=point, size=size, repeat=repeat)
103+
except AttributeError:
104+
samples = np.column_stack([comp_dist.random(point=point, size=size, repeat=repeat)
105+
for comp_dist in self.comp_dists])
106+
107+
return np.squeeze(samples)
108+
109+
def logp(self, value):
110+
w = self.w
111+
112+
return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1).sum(),
113+
w >= 0, w <= 1, tt.allclose(w.sum(axis=-1), 1))
114+
115+
def random(self, point=None, size=None, repeat=None):
116+
def random_choice(*args, **kwargs):
117+
w = kwargs.pop('w')
118+
w /= w.sum(axis=-1, keepdims=True)
119+
k = w.shape[-1]
120+
121+
if w.ndim > 1:
122+
return np.row_stack([np.random.choice(k, p=w_) for w_ in w])
123+
else:
124+
return np.random.choice(k, p=w, *args, **kwargs)
125+
126+
w = draw_values([self.w], point=point)
127+
128+
w_samples = generate_samples(random_choice,
129+
w=w,
130+
broadcast_shape=w.shape[:-1] or (1,),
131+
dist_shape=self.shape,
132+
size=size).squeeze()
133+
comp_samples = self._comp_samples(point=point, size=size, repeat=repeat)
134+
135+
if comp_samples.ndim > 1:
136+
return np.squeeze(comp_samples[np.arange(w_samples.size), w_samples])
137+
else:
138+
return np.squeeze(comp_samples[w_samples])
139+
140+
141+
class NormalMixture(Mixture):
142+
R"""
143+
Normal mixture log-likelihood
144+
145+
.. math:: f(x \mid w, \mu, \sigma^2) = \sum_{i = 1}^n w_i N(x \mid \mu_i, \sigma^2_i
146+
147+
======== =======================================
148+
Support :math:`x \in \mathbb{R}`
149+
Mean :math:`\sum_{i = 1}^n w_i \mu_i`
150+
Variance :math:`\sum_{i = 1}^n w_i^2 \sigma^2_i`
151+
======== =======================================
152+
153+
Parameters
154+
w : array of floats
155+
w >= 0 and w <= 1
156+
the mixutre weights
157+
mu : array of floats
158+
the component means
159+
sd : array of floats
160+
the component standard deviations
161+
tau : array of floats
162+
the component precisions
163+
"""
164+
def __init__(self, w, mu, *args, **kwargs):
165+
_, sd = get_tau_sd(tau=kwargs.pop('tau', None),
166+
sd=kwargs.pop('sd', None))
167+
168+
super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd),
169+
*args, **kwargs)

pymc3/tests/test_mixture.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import numpy as np
2+
from numpy.testing import assert_allclose
3+
4+
from .helpers import SeededTest
5+
from pymc3 import Dirichlet, Gamma, Metropolis, Mixture, Model, Normal, NormalMixture, Poisson, sample
6+
7+
8+
# Generate data
9+
def generate_normal_mixture_data(w, mu, sd, size=1000):
10+
component = np.random.choice(w.size, size=size, p=w)
11+
x = np.random.normal(mu[component], sd[component], size=size)
12+
13+
return x
14+
15+
16+
def generate_poisson_mixture_data(w, mu, size=1000):
17+
component = np.random.choice(w.size, size=size, p=w)
18+
x = np.random.poisson(mu[component], size=size)
19+
20+
return x
21+
22+
23+
class TestMixture(SeededTest):
24+
@classmethod
25+
def setUpClass(cls):
26+
super(TestMixture, cls).setUpClass()
27+
28+
cls.norm_w = np.array([0.75, 0.25])
29+
cls.norm_mu = np.array([0., 5.])
30+
cls.norm_sd = np.ones_like(cls.norm_mu)
31+
cls.norm_x = generate_normal_mixture_data(cls.norm_w, cls.norm_mu, cls.norm_sd, size=1000)
32+
33+
cls.pois_w = np.array([0.4, 0.6])
34+
cls.pois_mu = np.array([5., 20.])
35+
cls.pois_x = generate_poisson_mixture_data(cls.pois_w, cls.pois_mu, size=1000)
36+
37+
def test_mixture_list_of_normals(self):
38+
with Model() as model:
39+
w = Dirichlet('w', np.ones_like(self.norm_w))
40+
41+
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
42+
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
43+
44+
x_obs = Mixture('x_obs', w,
45+
[Normal.dist(mu[0], tau=tau[0]),
46+
Normal.dist(mu[1], tau=tau[1])],
47+
observed=self.norm_x)
48+
49+
step = Metropolis()
50+
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)
51+
52+
assert_allclose(np.sort(trace['w'].mean(axis=0)),
53+
np.sort(self.norm_w),
54+
rtol=0.1, atol=0.1)
55+
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
56+
np.sort(self.norm_mu),
57+
rtol=0.1, atol=0.1)
58+
59+
def test_normal_mixture(self):
60+
with Model() as model:
61+
w = Dirichlet('w', np.ones_like(self.norm_w))
62+
63+
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
64+
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
65+
66+
x_obs = NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x)
67+
68+
step = Metropolis()
69+
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)
70+
71+
assert_allclose(np.sort(trace['w'].mean(axis=0)),
72+
np.sort(self.norm_w),
73+
rtol=0.1, atol=0.1)
74+
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
75+
np.sort(self.norm_mu),
76+
rtol=0.1, atol=0.1)
77+
78+
def test_poisson_mixture(self):
79+
with Model() as model:
80+
w = Dirichlet('w', np.ones_like(self.pois_w))
81+
82+
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
83+
84+
x_obs = Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x)
85+
86+
step = Metropolis()
87+
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)
88+
89+
assert_allclose(np.sort(trace['w'].mean(axis=0)),
90+
np.sort(self.pois_w),
91+
rtol=0.1, atol=0.1)
92+
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
93+
np.sort(self.pois_mu),
94+
rtol=0.1, atol=0.1)
95+
96+
def test_mixture_list_of_poissons(self):
97+
with Model() as model:
98+
w = Dirichlet('w', np.ones_like(self.pois_w))
99+
100+
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
101+
102+
x_obs = Mixture('x_obs', w,
103+
[Poisson.dist(mu[0]), Poisson.dist(mu[1])],
104+
observed=self.pois_x)
105+
106+
step = Metropolis()
107+
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)
108+
109+
assert_allclose(np.sort(trace['w'].mean(axis=0)),
110+
np.sort(self.pois_w),
111+
rtol=0.1, atol=0.1)
112+
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
113+
np.sort(self.pois_mu),
114+
rtol=0.1, atol=0.1)

0 commit comments

Comments
 (0)