-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add GP Wrapped Periodic Kernel #6742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
fd1ac33
cc4eea9
59d59cd
d3b0586
9eee89a
34d3ee8
3c74942
9f36319
539061c
5fe3fba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,7 @@ | |
"Cosine", | ||
"Periodic", | ||
"WarpedInput", | ||
"WrappedPeriodic", | ||
"Gibbs", | ||
"Coregion", | ||
"ScaledCov", | ||
|
@@ -502,12 +503,20 @@ def square_dist(self, X, Xs): | |
|
||
def euclidean_dist(self, X, Xs): | ||
r2 = self.square_dist(X, Xs) | ||
return self._sqrt(r2) | ||
|
||
def _sqrt(self, r2): | ||
return pt.sqrt(r2 + 1e-12) | ||
|
||
def diag(self, X): | ||
return pt.alloc(1.0, X.shape[0]) | ||
|
||
def full(self, X, Xs=None): | ||
X, Xs = self._slice(X, Xs) | ||
r2 = self.square_dist(X, Xs) | ||
return self.full_from_distance(r2, squared=True) | ||
|
||
def full_from_distance(self, dist, squared=False): | ||
raise NotImplementedError | ||
|
||
def power_spectral_density(self, omega): | ||
|
@@ -544,8 +553,14 @@ def full(self, X, Xs=None): | |
f1 = X.dimshuffle(0, "x", 1) | ||
f2 = Xs.dimshuffle("x", 0, 1) | ||
r = np.pi * (f1 - f2) / self.period | ||
r = pt.sum(pt.square(pt.sin(r) / self.ls), 2) | ||
return pt.exp(-0.5 * r) | ||
r2 = pt.sum(pt.square(pt.sin(r) / self.ls), 2) | ||
return self.full_from_distance(r2, squared=True) | ||
|
||
def full_from_distance(self, dist, squared=False): | ||
# NOTE: This is the same as the ExpQuad as we assume the periodicity | ||
# has already been accounted for in the distance | ||
r2 = dist if squared else dist ** 2 | ||
return pt.exp(-0.5 * r2) | ||
|
||
|
||
class ExpQuad(Stationary): | ||
|
@@ -559,9 +574,9 @@ class ExpQuad(Stationary): | |
|
||
""" | ||
|
||
def full(self, X, Xs=None): | ||
X, Xs = self._slice(X, Xs) | ||
return pt.exp(-0.5 * self.square_dist(X, Xs)) | ||
def full_from_distance(self, dist, squared=False): | ||
r2 = dist if squared else dist ** 2 | ||
return pt.exp(-0.5 * r2) | ||
|
||
def power_spectral_density(self, omega): | ||
r""" | ||
|
@@ -592,10 +607,10 @@ def __init__(self, input_dim, alpha, ls=None, ls_inv=None, active_dims=None): | |
super().__init__(input_dim, ls, ls_inv, active_dims) | ||
self.alpha = alpha | ||
|
||
def full(self, X, Xs=None): | ||
X, Xs = self._slice(X, Xs) | ||
def full_from_distance(self, dist, squared=False): | ||
r2 = dist if squared else dist ** 2 | ||
return pt.power( | ||
(1.0 + 0.5 * self.square_dist(X, Xs) * (1.0 / self.alpha)), | ||
(1.0 + 0.5 * r2 * (1.0 / self.alpha)), | ||
-1.0 * self.alpha, | ||
) | ||
|
||
|
@@ -611,9 +626,8 @@ class Matern52(Stationary): | |
\mathrm{exp}\left[ - \frac{\sqrt{5(x - x')^2}}{\ell} \right] | ||
""" | ||
|
||
def full(self, X, Xs=None): | ||
X, Xs = self._slice(X, Xs) | ||
r = self.euclidean_dist(X, Xs) | ||
def full_from_distance(self, dist, squared=False): | ||
r = self._sqrt(dist) if squared else dist | ||
return (1.0 + np.sqrt(5.0) * r + 5.0 / 3.0 * pt.square(r)) * pt.exp(-1.0 * np.sqrt(5.0) * r) | ||
|
||
def power_spectral_density(self, omega): | ||
|
@@ -651,9 +665,8 @@ class Matern32(Stationary): | |
\mathrm{exp}\left[ - \frac{\sqrt{3(x - x')^2}}{\ell} \right] | ||
""" | ||
|
||
def full(self, X, Xs=None): | ||
X, Xs = self._slice(X, Xs) | ||
r = self.euclidean_dist(X, Xs) | ||
def full_from_distance(self, dist, squared=False): | ||
r = self._sqrt(dist) if squared else dist | ||
return (1.0 + np.sqrt(3.0) * r) * pt.exp(-np.sqrt(3.0) * r) | ||
|
||
def power_spectral_density(self, omega): | ||
|
@@ -690,9 +703,8 @@ class Matern12(Stationary): | |
k(x, x') = \mathrm{exp}\left[ -\frac{(x - x')^2}{\ell} \right] | ||
""" | ||
|
||
def full(self, X, Xs=None): | ||
X, Xs = self._slice(X, Xs) | ||
r = self.euclidean_dist(X, Xs) | ||
def full_from_distance(self, dist, squared=False): | ||
r = self._sqrt(dist) if squared else dist | ||
return pt.exp(-r) | ||
|
||
|
||
|
@@ -705,9 +717,9 @@ class Exponential(Stationary): | |
k(x, x') = \mathrm{exp}\left[ -\frac{||x - x'||}{2\ell} \right] | ||
""" | ||
|
||
def full(self, X, Xs=None): | ||
X, Xs = self._slice(X, Xs) | ||
return pt.exp(-0.5 * self.euclidean_dist(X, Xs)) | ||
def full_from_distance(self, dist, squared=False): | ||
r = self._sqrt(dist) if squared else dist | ||
return pt.exp(-0.5 * r) | ||
|
||
|
||
class Cosine(Stationary): | ||
|
@@ -718,9 +730,9 @@ class Cosine(Stationary): | |
k(x, x') = \mathrm{cos}\left( 2 \pi \frac{||x - x'||}{ \ell^2} \right) | ||
""" | ||
|
||
def full(self, X, Xs=None): | ||
X, Xs = self._slice(X, Xs) | ||
return pt.cos(2.0 * np.pi * self.euclidean_dist(X, Xs)) | ||
def full_from_distance(self, dist, squared=False): | ||
r = self._sqrt(dist) if squared else dist | ||
return pt.cos(2.0 * np.pi * r) | ||
|
||
|
||
class Linear(Covariance): | ||
|
@@ -812,6 +824,52 @@ def full(self, X, Xs=None): | |
def diag(self, X): | ||
X, _ = self._slice(X, None) | ||
return self.cov_func(self.w(X, self.args), diag=True) | ||
|
||
|
||
class WrappedPeriodic(Covariance): | ||
r""" | ||
Wrap a stationary covariance function to make it periodic. | ||
|
||
This is done by warping the input with the function | ||
|
||
.. math:: | ||
\mathbf{u}(x) = \left( | ||
\mathrm{sin} \left( \frac{2\pi x}{T} \right), | ||
\mathrm{cos} \left( \frac{2\pi x}{T} \right) | ||
\right) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be nice to add something like, "the Also, I think it'd be nice to add a note that describes and gives the code that makes this function equivalent to Also, the function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have addressed these in latest commit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! Super nice |
||
Parameters | ||
---------- | ||
cov_func: Stationary | ||
Base kernel or covariance function | ||
period: Period | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_dim: int, | ||
cov_func: Stationary, | ||
period, | ||
active_dims: Optional[Sequence[int]] = None, | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. My only concern would be it is then the only |
||
super().__init__(input_dim, active_dims) | ||
if not isinstance(cov_func, Stationary): | ||
raise TypeError("Must inherit from the Stationary class") | ||
self.cov_func = cov_func | ||
self.period = period | ||
|
||
def full(self, X, Xs=None): | ||
X, Xs = self._slice(X, Xs) | ||
if Xs is None: | ||
Xs = X | ||
f1 = pt.expand_dims(X, axis=(0,)) | ||
f2 = pt.expand_dims(Xs, axis=(1,)) | ||
r = np.pi * (f1 - f2) / self.period | ||
r2 = 4 * pt.sum(pt.square(pt.sin(r) / self.cov_func.ls), 2) | ||
return self.cov_func.full_from_distance(r2, squared=True) | ||
|
||
def diag(self, X): | ||
return pt.alloc(1.0, X.shape[0]) | ||
|
||
|
||
class Gibbs(Covariance): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you had GeneralizedPeriodic originally as the name, why the switch? I think GeneralizedPeriodic makes it a bit clearer what it's doing.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I felt it captured better what it was doing i.e. you use it to wrap up an existing kernel to make it periodic. I think a good name might be a verb (like
Add
orProd
) since it acts on an existing kernel...but I don't know what that verb would be :)Periodify
... But I don't mind moving back toGeneralizedPeriodic
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I guess
Wrapped
is more describes what the code does, andGeneralized
describes what the kernel is. Either way makes sense.