1
1
import theano .tensor as tt
2
2
import numpy as np
3
3
from functools import reduce
4
+ from theano import Variable
4
5
5
6
__all__ = ['ExpQuad' ,
6
7
'RatQuad' ,
@@ -35,16 +36,20 @@ def __init__(self, input_dim, active_dims=None):
35
36
if len (active_dims ) != input_dim :
36
37
raise ValueError ("Length of active_dims must match input_dim" )
37
38
38
- def __call__ (self , X , Z ):
39
+ def __call__ (self , X , Z = None , diag = False ):
39
40
R"""
40
41
Evaluate the kernel/covariance function.
41
42
42
43
Parameters
43
44
----------
44
45
X : The training inputs to the kernel.
45
46
Z : The optional prediction set of inputs the kernel. If Z is None, Z = X.
47
+ daig: Return only the diagonal of the covariance function. Default is False.
46
48
"""
47
- raise NotImplementedError
49
+ if diag :
50
+ return self .diag (X )
51
+ else :
52
+ return self .full (X , Z )
48
53
49
54
def _slice (self , X , Z ):
50
55
X = X [:, self .active_dims ]
@@ -93,17 +98,29 @@ def __init__(self, factor_list):
93
98
else :
94
99
self .factor_list .append (factor )
95
100
101
+ def merge_factors (self , X , Z = None , diag = False ):
102
+ factors = []
103
+ for factor in self .factor_list :
104
+ if isinstance (factor , Covariance ):
105
+ factors .append (factor (X , Z , diag ))
106
+ elif hasattr (factor , "ndim" ):
107
+ if diag :
108
+ factors .append (tt .diag (factor ))
109
+ else :
110
+ factors .append (factor )
111
+ else :
112
+ factors .append (factor )
113
+ return factors
114
+
96
115
97
116
class Add (Combination ):
98
- def __call__ (self , X , Z = None ):
99
- return reduce ((lambda x , y : x + y ),
100
- [k (X , Z ) if isinstance (k , Covariance ) else k for k in self .factor_list ])
117
+ def __call__ (self , X , Z = None , diag = False ):
118
+ return reduce ((lambda x , y : x + y ), self .merge_factors (X , Z , diag ))
101
119
102
120
103
121
class Prod (Combination ):
104
- def __call__ (self , X , Z = None ):
105
- return reduce ((lambda x , y : x * y ),
106
- [k (X , Z ) if isinstance (k , Covariance ) else k for k in self .factor_list ])
122
+ def __call__ (self , X , Z = None , diag = False ):
123
+ return reduce ((lambda x , y : x * y ), self .merge_factors (X , Z , diag ))
107
124
108
125
109
126
class Stationary (Covariance ):
@@ -139,6 +156,12 @@ def euclidean_dist(self, X, Z):
139
156
r2 = self .square_dist (X , Z )
140
157
return tt .sqrt (r2 + 1e-12 )
141
158
159
+ def diag (self , X ):
160
+ return tt .ones (tt .stack ([X .shape [0 ], ]))
161
+
162
+ def full (self , X , Z = None ):
163
+ raise NotImplementedError
164
+
142
165
143
166
class ExpQuad (Stationary ):
144
167
R"""
@@ -150,7 +173,7 @@ class ExpQuad(Stationary):
150
173
k(x, x') = \mathrm{exp}\left[ -\frac{(x - x')^2}{2 \ell^2} \right]
151
174
"""
152
175
153
- def __call__ (self , X , Z = None ):
176
+ def full (self , X , Z = None ):
154
177
X , Z = self ._slice (X , Z )
155
178
return tt .exp ( - 0.5 * self .square_dist (X , Z ))
156
179
@@ -169,7 +192,7 @@ def __init__(self, input_dim, lengthscales, alpha, active_dims=None):
169
192
self .lengthscales = lengthscales
170
193
self .alpha = alpha
171
194
172
- def __call__ (self , X , Z = None ):
195
+ def full (self , X , Z = None ):
173
196
X , Z = self ._slice (X , Z )
174
197
return tt .power ((1.0 + 0.5 * self .square_dist (X , Z ) * (1.0 / self .alpha )), - 1.0 * self .alpha )
175
198
@@ -183,7 +206,7 @@ class Matern52(Stationary):
183
206
k(x, x') = \left(1 + \frac{\sqrt{5(x - x')^2}}{\ell} + \frac{5(x-x')^2}{3\ell^2}\right) \mathrm{exp}\left[ - \frac{\sqrt{5(x - x')^2}}{\ell} \right]
184
207
"""
185
208
186
- def __call__ (self , X , Z = None ):
209
+ def full (self , X , Z = None ):
187
210
X , Z = self ._slice (X , Z )
188
211
r = self .euclidean_dist (X , Z )
189
212
return (1.0 + np .sqrt (5.0 ) * r + 5.0 / 3.0 * tt .square (r )) * tt .exp (- 1.0 * np .sqrt (5.0 ) * r )
@@ -198,7 +221,7 @@ class Matern32(Stationary):
198
221
k(x, x') = \left(1 + \frac{\sqrt{3(x - x')^2}}{\ell}\right)\mathrm{exp}\left[ - \frac{\sqrt{3(x - x')^2}}{\ell} \right]
199
222
"""
200
223
201
- def __call__ (self , X , Z = None ):
224
+ def full (self , X , Z = None ):
202
225
X , Z = self ._slice (X , Z )
203
226
r = self .euclidean_dist (X , Z )
204
227
return (1.0 + np .sqrt (3.0 ) * r ) * tt .exp (- np .sqrt (3.0 ) * r )
@@ -213,7 +236,7 @@ class Exponential(Stationary):
213
236
k(x, x') = \mathrm{exp}\left[ -\frac{||x - x'||}{2\ell^2} \right]
214
237
"""
215
238
216
- def __call__ (self , X , Z = None ):
239
+ def full (self , X , Z = None ):
217
240
X , Z = self ._slice (X , Z )
218
241
return tt .exp (- 0.5 * self .euclidean_dist (X , Z ))
219
242
@@ -226,7 +249,7 @@ class Cosine(Stationary):
226
249
k(x, x') = \mathrm{cos}\left( \frac{||x - x'||}{ \ell^2} \right)
227
250
"""
228
251
229
- def __call__ (self , X , Z = None ):
252
+ def full (self , X , Z = None ):
230
253
X , Z = self ._slice (X , Z )
231
254
return tt .cos (np .pi * self .euclidean_dist (X , Z ))
232
255
@@ -243,15 +266,22 @@ def __init__(self, input_dim, c, active_dims=None):
243
266
Covariance .__init__ (self , input_dim , active_dims )
244
267
self .c = c
245
268
246
- def __call__ (self , X , Z = None ):
269
+ def _common (self , X , Z = None ):
247
270
X , Z = self ._slice (X , Z )
248
271
Xc = tt .sub (X , self .c )
272
+ return X , Xc , Z
273
+
274
+ def full (self , X , Z = None ):
275
+ X , Xc , Z = self ._common (X , Z )
249
276
if Z is None :
250
277
return tt .dot (Xc , tt .transpose (Xc ))
251
278
else :
252
279
Zc = tt .sub (Z , self .c )
253
280
return tt .dot (Xc , tt .transpose (Zc ))
254
281
282
+ def diag (self , X ):
283
+ X , Xc , _ = self ._common (X , None )
284
+ return tt .sum (tt .square (Xc ), 1 )
255
285
256
286
class Polynomial (Linear ):
257
287
R"""
@@ -266,10 +296,13 @@ def __init__(self, input_dim, c, d, offset, active_dims=None):
266
296
self .d = d
267
297
self .offset = offset
268
298
269
- def __call__ (self , X , Z = None ):
270
- linear = super (Polynomial , self ).__call__ (X , Z )
299
+ def full (self , X , Z = None ):
300
+ linear = super (Polynomial , self ).full (X , Z )
271
301
return tt .power (linear + self .offset , self .d )
272
302
303
+ def diag (self , X ):
304
+ linear = super (Polynomial , self ).diag (X )
305
+ return tt .power (linear + self .offset , self .d )
273
306
274
307
class WarpedInput (Covariance ):
275
308
R"""
@@ -298,13 +331,17 @@ def __init__(self, input_dim, cov_func, warp_func, args=None, active_dims=None):
298
331
self .args = args
299
332
self .cov_func = cov_func
300
333
301
- def __call__ (self , X , Z = None ):
334
+ def full (self , X , Z = None ):
302
335
X , Z = self ._slice (X , Z )
303
336
if Z is None :
304
337
return self .cov_func (self .w (X , self .args ), Z )
305
338
else :
306
339
return self .cov_func (self .w (X , self .args ), self .w (Z , self .args ))
307
340
341
+ def diag (self , X ):
342
+ X , _ = self ._slice (X , None )
343
+ return self .cov_func (self .w (X , self .args ), diag = True )
344
+
308
345
309
346
class Gibbs (Covariance ):
310
347
R"""
@@ -332,7 +369,7 @@ def __init__(self, input_dim, lengthscale_func, args=None, active_dims=None):
332
369
raise NotImplementedError ("Higher dimensional inputs are untested" )
333
370
if not callable (lengthscale_func ):
334
371
raise TypeError ("lengthscale_func must be callable" )
335
- self .ell = handle_args (lengthscale_func , args )
372
+ self .lfunc = handle_args (lengthscale_func , args )
336
373
self .args = args
337
374
338
375
def square_dist (self , X , Z ):
@@ -348,20 +385,23 @@ def square_dist(self, X, Z):
348
385
(tt .reshape (Xs , (- 1 , 1 )) + tt .reshape (Zs , (1 , - 1 )))
349
386
return tt .clip (sqd , 0.0 , np .inf )
350
387
351
- def __call__ (self , X , Z = None ):
388
+ def full (self , X , Z = None ):
352
389
X , Z = self ._slice (X , Z )
353
- rx = self .ell (X , self .args )
390
+ rx = self .lfunc (X , self .args )
354
391
rx2 = tt .reshape (tt .square (rx ), (- 1 , 1 ))
355
392
if Z is None :
356
393
r2 = self .square_dist (X ,X )
357
- rz = self .ell (X , self .args )
394
+ rz = self .lfunc (X , self .args )
358
395
else :
359
396
r2 = self .square_dist (X ,Z )
360
- rz = self .ell (Z , self .args )
397
+ rz = self .lfunc (Z , self .args )
361
398
rz2 = tt .reshape (tt .square (rz ), (1 , - 1 ))
362
399
return tt .sqrt ((2.0 * tt .dot (rx , tt .transpose (rz ))) / (rx2 + rz2 )) * \
363
400
tt .exp (- 1.0 * r2 / (rx2 + rz2 ))
364
401
402
+ def diag (self , X ):
403
+ return tt .ones (tt .stack ([X .shape [0 ], ]))
404
+
365
405
366
406
def handle_args (func , args ):
367
407
def f (x , args ):
0 commit comments