Skip to content

Commit 2181649

Browse files
devmotionwilltebbuttgithub-actions[bot]st--
authored
Add with_lengthscale (alternative/extension of #335) (#336)
Co-authored-by: willtebbutt <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: ST John <[email protected]>
1 parent 58cb069 commit 2181649

File tree

7 files changed

+76
-1
lines changed

7 files changed

+76
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.6"
3+
version = "0.10.7"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/transform.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,9 @@ SelectTransform
3636
ChainTransform
3737
PeriodicTransform
3838
```
39+
40+
## Convenience functions
41+
42+
```@docs
43+
with_lengthscale
44+
```

docs/src/userguide.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ For example, a squared exponential kernel is created by
1515
k = SqExponentialKernel() ∘ ScaleTransform(2.0)
1616
k = compose(SqExponentialKernel(), ScaleTransform(2.0))
1717
```
18+
Alternatively, you can use the convenience function [`with_lengthscale`](@ref):
19+
```julia
20+
k = with_lengthscale(SqExponentialKernel(), 0.5)
21+
```
22+
[`with_lengthscale`](@ref) also works with vector-valued lengthscales for ARD.
1823
Check the [Input Transforms](@ref input_transforms) page for more details.
1924

2025
!!! tip "How do I set the kernel variance?"

src/KernelFunctions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export Transform,
2727
IdentityTransform,
2828
FunctionTransform,
2929
PeriodicTransform
30+
export with_lengthscale
3031

3132
export NystromFact, nystrom
3233

@@ -75,6 +76,7 @@ include(joinpath("transform", "selecttransform.jl"))
7576
include(joinpath("transform", "chaintransform.jl"))
7677
include(joinpath("transform", "periodic_transform.jl"))
7778
include(joinpath("kernels", "transformedkernel.jl"))
79+
include(joinpath("transform", "with_lengthscale.jl"))
7880

7981
include(joinpath("basekernels", "constant.jl"))
8082
include(joinpath("basekernels", "cosine.jl"))

src/transform/with_lengthscale.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
with_lengthscale(kernel::Kernel, lengthscale::Real)
3+
4+
Construct a transformed kernel with `lengthscale`.
5+
6+
# Examples
7+
8+
```jldoctest
9+
julia> kernel = with_lengthscale(SqExponentialKernel(), 2.5);
10+
11+
julia> x = rand(2);
12+
13+
julia> y = rand(2);
14+
15+
julia> kernel(x, y) ≈ (SqExponentialKernel() ∘ ScaleTransform(0.4))(x, y)
16+
true
17+
```
18+
"""
19+
function with_lengthscale(kernel::Kernel, lengthscale::Real)
20+
return kernel ScaleTransform(inv(lengthscale))
21+
end
22+
23+
"""
24+
with_lengthscale(kernel::Kernel, lengthscales::AbstractVector{<:Real})
25+
26+
Construct a transformed "ARD" kernel with different `lengthscales` for each dimension.
27+
28+
# Examples
29+
30+
```jldoctest
31+
julia> kernel = with_lengthscale(SqExponentialKernel(), [0.5, 2.5]);
32+
33+
julia> x = rand(2);
34+
35+
julia> y = rand(2);
36+
37+
julia> kernel(x, y) ≈ (SqExponentialKernel() ∘ ARDTransform([2, 0.4]))(x, y)
38+
true
39+
```
40+
"""
41+
function with_lengthscale(kernel::Kernel, lengthscales::AbstractVector{<:Real})
42+
return kernel ARDTransform(map(inv, lengthscales))
43+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ include("test_utils.jl")
7676
print(" ")
7777
include(joinpath("transform", "periodic_transform.jl"))
7878
print(" ")
79+
include(joinpath("transform", "with_lengthscale.jl"))
80+
print(" ")
7981
end
8082
@info "Ran tests on Transform"
8183
end

test/transform/with_lengthscale.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@testset "with_lengthscale" begin
2+
@testset "ScaleTransform" begin
3+
l = exp(rand())
4+
kernel = @inferred(with_lengthscale(SqExponentialKernel(), l))
5+
6+
@test kernel isa TransformedKernel{<:SqExponentialKernel,<:ScaleTransform}
7+
@test kernel.transform.s[1] inv(l)
8+
end
9+
10+
@testset "ARDTransform" begin
11+
l = map(exp, rand(5))
12+
kernel = @inferred(with_lengthscale(SqExponentialKernel(), l))
13+
14+
@test kernel isa TransformedKernel{<:SqExponentialKernel,<:ARDTransform}
15+
@test kernel.transform.v map(inv, l)
16+
end
17+
end

0 commit comments

Comments
 (0)