Skip to content

Commit 5c740d4

Browse files
authored
Merge pull request #2342 from bwengals/gp-sample-fix2
Fix for gp.sample_gp (redone)
2 parents 4bba2f0 + 5e6e299 commit 5c740d4

File tree

2 files changed

+34
-24
lines changed

2 files changed

+34
-24
lines changed

pymc3/gp/gp.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,47 +15,47 @@
1515

1616
class GP(Continuous):
1717
"""Gausian process
18-
18+
1919
Parameters
2020
----------
2121
mean_func : Mean
2222
Mean function of Gaussian process
2323
cov_func : Covariance
2424
Covariance function of Gaussian process
2525
X : array
26-
Grid of points to evaluate Gaussian process over. Only required if the
26+
Grid of points to evaluate Gaussian process over. Only required if the
2727
GP is not an observed variable.
2828
sigma : scalar or array
2929
Observation standard deviation (defaults to zero)
3030
"""
3131
def __init__(self, mean_func=None, cov_func=None, X=None, sigma=0, *args, **kwargs):
32-
32+
3333
if mean_func is None:
3434
self.M = Zero()
3535
else:
3636
if not isinstance(mean_func, Mean):
3737
raise ValueError('mean_func must be a subclass of Mean')
3838
self.M = mean_func
39-
39+
4040
if cov_func is None:
4141
raise ValueError('A covariance function must be specified for GPP')
4242
if not isinstance(cov_func, Covariance):
4343
raise ValueError('cov_func must be a subclass of Covariance')
4444
self.K = cov_func
45-
45+
4646
self.sigma = sigma
47-
47+
4848
if X is not None:
4949
self.X = X
5050
self.mean = self.mode = self.M(X)
5151
kwargs.setdefault("shape", X.squeeze().shape)
52-
52+
5353
super(GP, self).__init__(*args, **kwargs)
54-
54+
5555
def random(self, point=None, size=None, **kwargs):
5656
X = self.X
5757
mu, cov = draw_values([self.M(X).squeeze(), self.K(X) + np.eye(X.shape[0])*self.sigma**2], point=point)
58-
58+
5959
def _random(mean, cov, size=None):
6060
return stats.multivariate_normal.rvs(
6161
mean, cov, None if size == mean.shape else size)
@@ -74,9 +74,9 @@ def logp(self, Y, X=None):
7474
Sigma = self.K(X) + tt.eye(X.shape[0])*self.sigma**2
7575

7676
return MvNormal.dist(mu, Sigma).logp(Y)
77-
7877

79-
def sample_gp(trace, gp, X_values, samples=None, obs_noise=True, model=None, random_seed=None, progressbar=True):
78+
79+
def sample_gp(trace, gp, X_values, samples=None, obs_noise=True, model=None, random_seed=None, progressbar=True, chol_const=True):
8080
"""Generate samples from a posterior Gaussian process.
8181
8282
Parameters
@@ -92,38 +92,41 @@ def sample_gp(trace, gp, X_values, samples=None, obs_noise=True, model=None, ran
9292
length of `trace`
9393
obs_noise : bool
9494
Flag for including observation noise in sample. Defaults to True.
95-
model : Model
95+
model : Model
9696
Model used to generate `trace`. Optional if in `with` context manager.
9797
random_seed : integer > 0
9898
Random number seed for sampling.
9999
progressbar : bool
100100
Flag for showing progress bar.
101-
101+
chol_const : bool
102+
Flag to a small diagonal to the posterior covariance
103+
for numerical stability
104+
102105
Returns
103106
-------
104107
Array of samples from posterior GP evaluated at Z.
105108
"""
106109
model = modelcontext(model)
107-
110+
108111
if samples is None:
109112
samples = len(trace)
110-
113+
111114
if random_seed:
112115
np.random.seed(random_seed)
113-
116+
114117
if progressbar:
115118
indices = tqdm(np.random.randint(0, len(trace), samples), total=samples)
116119
else:
117120
indices = np.random.randint(0, len(trace), samples)
118121

119-
K = gp.distribution.K
120-
122+
K = gp.distribution.K
123+
121124
data = [v for v in model.observed_RVs if v.name==gp.name][0].data
122125

123126
X = data['X']
124127
Y = data['Y']
125128
Z = X_values
126-
129+
127130
S_xz = K(X, Z)
128131
S_zz = K(Z)
129132
if obs_noise:
@@ -136,8 +139,10 @@ def sample_gp(trace, gp, X_values, samples=None, obs_noise=True, model=None, ran
136139
# Posterior covariance
137140
S_post = S_zz - tt.dot(tt.dot(S_xz.T, S_inv), S_xz)
138141

139-
gp_post = MvNormal.dist(m_post, S_post, shape=Z.shape[0])
140-
142+
if chol_const:
143+
n = S_post.shape[0]
144+
correction = 1e-6 * tt.nlinalg.trace(S_post) * tt.eye(n)
145+
146+
gp_post = MvNormal.dist(m_post, S_post + correction, shape=Z.shape[0])
141147
samples = [gp_post.random(point=trace[idx]) for idx in indices]
142-
143148
return np.array(samples)

pymc3/tests/test_gp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy.testing as npt
88
import pytest
99

10-
1110
class TestZeroMean(object):
1211
def test_value(self):
1312
X = np.linspace(0, 1, 10)[:, None]
@@ -479,7 +478,7 @@ def test_func_args(self):
479478

480479
def test_sample(self):
481480
X = np.linspace(0, 1, 10)[:, None]
482-
Y = np.random.randn(10, 1)
481+
Y = np.random.randn(10)
483482
with Model() as model:
484483
M = gp.mean.Zero()
485484
l = Uniform('l', 0, 5)
@@ -488,3 +487,9 @@ def test_sample(self):
488487
# make a Gaussian model
489488
random_test = gp.GP('random_test', mean_func=M, cov_func=K, sigma=sigma, observed={'X':X, 'Y':Y})
490489
tr = sample(20, init=None, progressbar=False, random_seed=self.random_seed)
490+
491+
# test prediction
492+
Z = np.linspace(0, 1, 5)[:, None]
493+
with model:
494+
out = gp.sample_gp(tr[-3:], gp=random_test, X_values=Z, obs_noise=False,
495+
random_seed=self.random_seed, progressbar=False, chol_const=True)

0 commit comments

Comments
 (0)