Skip to content

Commit 24bbeb3

Browse files
Merge branch 'master' into tgw/heterotopic
2 parents b5bb9e7 + ab30149 commit 24bbeb3

File tree

15 files changed

+160
-62
lines changed

15 files changed

+160
-62
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,8 @@ jobs:
5858
GROUP: ${{ matrix.group }}
5959
- uses: julia-actions/julia-processcoverage@v1
6060
if: matrix.version == '1' && matrix.os == 'ubuntu-latest'
61-
- name: Coveralls parallel
61+
- name: Send coverage to CodeCov
6262
if: matrix.version == '1' && matrix.os == 'ubuntu-latest'
63-
uses: coverallsapp/github-action@master
63+
uses: codecov/codecov-action@v2
6464
with:
65-
github-token: ${{ secrets.GITHUB_TOKEN }}
66-
path-to-lcov: ./lcov.info
67-
flag-name: run-${{ matrix.group }}
68-
parallel: true
69-
finish:
70-
needs: test
71-
runs-on: ubuntu-latest
72-
steps:
73-
- name: Send coverage
74-
uses: coverallsapp/github-action@master
75-
with:
76-
github-token: ${{ secrets.github_token }}
77-
parallel-finished: true
65+
file: lcov.info

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.24"
3+
version = "0.10.27"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# KernelFunctions.jl
22

33
![CI](https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/workflows/CI/badge.svg?branch=master)
4-
[![Coverage Status](https://coveralls.io/repos/github/JuliaGaussianProcesses/KernelFunctions.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaGaussianProcesses/KernelFunctions.jl?branch=master)
4+
[![codecov](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/branch/master/graph/badge.svg?token=rmDh3gb7hN)](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl)
55
[![Documentation (stable)](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable)
66
[![Documentation (latest)](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliagaussianprocesses.github.io/KernelFunctions.jl/dev)
77
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)

docs/src/api.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ To find out more about the background, read this [review of kernels for vector-v
8484

8585
KernelFunctions also provides miscellaneous utility functions.
8686
```@docs
87-
kernelpdmat
8887
nystrom
8988
NystromFact
9089
```
@@ -96,4 +95,11 @@ To keep the dependencies of KernelFunctions lean, some functionality is only ava
9695
[*https://github.com/MichielStock/Kronecker.jl*](https://github.com/MichielStock/Kronecker.jl)
9796
```@docs
9897
kronecker_kernelmatrix
98+
kernelkronmat
9999
```
100+
101+
### PDMats.jl
102+
[*https://github.com/JuliaStats/PDMats.jl*](https://github.com/JuliaStats/PDMats.jl)
103+
```@docs
104+
kernelpdmat
105+
```

examples/gaussian-process-priors/script.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ kernels = [
120120
LinearKernel(),
121121
compose(PeriodicKernel(), ScaleTransform(0.2)),
122122
NeuralNetworkKernel(),
123+
GibbsKernel(; lengthscale=x -> sum(exp sin, x)),
123124
]
124125
plot(
125126
[visualize(k) for k in kernels]...;

src/basekernels/rational.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ Rational kernel with shape parameter `α` and given `metric`.
55
66
# Definition
77
8-
For inputs ``x, x'``, the rational kernel with shape parameter
9-
``\\alpha > 0`` is defined as
8+
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the rational kernel with shape parameter ``\\alpha > 0`` is defined as
109
```math
11-
k(x, x'; \\alpha) = \\bigg(1 + \\frac{\\|x - x'\\|}{\\alpha}\\bigg)^{-\\alpha}.
10+
k(x, x'; \\alpha) = \\bigg(1 + \\frac{d(x, x')}{\\alpha}\\bigg)^{-\\alpha}.
1211
```
12+
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
1313
1414
The [`ExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to \\infty``.
1515

src/distances/delta.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
# Delta is not following the PreMetric rules since d(x, x) == 1
22
struct Delta <: Distances.UnionPreMetric end
33

4-
@inline Distances.eval_op(::Delta, a::Real, b::Real) = a == b
5-
@inline Distances.eval_reduce(::Delta, a, b) = a && b
6-
@inline Distances.eval_start(::Delta, a, b) = true
7-
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
8-
@inline (dist::Delta)(a::Number, b::Number) = a == b
4+
(dist::Delta)(a::Number, b::Number) = a == b
5+
Base.@propagate_inbounds function (dist::Delta)(
6+
a::AbstractArray{<:Number}, b::AbstractArray{<:Number}
7+
)
8+
@boundscheck if length(a) != length(b)
9+
throw(
10+
DimensionMismatch(
11+
"first array has length $(length(a)) which does not match the length of the second, $(length(b)).",
12+
),
13+
)
14+
end
15+
return a == b
16+
end
917

1018
Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool

src/distances/pairwise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ end
77
pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X)
88

99
function pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector, Y::AbstractVector)
10-
return broadcast!(d, out, X, Y')
10+
return broadcast!(d, out, X, permutedims(Y))
1111
end
1212

1313
pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X)

src/kernels/kerneltensorproduct.jl

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,17 @@ function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVec
7171
validate_inplace_dims(K, x)
7272
validate_domain(k, x)
7373

74-
kernels_and_inputs = zip(k.kernels, slices(x))
75-
kernelmatrix!(K, first(kernels_and_inputs)...)
76-
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
77-
K .*= kernelmatrix(k, xi)
74+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
75+
first_x, tail_x = Iterators.peel(slices(x))
76+
77+
# handle first kernel and input
78+
kernelmatrix!(K, first_kernels, first_x)
79+
80+
# handle remaining kernels and inputs
81+
Ktmp = similar(K)
82+
for (ki, xi) in zip(tail_kernels, tail_x)
83+
kernelmatrix!(Ktmp, ki, xi)
84+
hadamard!(K, K, Ktmp)
7885
end
7986

8087
return K
@@ -86,10 +93,18 @@ function kernelmatrix!(
8693
validate_inplace_dims(K, x, y)
8794
validate_domain(k, x)
8895

89-
kernels_and_inputs = zip(k.kernels, slices(x), slices(y))
90-
kernelmatrix!(K, first(kernels_and_inputs)...)
91-
for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1)
92-
K .*= kernelmatrix(k, xi, yi)
96+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
97+
first_x, tail_x = Iterators.peel(slices(x))
98+
first_y, tail_y = Iterators.peel(slices(y))
99+
100+
# handle first kernel and inputs
101+
kernelmatrix!(K, first_kernels, first_x, first_y)
102+
103+
# handle remaining kernels and inputs
104+
Ktmp = similar(K)
105+
for (ki, xi, yi) in zip(tail_kernels, tail_x, tail_y)
106+
kernelmatrix!(Ktmp, ki, xi, yi)
107+
hadamard!(K, K, Ktmp)
93108
end
94109

95110
return K
@@ -99,10 +114,40 @@ function kernelmatrix_diag!(K::AbstractVector, k::KernelTensorProduct, x::Abstra
99114
validate_inplace_dims(K, x)
100115
validate_domain(k, x)
101116

102-
kernels_and_inputs = zip(k.kernels, slices(x))
103-
kernelmatrix_diag!(K, first(kernels_and_inputs)...)
104-
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
105-
K .*= kernelmatrix_diag(k, xi)
117+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
118+
first_x, tail_x = Iterators.peel(slices(x))
119+
120+
# handle first kernel and input
121+
kernelmatrix_diag!(K, first_kernels, first_x)
122+
123+
# handle remaining kernels and inputs
124+
Ktmp = similar(K)
125+
for (ki, xi) in zip(tail_kernels, tail_x)
126+
kernelmatrix_diag!(Ktmp, ki, xi)
127+
hadamard!(K, K, Ktmp)
128+
end
129+
130+
return K
131+
end
132+
133+
function kernelmatrix_diag!(
134+
K::AbstractVector, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector
135+
)
136+
validate_inplace_dims(K, x, y)
137+
validate_domain(k, x)
138+
139+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
140+
first_x, tail_x = Iterators.peel(slices(x))
141+
first_y, tail_y = Iterators.peel(slices(y))
142+
143+
# handle first kernel and inputs
144+
kernelmatrix_diag!(K, first_kernels, first_x, first_y)
145+
146+
# handle remaining kernels and inputs
147+
Ktmp = similar(K)
148+
for (ki, xi, yi) in zip(tail_kernels, tail_x, tail_y)
149+
kernelmatrix_diag!(Ktmp, ki, xi, yi)
150+
hadamard!(K, K, Ktmp)
106151
end
107152

108153
return K

src/kernels/scaledkernel.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ function kernelmatrix_diag!(K::AbstractVector, κ::ScaledKernel, x::AbstractVect
6161
return K
6262
end
6363

64+
function kernelmatrix_diag!(
65+
K::AbstractVector, κ::ScaledKernel, x::AbstractVector, y::AbstractVector
66+
)
67+
kernelmatrix_diag!(K, κ.kernel, x, y)
68+
K .*= κ.σ²
69+
return K
70+
end
71+
6472
Base.:*(w::Real, k::Kernel) = ScaledKernel(k, w)
6573

6674
Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)

src/matrix/kernelkroneckermat.jl

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,58 @@ using .Kronecker: Kronecker
66
export kernelkronmat
77
export kronecker_kernelmatrix
88

9-
function kernelkronmat::Kernel, X::AbstractVector, dims::Int)
10-
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see [`iskroncompatible`](@ref))"
11-
k = kernelmatrix(κ, X)
12-
return Kronecker.kronecker(k, dims)
9+
@doc raw"""
10+
kernelkronmat(κ::Kernel, X::AbstractVector{<:Real}, dims::Int) -> KroneckerPower
11+
12+
Return a `KroneckerPower` matrix on the `D`-dimensional input grid constructed by ``\otimes_{i=1}^D X``,
13+
where `D` is given by `dims`.
14+
15+
!!! warning
16+
17+
Require `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`.
18+
"""
19+
function kernelkronmat::Kernel, X::AbstractVector{<:Real}, dims::Int)
20+
checkkroncompatible(κ)
21+
K = kernelmatrix(κ, X)
22+
return Kronecker.kronecker(K, dims)
1323
end
1424

15-
function kernelkronmat(
16-
κ::Kernel, X::AbstractVector{<:AbstractVector}; obsdim::Int=defaultobs
17-
)
18-
@assert iskroncompatible(κ) "The chosen kernel is not compatible for Kronecker matrices"
25+
@doc raw"""
26+
27+
kernelkronmat(κ::Kernel, X::AbstractVector{<:AbstractVector}) -> KroneckerProduct
28+
29+
Returns a `KroneckerProduct` matrix on the grid built with the collection of vectors ``\{X_i\}_{i=1}^D``: ``\otimes_{i=1}^D X_i``.
30+
31+
!!! warning
32+
33+
Requires `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`.
34+
"""
35+
function kernelkronmat::Kernel, X::AbstractVector{<:AbstractVector})
36+
checkkroncompatible(κ)
1937
Ks = kernelmatrix.(κ, X)
20-
return K = reduce(Kronecker.:, Ks)
38+
return reduce(Kronecker.:, Ks)
2139
end
2240

23-
"""
24-
To be compatible with kroenecker constructions the kernel must satisfy
25-
the property : for x,x' ∈ ℜᴰ
26-
k(x,x') = ∏ᵢᴰ k(xᵢ,x'ᵢ)
41+
@doc raw"""
42+
iskroncompatible(k::Kernel)
43+
44+
Determine whether kernel `k` is compatible with Kronecker constructions such as [`kernelkronmat`](@ref)
45+
46+
The function returns `false` by default. If `k` is compatible it must satisfy for all ``x, x' \in \mathbb{R}^D`:
47+
```math
48+
k(x, x') = \prod_{i=1}^D k(x_i, x'_i).
49+
```
2750
"""
2851
@inline iskroncompatible::Kernel) = false # Default return for kernels
2952

53+
function checkkroncompatible::Kernel)
54+
return iskroncompatible(κ) || throw(
55+
ArgumentError(
56+
"The chosen kernel is not compatible for Kronecker matrices (see [`iskroncompatible`](@ref))",
57+
),
58+
)
59+
end
60+
3061
function _kernelmatrix_kroneckerjl_helper(
3162
::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs
3263
)

src/test_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ function test_interface(
8787

8888
tmp_diag = Vector{Float64}(undef, length(x0))
8989
@test kernelmatrix_diag!(tmp_diag, k, x0) kernelmatrix_diag(k, x0)
90+
@test kernelmatrix_diag!(tmp_diag, k, x0, x1) kernelmatrix_diag(k, x0, x1)
9091
end
9192

9293
function test_interface(

src/utils.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,27 @@ function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector, y::Abstract
197197
end
198198
end
199199

200-
function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector)
201-
return validate_inplace_dims(K, x, x)
202-
end
203-
204-
function validate_inplace_dims(K::AbstractVector, x::AbstractVector)
205-
if length(K) != length(x)
200+
function validate_inplace_dims(K::AbstractVector, x::AbstractVector, y::AbstractVector)
201+
validate_inputs(x, y)
202+
n = length(x)
203+
if length(y) != n
204+
throw(
205+
DimensionMismatch(
206+
"Length of input x ($n) not consistent with length of input y " *
207+
"($(length(y))",
208+
),
209+
)
210+
end
211+
if length(K) != n
206212
throw(
207213
DimensionMismatch(
208-
"Length of target vector K ($(length(K))) not consistent with length of input" *
209-
"vector x ($(length(x))",
214+
"Length of target vector K ($(length(K))) not consistent with length of " *
215+
"inputs ($n)",
210216
),
211217
)
212218
end
213219
end
220+
221+
function validate_inplace_dims(K::AbstractVecOrMat, x::AbstractVector)
222+
return validate_inplace_dims(K, x, x)
223+
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1717

1818
[compat]
1919
AxisArrays = "0.4.3"
20-
Distances = "0.9, 0.10"
20+
Distances = "= 0.10.0, = 0.10.1, = 0.10.2, = 0.10.3, = 0.10.4"
2121
Documenter = "0.25, 0.26, 0.27"
2222
FiniteDifferences = "0.10.8, 0.11, 0.12"
2323
ForwardDiff = "0.10"

test/matrix/kernelkroneckermat.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
@test all(collect(kernelkronmat(k, collect(x), 2)) .≈ kernelmatrix(k, X; obsdim=1))
88
@test all(collect(kernelkronmat(k, [x, x])) .≈ kernelmatrix(k, X; obsdim=1))
9-
@test_throws AssertionError kernelkronmat(LinearKernel(), collect(x), 2)
9+
@test_throws ArgumentError kernelkronmat(LinearKernel(), collect(x), 2)
1010

1111
@testset "lazy kernelmatrix" begin
1212
rng = MersenneTwister(123)

0 commit comments

Comments
 (0)