1
1
import theano .tensor as tt
2
2
import numpy as np
3
3
from functools import reduce
4
+ from theano import Variable
4
5
5
6
__all__ = ['ExpQuad' ,
6
7
'RatQuad' ,
@@ -35,16 +36,20 @@ def __init__(self, input_dim, active_dims=None):
35
36
if len (active_dims ) != input_dim :
36
37
raise ValueError ("Length of active_dims must match input_dim" )
37
38
38
- def __call__ (self , X , Z ):
39
+ def __call__ (self , X , Z = None , diag = False ):
39
40
R"""
40
41
Evaluate the kernel/covariance function.
41
42
42
43
Parameters
43
44
----------
44
45
X : The training inputs to the kernel.
45
46
Z : The optional prediction set of inputs the kernel. If Z is None, Z = X.
47
+ diag: Return only the diagonal of the covariance function. Default is False.
46
48
"""
47
- raise NotImplementedError
49
+ if diag :
50
+ return self .diag (X )
51
+ else :
52
+ return self .full (X , Z )
48
53
49
54
def _slice (self , X , Z ):
50
55
X = X [:, self .active_dims ]
@@ -93,17 +98,29 @@ def __init__(self, factor_list):
93
98
else :
94
99
self .factor_list .append (factor )
95
100
101
+ def merge_factors (self , X , Z = None , diag = False ):
102
+ factor_list = []
103
+ for factor in self .factor_list :
104
+ if isinstance (factor , Covariance ):
105
+ factor_list .append (factor (X , Z , diag ))
106
+ elif hasattr (factor , "ndim" ):
107
+ if diag :
108
+ factor_list .append (tt .diag (factor ))
109
+ else :
110
+ factor_list .append (factor )
111
+ else :
112
+ factor_list .append (factor )
113
+ return factor_list
114
+
96
115
97
116
class Add (Combination ):
98
- def __call__ (self , X , Z = None ):
99
- return reduce ((lambda x , y : x + y ),
100
- [k (X , Z ) if isinstance (k , Covariance ) else k for k in self .factor_list ])
117
+ def __call__ (self , X , Z = None , diag = False ):
118
+ return reduce ((lambda x , y : x + y ), self .merge_factors (X , Z , diag ))
101
119
102
120
103
121
class Prod (Combination ):
104
- def __call__ (self , X , Z = None ):
105
- return reduce ((lambda x , y : x * y ),
106
- [k (X , Z ) if isinstance (k , Covariance ) else k for k in self .factor_list ])
122
+ def __call__ (self , X , Z = None , diag = False ):
123
+ return reduce ((lambda x , y : x * y ), self .merge_factors (X , Z , diag ))
107
124
108
125
109
126
class Stationary (Covariance ):
@@ -137,6 +154,12 @@ def euclidean_dist(self, X, Z):
137
154
r2 = self .square_dist (X , Z )
138
155
return tt .sqrt (r2 + 1e-12 )
139
156
157
+ def diag (self , X ):
158
+ return tt .ones (tt .stack ([X .shape [0 ], ]))
159
+
160
+ def full (self , X , Z = None ):
161
+ raise NotImplementedError
162
+
140
163
141
164
class ExpQuad (Stationary ):
142
165
R"""
@@ -148,7 +171,7 @@ class ExpQuad(Stationary):
148
171
k(x, x') = \mathrm{exp}\left[ -\frac{(x - x')^2}{2 \ell^2} \right]
149
172
"""
150
173
151
- def __call__ (self , X , Z = None ):
174
+ def full (self , X , Z = None ):
152
175
X , Z = self ._slice (X , Z )
153
176
return tt .exp ( - 0.5 * self .square_dist (X , Z ))
154
177
@@ -166,7 +189,7 @@ def __init__(self, input_dim, lengthscales, alpha, active_dims=None):
166
189
super (RatQuad , self ).__init__ (input_dim , lengthscales , active_dims )
167
190
self .alpha = alpha
168
191
169
- def __call__ (self , X , Z = None ):
192
+ def full (self , X , Z = None ):
170
193
X , Z = self ._slice (X , Z )
171
194
return tt .power ((1.0 + 0.5 * self .square_dist (X , Z ) * (1.0 / self .alpha )), - 1.0 * self .alpha )
172
195
@@ -180,7 +203,7 @@ class Matern52(Stationary):
180
203
k(x, x') = \left(1 + \frac{\sqrt{5(x - x')^2}}{\ell} + \frac{5(x-x')^2}{3\ell^2}\right) \mathrm{exp}\left[ - \frac{\sqrt{5(x - x')^2}}{\ell} \right]
181
204
"""
182
205
183
- def __call__ (self , X , Z = None ):
206
+ def full (self , X , Z = None ):
184
207
X , Z = self ._slice (X , Z )
185
208
r = self .euclidean_dist (X , Z )
186
209
return (1.0 + np .sqrt (5.0 ) * r + 5.0 / 3.0 * tt .square (r )) * tt .exp (- 1.0 * np .sqrt (5.0 ) * r )
@@ -195,7 +218,7 @@ class Matern32(Stationary):
195
218
k(x, x') = \left(1 + \frac{\sqrt{3(x - x')^2}}{\ell}\right)\mathrm{exp}\left[ - \frac{\sqrt{3(x - x')^2}}{\ell} \right]
196
219
"""
197
220
198
- def __call__ (self , X , Z = None ):
221
+ def full (self , X , Z = None ):
199
222
X , Z = self ._slice (X , Z )
200
223
r = self .euclidean_dist (X , Z )
201
224
return (1.0 + np .sqrt (3.0 ) * r ) * tt .exp (- np .sqrt (3.0 ) * r )
@@ -210,7 +233,7 @@ class Exponential(Stationary):
210
233
k(x, x') = \mathrm{exp}\left[ -\frac{||x - x'||}{2\ell^2} \right]
211
234
"""
212
235
213
- def __call__ (self , X , Z = None ):
236
+ def full (self , X , Z = None ):
214
237
X , Z = self ._slice (X , Z )
215
238
return tt .exp (- 0.5 * self .euclidean_dist (X , Z ))
216
239
@@ -223,7 +246,7 @@ class Cosine(Stationary):
223
246
k(x, x') = \mathrm{cos}\left( \frac{||x - x'||}{ \ell^2} \right)
224
247
"""
225
248
226
- def __call__ (self , X , Z = None ):
249
+ def full (self , X , Z = None ):
227
250
X , Z = self ._slice (X , Z )
228
251
return tt .cos (np .pi * self .euclidean_dist (X , Z ))
229
252
@@ -240,15 +263,22 @@ def __init__(self, input_dim, c, active_dims=None):
240
263
super (Linear , self ).__init__ (input_dim , active_dims )
241
264
self .c = c
242
265
243
- def __call__ (self , X , Z = None ):
266
+ def _common (self , X , Z = None ):
244
267
X , Z = self ._slice (X , Z )
245
268
Xc = tt .sub (X , self .c )
269
+ return X , Xc , Z
270
+
271
+ def full (self , X , Z = None ):
272
+ X , Xc , Z = self ._common (X , Z )
246
273
if Z is None :
247
274
return tt .dot (Xc , tt .transpose (Xc ))
248
275
else :
249
276
Zc = tt .sub (Z , self .c )
250
277
return tt .dot (Xc , tt .transpose (Zc ))
251
278
279
+ def diag (self , X ):
280
+ X , Xc , _ = self ._common (X , None )
281
+ return tt .sum (tt .square (Xc ), 1 )
252
282
253
283
class Polynomial (Linear ):
254
284
R"""
@@ -263,10 +293,13 @@ def __init__(self, input_dim, c, d, offset, active_dims=None):
263
293
self .d = d
264
294
self .offset = offset
265
295
266
- def __call__ (self , X , Z = None ):
267
- linear = super (Polynomial , self ).__call__ (X , Z )
296
+ def full (self , X , Z = None ):
297
+ linear = super (Polynomial , self ).full (X , Z )
268
298
return tt .power (linear + self .offset , self .d )
269
299
300
+ def diag (self , X ):
301
+ linear = super (Polynomial , self ).diag (X )
302
+ return tt .power (linear + self .offset , self .d )
270
303
271
304
class WarpedInput (Covariance ):
272
305
R"""
@@ -295,13 +328,17 @@ def __init__(self, input_dim, cov_func, warp_func, args=None, active_dims=None):
295
328
self .args = args
296
329
self .cov_func = cov_func
297
330
298
- def __call__ (self , X , Z = None ):
331
+ def full (self , X , Z = None ):
299
332
X , Z = self ._slice (X , Z )
300
333
if Z is None :
301
334
return self .cov_func (self .w (X , self .args ), Z )
302
335
else :
303
336
return self .cov_func (self .w (X , self .args ), self .w (Z , self .args ))
304
337
338
+ def diag (self , X ):
339
+ X , _ = self ._slice (X , None )
340
+ return self .cov_func (self .w (X , self .args ), diag = True )
341
+
305
342
306
343
class Gibbs (Covariance ):
307
344
R"""
@@ -329,7 +366,7 @@ def __init__(self, input_dim, lengthscale_func, args=None, active_dims=None):
329
366
raise NotImplementedError ("Higher dimensional inputs are untested" )
330
367
if not callable (lengthscale_func ):
331
368
raise TypeError ("lengthscale_func must be callable" )
332
- self .ell = handle_args (lengthscale_func , args )
369
+ self .lfunc = handle_args (lengthscale_func , args )
333
370
self .args = args
334
371
335
372
def square_dist (self , X , Z ):
@@ -345,20 +382,23 @@ def square_dist(self, X, Z):
345
382
(tt .reshape (Xs , (- 1 , 1 )) + tt .reshape (Zs , (1 , - 1 )))
346
383
return tt .clip (sqd , 0.0 , np .inf )
347
384
348
- def __call__ (self , X , Z = None ):
385
+ def full (self , X , Z = None ):
349
386
X , Z = self ._slice (X , Z )
350
- rx = self .ell (X , self .args )
387
+ rx = self .lfunc (X , self .args )
351
388
rx2 = tt .reshape (tt .square (rx ), (- 1 , 1 ))
352
389
if Z is None :
353
390
r2 = self .square_dist (X ,X )
354
- rz = self .ell (X , self .args )
391
+ rz = self .lfunc (X , self .args )
355
392
else :
356
393
r2 = self .square_dist (X ,Z )
357
- rz = self .ell (Z , self .args )
394
+ rz = self .lfunc (Z , self .args )
358
395
rz2 = tt .reshape (tt .square (rz ), (1 , - 1 ))
359
396
return tt .sqrt ((2.0 * tt .dot (rx , tt .transpose (rz ))) / (rx2 + rz2 )) * \
360
397
tt .exp (- 1.0 * r2 / (rx2 + rz2 ))
361
398
399
+ def diag (self , X ):
400
+ return tt .ones (tt .stack ([X .shape [0 ], ]))
401
+
362
402
363
403
def handle_args (func , args ):
364
404
def f (x , args ):
0 commit comments