Skip to content

Commit 903a797

Browse files
committed
adds methods to cov functions to compute only diagonals, adds tests
1 parent 3ba16fb commit 903a797

File tree

2 files changed

+140
-23
lines changed

2 files changed

+140
-23
lines changed

pymc3/gp/cov.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import theano.tensor as tt
22
import numpy as np
33
from functools import reduce
4+
from theano import Variable
45

56
__all__ = ['ExpQuad',
67
'RatQuad',
@@ -35,16 +36,20 @@ def __init__(self, input_dim, active_dims=None):
3536
if len(active_dims) != input_dim:
3637
raise ValueError("Length of active_dims must match input_dim")
3738

38-
def __call__(self, X, Z):
39+
def __call__(self, X, Z=None, diag=False):
3940
R"""
4041
Evaluate the kernel/covariance function.
4142
4243
Parameters
4344
----------
4445
X : The training inputs to the kernel.
4546
Z : The optional prediction set of inputs the kernel. If Z is None, Z = X.
47+
daig: Return only the diagonal of the covariance function. Default is False.
4648
"""
47-
raise NotImplementedError
49+
if diag:
50+
return self.diag(X)
51+
else:
52+
return self.full(X, Z)
4853

4954
def _slice(self, X, Z):
5055
X = X[:, self.active_dims]
@@ -93,17 +98,29 @@ def __init__(self, factor_list):
9398
else:
9499
self.factor_list.append(factor)
95100

101+
def merge_factors(self, X, Z=None, diag=False):
102+
factors = []
103+
for factor in self.factor_list:
104+
if isinstance(factor, Covariance):
105+
factors.append(factor(X, Z, diag))
106+
elif hasattr(factor, "ndim"):
107+
if diag:
108+
factors.append(tt.diag(factor))
109+
else:
110+
factors.append(factor)
111+
else:
112+
factors.append(factor)
113+
return factors
114+
96115

97116
class Add(Combination):
98-
def __call__(self, X, Z=None):
99-
return reduce((lambda x, y: x + y),
100-
[k(X, Z) if isinstance(k, Covariance) else k for k in self.factor_list])
117+
def __call__(self, X, Z=None, diag=False):
118+
return reduce((lambda x, y: x + y), self.merge_factors(X, Z, diag))
101119

102120

103121
class Prod(Combination):
104-
def __call__(self, X, Z=None):
105-
return reduce((lambda x, y: x * y),
106-
[k(X, Z) if isinstance(k, Covariance) else k for k in self.factor_list])
122+
def __call__(self, X, Z=None, diag=False):
123+
return reduce((lambda x, y: x * y), self.merge_factors(X, Z, diag))
107124

108125

109126
class Stationary(Covariance):
@@ -139,6 +156,12 @@ def euclidean_dist(self, X, Z):
139156
r2 = self.square_dist(X, Z)
140157
return tt.sqrt(r2 + 1e-12)
141158

159+
def diag(self, X):
160+
return tt.ones(tt.stack([X.shape[0], ]))
161+
162+
def full(self, X, Z=None):
163+
raise NotImplementedError
164+
142165

143166
class ExpQuad(Stationary):
144167
R"""
@@ -150,7 +173,7 @@ class ExpQuad(Stationary):
150173
k(x, x') = \mathrm{exp}\left[ -\frac{(x - x')^2}{2 \ell^2} \right]
151174
"""
152175

153-
def __call__(self, X, Z=None):
176+
def full(self, X, Z=None):
154177
X, Z = self._slice(X, Z)
155178
return tt.exp( -0.5 * self.square_dist(X, Z))
156179

@@ -169,7 +192,7 @@ def __init__(self, input_dim, lengthscales, alpha, active_dims=None):
169192
self.lengthscales = lengthscales
170193
self.alpha = alpha
171194

172-
def __call__(self, X, Z=None):
195+
def full(self, X, Z=None):
173196
X, Z = self._slice(X, Z)
174197
return tt.power((1.0 + 0.5 * self.square_dist(X, Z) * (1.0 / self.alpha)), -1.0 * self.alpha)
175198

@@ -183,7 +206,7 @@ class Matern52(Stationary):
183206
k(x, x') = \left(1 + \frac{\sqrt{5(x - x')^2}}{\ell} + \frac{5(x-x')^2}{3\ell^2}\right) \mathrm{exp}\left[ - \frac{\sqrt{5(x - x')^2}}{\ell} \right]
184207
"""
185208

186-
def __call__(self, X, Z=None):
209+
def full(self, X, Z=None):
187210
X, Z = self._slice(X, Z)
188211
r = self.euclidean_dist(X, Z)
189212
return (1.0 + np.sqrt(5.0) * r + 5.0 / 3.0 * tt.square(r)) * tt.exp(-1.0 * np.sqrt(5.0) * r)
@@ -198,7 +221,7 @@ class Matern32(Stationary):
198221
k(x, x') = \left(1 + \frac{\sqrt{3(x - x')^2}}{\ell}\right)\mathrm{exp}\left[ - \frac{\sqrt{3(x - x')^2}}{\ell} \right]
199222
"""
200223

201-
def __call__(self, X, Z=None):
224+
def full(self, X, Z=None):
202225
X, Z = self._slice(X, Z)
203226
r = self.euclidean_dist(X, Z)
204227
return (1.0 + np.sqrt(3.0) * r) * tt.exp(-np.sqrt(3.0) * r)
@@ -213,7 +236,7 @@ class Exponential(Stationary):
213236
k(x, x') = \mathrm{exp}\left[ -\frac{||x - x'||}{2\ell^2} \right]
214237
"""
215238

216-
def __call__(self, X, Z=None):
239+
def full(self, X, Z=None):
217240
X, Z = self._slice(X, Z)
218241
return tt.exp(-0.5 * self.euclidean_dist(X, Z))
219242

@@ -226,7 +249,7 @@ class Cosine(Stationary):
226249
k(x, x') = \mathrm{cos}\left( \frac{||x - x'||}{ \ell^2} \right)
227250
"""
228251

229-
def __call__(self, X, Z=None):
252+
def full(self, X, Z=None):
230253
X, Z = self._slice(X, Z)
231254
return tt.cos(np.pi * self.euclidean_dist(X, Z))
232255

@@ -243,15 +266,22 @@ def __init__(self, input_dim, c, active_dims=None):
243266
Covariance.__init__(self, input_dim, active_dims)
244267
self.c = c
245268

246-
def __call__(self, X, Z=None):
269+
def _common(self, X, Z=None):
247270
X, Z = self._slice(X, Z)
248271
Xc = tt.sub(X, self.c)
272+
return X, Xc, Z
273+
274+
def full(self, X, Z=None):
275+
X, Xc, Z = self._common(X, Z)
249276
if Z is None:
250277
return tt.dot(Xc, tt.transpose(Xc))
251278
else:
252279
Zc = tt.sub(Z, self.c)
253280
return tt.dot(Xc, tt.transpose(Zc))
254281

282+
def diag(self, X):
283+
X, Xc, _ = self._common(X, None)
284+
return tt.sum(tt.square(Xc), 1)
255285

256286
class Polynomial(Linear):
257287
R"""
@@ -266,10 +296,13 @@ def __init__(self, input_dim, c, d, offset, active_dims=None):
266296
self.d = d
267297
self.offset = offset
268298

269-
def __call__(self, X, Z=None):
270-
linear = super(Polynomial, self).__call__(X, Z)
299+
def full(self, X, Z=None):
300+
linear = super(Polynomial, self).full(X, Z)
271301
return tt.power(linear + self.offset, self.d)
272302

303+
def diag(self, X):
304+
linear = super(Polynomial, self).diag(X)
305+
return tt.power(linear + self.offset, self.d)
273306

274307
class WarpedInput(Covariance):
275308
R"""
@@ -298,13 +331,17 @@ def __init__(self, input_dim, cov_func, warp_func, args=None, active_dims=None):
298331
self.args = args
299332
self.cov_func = cov_func
300333

301-
def __call__(self, X, Z=None):
334+
def full(self, X, Z=None):
302335
X, Z = self._slice(X, Z)
303336
if Z is None:
304337
return self.cov_func(self.w(X, self.args), Z)
305338
else:
306339
return self.cov_func(self.w(X, self.args), self.w(Z, self.args))
307340

341+
def diag(self, X):
342+
X, _ = self._slice(X, None)
343+
return self.cov_func(self.w(X, self.args), diag=True)
344+
308345

309346
class Gibbs(Covariance):
310347
R"""
@@ -332,7 +369,7 @@ def __init__(self, input_dim, lengthscale_func, args=None, active_dims=None):
332369
raise NotImplementedError("Higher dimensional inputs are untested")
333370
if not callable(lengthscale_func):
334371
raise TypeError("lengthscale_func must be callable")
335-
self.ell = handle_args(lengthscale_func, args)
372+
self.lfunc = handle_args(lengthscale_func, args)
336373
self.args = args
337374

338375
def square_dist(self, X, Z):
@@ -348,20 +385,23 @@ def square_dist(self, X, Z):
348385
(tt.reshape(Xs, (-1, 1)) + tt.reshape(Zs, (1, -1)))
349386
return tt.clip(sqd, 0.0, np.inf)
350387

351-
def __call__(self, X, Z=None):
388+
def full(self, X, Z=None):
352389
X, Z = self._slice(X, Z)
353-
rx = self.ell(X, self.args)
390+
rx = self.lfunc(X, self.args)
354391
rx2 = tt.reshape(tt.square(rx), (-1, 1))
355392
if Z is None:
356393
r2 = self.square_dist(X,X)
357-
rz = self.ell(X, self.args)
394+
rz = self.lfunc(X, self.args)
358395
else:
359396
r2 = self.square_dist(X,Z)
360-
rz = self.ell(Z, self.args)
397+
rz = self.lfunc(Z, self.args)
361398
rz2 = tt.reshape(tt.square(rz), (1, -1))
362399
return tt.sqrt((2.0 * tt.dot(rx, tt.transpose(rz))) / (rx2 + rz2)) *\
363400
tt.exp(-1.0 * r2 / (rx2 + rz2))
364401

402+
def diag(self, X):
403+
return tt.ones(tt.stack([X.shape[0], ]))
404+
365405

366406
def handle_args(func, args):
367407
def f(x, args):

0 commit comments

Comments
 (0)