Skip to content

Commit 4bba2f0

Browse files
authored
Merge pull request #2344 from bwengals/cov_diags
Methods to only compute diagonal of covariance functions
2 parents 326756e + 2303379 commit 4bba2f0

File tree

2 files changed

+140
-23
lines changed

2 files changed

+140
-23
lines changed

pymc3/gp/cov.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import theano.tensor as tt
22
import numpy as np
33
from functools import reduce
4+
from theano import Variable
45

56
__all__ = ['ExpQuad',
67
'RatQuad',
@@ -35,16 +36,20 @@ def __init__(self, input_dim, active_dims=None):
3536
if len(active_dims) != input_dim:
3637
raise ValueError("Length of active_dims must match input_dim")
3738

38-
def __call__(self, X, Z):
39+
def __call__(self, X, Z=None, diag=False):
3940
R"""
4041
Evaluate the kernel/covariance function.
4142
4243
Parameters
4344
----------
4445
X : The training inputs to the kernel.
4546
Z : The optional prediction set of inputs the kernel. If Z is None, Z = X.
47+
diag: Return only the diagonal of the covariance function. Default is False.
4648
"""
47-
raise NotImplementedError
49+
if diag:
50+
return self.diag(X)
51+
else:
52+
return self.full(X, Z)
4853

4954
def _slice(self, X, Z):
5055
X = X[:, self.active_dims]
@@ -93,17 +98,29 @@ def __init__(self, factor_list):
9398
else:
9499
self.factor_list.append(factor)
95100

101+
def merge_factors(self, X, Z=None, diag=False):
102+
factor_list = []
103+
for factor in self.factor_list:
104+
if isinstance(factor, Covariance):
105+
factor_list.append(factor(X, Z, diag))
106+
elif hasattr(factor, "ndim"):
107+
if diag:
108+
factor_list.append(tt.diag(factor))
109+
else:
110+
factor_list.append(factor)
111+
else:
112+
factor_list.append(factor)
113+
return factor_list
114+
96115

97116
class Add(Combination):
98-
def __call__(self, X, Z=None):
99-
return reduce((lambda x, y: x + y),
100-
[k(X, Z) if isinstance(k, Covariance) else k for k in self.factor_list])
117+
def __call__(self, X, Z=None, diag=False):
118+
return reduce((lambda x, y: x + y), self.merge_factors(X, Z, diag))
101119

102120

103121
class Prod(Combination):
104-
def __call__(self, X, Z=None):
105-
return reduce((lambda x, y: x * y),
106-
[k(X, Z) if isinstance(k, Covariance) else k for k in self.factor_list])
122+
def __call__(self, X, Z=None, diag=False):
123+
return reduce((lambda x, y: x * y), self.merge_factors(X, Z, diag))
107124

108125

109126
class Stationary(Covariance):
@@ -137,6 +154,12 @@ def euclidean_dist(self, X, Z):
137154
r2 = self.square_dist(X, Z)
138155
return tt.sqrt(r2 + 1e-12)
139156

157+
def diag(self, X):
158+
return tt.ones(tt.stack([X.shape[0], ]))
159+
160+
def full(self, X, Z=None):
161+
raise NotImplementedError
162+
140163

141164
class ExpQuad(Stationary):
142165
R"""
@@ -148,7 +171,7 @@ class ExpQuad(Stationary):
148171
k(x, x') = \mathrm{exp}\left[ -\frac{(x - x')^2}{2 \ell^2} \right]
149172
"""
150173

151-
def __call__(self, X, Z=None):
174+
def full(self, X, Z=None):
152175
X, Z = self._slice(X, Z)
153176
return tt.exp( -0.5 * self.square_dist(X, Z))
154177

@@ -166,7 +189,7 @@ def __init__(self, input_dim, lengthscales, alpha, active_dims=None):
166189
super(RatQuad, self).__init__(input_dim, lengthscales, active_dims)
167190
self.alpha = alpha
168191

169-
def __call__(self, X, Z=None):
192+
def full(self, X, Z=None):
170193
X, Z = self._slice(X, Z)
171194
return tt.power((1.0 + 0.5 * self.square_dist(X, Z) * (1.0 / self.alpha)), -1.0 * self.alpha)
172195

@@ -180,7 +203,7 @@ class Matern52(Stationary):
180203
k(x, x') = \left(1 + \frac{\sqrt{5(x - x')^2}}{\ell} + \frac{5(x-x')^2}{3\ell^2}\right) \mathrm{exp}\left[ - \frac{\sqrt{5(x - x')^2}}{\ell} \right]
181204
"""
182205

183-
def __call__(self, X, Z=None):
206+
def full(self, X, Z=None):
184207
X, Z = self._slice(X, Z)
185208
r = self.euclidean_dist(X, Z)
186209
return (1.0 + np.sqrt(5.0) * r + 5.0 / 3.0 * tt.square(r)) * tt.exp(-1.0 * np.sqrt(5.0) * r)
@@ -195,7 +218,7 @@ class Matern32(Stationary):
195218
k(x, x') = \left(1 + \frac{\sqrt{3(x - x')^2}}{\ell}\right)\mathrm{exp}\left[ - \frac{\sqrt{3(x - x')^2}}{\ell} \right]
196219
"""
197220

198-
def __call__(self, X, Z=None):
221+
def full(self, X, Z=None):
199222
X, Z = self._slice(X, Z)
200223
r = self.euclidean_dist(X, Z)
201224
return (1.0 + np.sqrt(3.0) * r) * tt.exp(-np.sqrt(3.0) * r)
@@ -210,7 +233,7 @@ class Exponential(Stationary):
210233
k(x, x') = \mathrm{exp}\left[ -\frac{||x - x'||}{2\ell^2} \right]
211234
"""
212235

213-
def __call__(self, X, Z=None):
236+
def full(self, X, Z=None):
214237
X, Z = self._slice(X, Z)
215238
return tt.exp(-0.5 * self.euclidean_dist(X, Z))
216239

@@ -223,7 +246,7 @@ class Cosine(Stationary):
223246
k(x, x') = \mathrm{cos}\left( \frac{||x - x'||}{ \ell^2} \right)
224247
"""
225248

226-
def __call__(self, X, Z=None):
249+
def full(self, X, Z=None):
227250
X, Z = self._slice(X, Z)
228251
return tt.cos(np.pi * self.euclidean_dist(X, Z))
229252

@@ -240,15 +263,22 @@ def __init__(self, input_dim, c, active_dims=None):
240263
super(Linear, self).__init__(input_dim, active_dims)
241264
self.c = c
242265

243-
def __call__(self, X, Z=None):
266+
def _common(self, X, Z=None):
244267
X, Z = self._slice(X, Z)
245268
Xc = tt.sub(X, self.c)
269+
return X, Xc, Z
270+
271+
def full(self, X, Z=None):
272+
X, Xc, Z = self._common(X, Z)
246273
if Z is None:
247274
return tt.dot(Xc, tt.transpose(Xc))
248275
else:
249276
Zc = tt.sub(Z, self.c)
250277
return tt.dot(Xc, tt.transpose(Zc))
251278

279+
def diag(self, X):
280+
X, Xc, _ = self._common(X, None)
281+
return tt.sum(tt.square(Xc), 1)
252282

253283
class Polynomial(Linear):
254284
R"""
@@ -263,10 +293,13 @@ def __init__(self, input_dim, c, d, offset, active_dims=None):
263293
self.d = d
264294
self.offset = offset
265295

266-
def __call__(self, X, Z=None):
267-
linear = super(Polynomial, self).__call__(X, Z)
296+
def full(self, X, Z=None):
297+
linear = super(Polynomial, self).full(X, Z)
268298
return tt.power(linear + self.offset, self.d)
269299

300+
def diag(self, X):
301+
linear = super(Polynomial, self).diag(X)
302+
return tt.power(linear + self.offset, self.d)
270303

271304
class WarpedInput(Covariance):
272305
R"""
@@ -295,13 +328,17 @@ def __init__(self, input_dim, cov_func, warp_func, args=None, active_dims=None):
295328
self.args = args
296329
self.cov_func = cov_func
297330

298-
def __call__(self, X, Z=None):
331+
def full(self, X, Z=None):
299332
X, Z = self._slice(X, Z)
300333
if Z is None:
301334
return self.cov_func(self.w(X, self.args), Z)
302335
else:
303336
return self.cov_func(self.w(X, self.args), self.w(Z, self.args))
304337

338+
def diag(self, X):
339+
X, _ = self._slice(X, None)
340+
return self.cov_func(self.w(X, self.args), diag=True)
341+
305342

306343
class Gibbs(Covariance):
307344
R"""
@@ -329,7 +366,7 @@ def __init__(self, input_dim, lengthscale_func, args=None, active_dims=None):
329366
raise NotImplementedError("Higher dimensional inputs are untested")
330367
if not callable(lengthscale_func):
331368
raise TypeError("lengthscale_func must be callable")
332-
self.ell = handle_args(lengthscale_func, args)
369+
self.lfunc = handle_args(lengthscale_func, args)
333370
self.args = args
334371

335372
def square_dist(self, X, Z):
@@ -345,20 +382,23 @@ def square_dist(self, X, Z):
345382
(tt.reshape(Xs, (-1, 1)) + tt.reshape(Zs, (1, -1)))
346383
return tt.clip(sqd, 0.0, np.inf)
347384

348-
def __call__(self, X, Z=None):
385+
def full(self, X, Z=None):
349386
X, Z = self._slice(X, Z)
350-
rx = self.ell(X, self.args)
387+
rx = self.lfunc(X, self.args)
351388
rx2 = tt.reshape(tt.square(rx), (-1, 1))
352389
if Z is None:
353390
r2 = self.square_dist(X,X)
354-
rz = self.ell(X, self.args)
391+
rz = self.lfunc(X, self.args)
355392
else:
356393
r2 = self.square_dist(X,Z)
357-
rz = self.ell(Z, self.args)
394+
rz = self.lfunc(Z, self.args)
358395
rz2 = tt.reshape(tt.square(rz), (1, -1))
359396
return tt.sqrt((2.0 * tt.dot(rx, tt.transpose(rz))) / (rx2 + rz2)) *\
360397
tt.exp(-1.0 * r2 / (rx2 + rz2))
361398

399+
def diag(self, X):
400+
return tt.ones(tt.stack([X.shape[0], ]))
401+
362402

363403
def handle_args(func, args):
364404
def f(x, args):

0 commit comments

Comments
 (0)