-
Notifications
You must be signed in to change notification settings - Fork 36
Extension of #269: Use \circ
and compose
and deprecate transform
#276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5c1ae08
46343d7
8280d27
777f7b2
e653fe4
80b57e3
d3a3d6b
3969134
a004f60
947a35b
6026b79
dde4c1d
7fa7e89
ee3914b
37d274d
2710c17
667e177
0440b94
1f9ba02
dc2989e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
@deprecate transform(k::Kernel, t::Transform) k ∘ t | ||
@deprecate transform(k::TransformedKernel, t::Transform) k.kernel ∘ t ∘ k.transform | ||
@deprecate transform(k::Kernel, ρ::Real) k ∘ ScaleTransform(ρ) | ||
@deprecate transform(k::Kernel, ρ::AbstractVector) k ∘ ARDTransform(ρ) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,17 +3,11 @@ | |
|
||
Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`. | ||
|
||
It is preferred to create kernels with input transformations with [`transform`](@ref) | ||
instead of `TransformedKernel` directly since [`transform`](@ref) allows optimized | ||
implementations for specific kernels and transformations. | ||
The preferred way to create kernels with input transformations is to use the composition | ||
operator [`∘`](@ref) or its alias `compose` instead of `TransformedKernel` directly since | ||
this allows optimized implementations for specific kernels and transformations. | ||
|
||
# Definition | ||
|
||
For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by | ||
input transformation ``t`` is defined as | ||
```math | ||
\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big). | ||
``` | ||
See also: [`∘`](@ref) | ||
""" | ||
struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel | ||
kernel::Tk | ||
|
@@ -42,30 +36,37 @@ end | |
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y)) | ||
|
||
""" | ||
transform(k::Kernel, t::Transform) | ||
kernel ∘ transform | ||
∘(kernel, transform) | ||
compose(kernel, transform) | ||
|
||
Create a [`TransformedKernel`](@ref) for kernel `k` and transform `t`. | ||
""" | ||
transform(k::Kernel, t::Transform) = TransformedKernel(k, t) | ||
function transform(k::TransformedKernel, t::Transform) | ||
return TransformedKernel(k.kernel, t ∘ k.transform) | ||
end | ||
Compose a `kernel` with a transformation `transform` of its inputs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we repeat the mathematical definition from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like redundancy 😛 The docstring of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant to access the docs from the REPL :) I personally almost never check API pages from the docs but overuse |
||
|
||
""" | ||
transform(k::Kernel, ρ::Real) | ||
The prefix forms support chains of multiple transformations: | ||
`∘(kernel, transform1, transform2) = kernel ∘ transform1 ∘ transform2`. | ||
|
||
Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscale `ρ`. | ||
""" | ||
transform(k::Kernel, ρ::Real) = transform(k, ScaleTransform(ρ)) | ||
# Definition | ||
|
||
""" | ||
transform(k::Kernel, ρ::AbstractVector) | ||
For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by | ||
input transformation ``t`` is defined as | ||
```math | ||
\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big). | ||
``` | ||
|
||
Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscales `ρ`. | ||
""" | ||
transform(k::Kernel, ρ::AbstractVector) = transform(k, ARDTransform(ρ)) | ||
# Examples | ||
|
||
```jldoctest | ||
julia> (SqExponentialKernel() ∘ ScaleTransform(0.5))(0, 2) == exp(-0.5) | ||
true | ||
|
||
kernel(κ) = κ.kernel | ||
theogf marked this conversation as resolved.
Show resolved
Hide resolved
|
||
julia> ∘(ExponentialKernel(), ScaleTransform(2), ScaleTransform(0.5))(1, 2) == exp(-1) | ||
true | ||
``` | ||
|
||
See also: [`TransformedKernel`](@ref) | ||
""" | ||
Base.:∘(k::Kernel, t::Transform) = TransformedKernel(k, t) | ||
Base.:∘(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.transform ∘ t) | ||
|
||
Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0) | ||
|
||
|
@@ -87,13 +88,13 @@ function kernelmatrix_diag!( | |
end | ||
|
||
function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector) | ||
return kernelmatrix!(K, kernel(κ), _map(κ.transform, x)) | ||
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x)) | ||
end | ||
|
||
function kernelmatrix!( | ||
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector | ||
) | ||
return kernelmatrix!(K, kernel(κ), _map(κ.transform, x), _map(κ.transform, y)) | ||
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) | ||
end | ||
|
||
function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector) | ||
|
@@ -105,9 +106,9 @@ function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::Abstract | |
end | ||
|
||
function kernelmatrix(κ::TransformedKernel, x::AbstractVector) | ||
return kernelmatrix(kernel(κ), _map(κ.transform, x)) | ||
return kernelmatrix(κ.kernel, _map(κ.transform, x)) | ||
end | ||
|
||
function kernelmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector) | ||
return kernelmatrix(kernel(κ), _map(κ.transform, x), _map(κ.transform, y)) | ||
return kernelmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
@testset "deprecations.jl" begin | ||
p = rand() | ||
v = rand(3) | ||
M = rand(3, 3) | ||
v1 = rand(3) | ||
v2 = rand(3) | ||
kernel = SqExponentialKernel() | ||
|
||
k1 = @test_deprecated transform(kernel, LinearTransform(M)) | ||
@test k1(v1, v2) == (kernel ∘ LinearTransform(M))(v1, v2) | ||
|
||
k2 = @test_deprecated transform(kernel ∘ ScaleTransform(p), ARDTransform(v)) | ||
@test k2(v1, v2) == (kernel ∘ ARDTransform(v) ∘ ScaleTransform(p))(v1, v2) | ||
|
||
k3 = @test_deprecated transform(kernel, p) | ||
@test k3(v1, v2) == (kernel ∘ ScaleTransform(p))(v1, v2) | ||
|
||
k4 = @test_deprecated transform(kernel, v) | ||
@test k4(v1, v2) == (kernel ∘ ARDTransform(v))(v1, v2) | ||
end |
Uh oh!
There was an error while loading. Please reload this page.