Skip to content

Commit 9b4ca7c

Browse files
authored
Merge branch 'master' into add-gibbskernel-to-prior-example
2 parents 21372f0 + 7ce1c39 commit 9b4ca7c

File tree

16 files changed

+215
-57
lines changed

16 files changed

+215
-57
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.21"
3+
version = "0.10.26"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/api.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ To find out more about the background, read this [review of kernels for vector-v
8484

8585
KernelFunctions also provides miscellaneous utility functions.
8686
```@docs
87-
kernelpdmat
8887
nystrom
8988
NystromFact
9089
```
@@ -96,4 +95,11 @@ To keep the dependencies of KernelFunctions lean, some functionality is only ava
9695
[*https://github.com/MichielStock/Kronecker.jl*](https://github.com/MichielStock/Kronecker.jl)
9796
```@docs
9897
kronecker_kernelmatrix
98+
kernelkronmat
9999
```
100+
101+
### PDMats.jl
102+
[*https://github.com/JuliaStats/PDMats.jl*](https://github.com/JuliaStats/PDMats.jl)
103+
```@docs
104+
kernelpdmat
105+
```

src/KernelFunctions.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2
5858
using LogExpFunctions: softplus
5959
using StatsBase
6060
using TensorCore
61-
using ZygoteRules: ZygoteRules
61+
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield
62+
63+
# Hack to work around Zygote type inference problems.
64+
const Distances_pairwise = Distances.pairwise
6265

6366
abstract type Kernel end
6467
abstract type SimpleKernel <: Kernel end

src/basekernels/fbm.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ end
4343
_fbm(modX, modY, modXY, h) = (modX^h + modY^h - modXY^h) / 2
4444

4545
_mod(x::AbstractVector{<:Real}) = abs2.(x)
46+
_mod(x::AbstractVector{<:AbstractVector{<:Real}}) = sum.(abs2, x)
47+
# two lines above could be combined into the second (dispatching on general AbstractVectors), but this (somewhat) more performant
4648
_mod(x::ColVecs) = vec(sum(abs2, x.X; dims=1))
4749
_mod(x::RowVecs) = vec(sum(abs2, x.X; dims=2))
4850

src/basekernels/rational.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ Rational kernel with shape parameter `α` and given `metric`.
55
66
# Definition
77
8-
For inputs ``x, x'``, the rational kernel with shape parameter
9-
``\\alpha > 0`` is defined as
8+
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the rational kernel with shape parameter ``\\alpha > 0`` is defined as
109
```math
11-
k(x, x'; \\alpha) = \\bigg(1 + \\frac{\\|x - x'\\|}{\\alpha}\\bigg)^{-\\alpha}.
10+
k(x, x'; \\alpha) = \\bigg(1 + \\frac{d(x, x')}{\\alpha}\\bigg)^{-\\alpha}.
1211
```
12+
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
1313
1414
The [`ExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to \\infty``.
1515

src/distances/delta.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
# Delta is not following the PreMetric rules since d(x, x) == 1
22
struct Delta <: Distances.UnionPreMetric end
33

4-
@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector)
4+
(dist::Delta)(a::Number, b::Number) = a == b
5+
Base.@propagate_inbounds function (dist::Delta)(
6+
a::AbstractArray{<:Number}, b::AbstractArray{<:Number}
7+
)
58
@boundscheck if length(a) != length(b)
69
throw(
710
DimensionMismatch(
8-
"first array has length $(length(a)) which does not match the length of the " *
9-
"second, $(length(b)).",
11+
"first array has length $(length(a)) which does not match the length of the second, $(length(b)).",
1012
),
1113
)
1214
end
1315
return a == b
1416
end
1517

1618
Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool
17-
18-
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
19-
@inline (dist::Delta)(a::Number, b::Number) = a == b

src/distances/pairwise.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ end
77
pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X)
88

99
function pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector, Y::AbstractVector)
10-
return broadcast!(d, out, X, Y')
10+
return broadcast!(d, out, X, permutedims(Y))
1111
end
1212

1313
pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X)
1414

1515
function pairwise(d::PreMetric, x::AbstractVector{<:Real})
16-
return Distances.pairwise(d, reshape(x, :, 1); dims=1)
16+
return Distances_pairwise(d, reshape(x, :, 1); dims=1)
1717
end
1818

1919
function pairwise(d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
20-
return Distances.pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
20+
return Distances_pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
2121
end
2222

2323
function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real})

src/kernels/kerneltensorproduct.jl

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,17 @@ function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVec
7171
validate_inplace_dims(K, x)
7272
validate_domain(k, x)
7373

74-
kernels_and_inputs = zip(k.kernels, slices(x))
75-
kernelmatrix!(K, first(kernels_and_inputs)...)
76-
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
77-
K .*= kernelmatrix(k, xi)
74+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
75+
first_x, tail_x = Iterators.peel(slices(x))
76+
77+
# handle first kernel and input
78+
kernelmatrix!(K, first_kernels, first_x)
79+
80+
# handle remaining kernels and inputs
81+
Ktmp = similar(K)
82+
for (ki, xi) in zip(tail_kernels, tail_x)
83+
kernelmatrix!(Ktmp, ki, xi)
84+
hadamard!(K, K, Ktmp)
7885
end
7986

8087
return K
@@ -86,10 +93,18 @@ function kernelmatrix!(
8693
validate_inplace_dims(K, x, y)
8794
validate_domain(k, x)
8895

89-
kernels_and_inputs = zip(k.kernels, slices(x), slices(y))
90-
kernelmatrix!(K, first(kernels_and_inputs)...)
91-
for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1)
92-
K .*= kernelmatrix(k, xi, yi)
96+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
97+
first_x, tail_x = Iterators.peel(slices(x))
98+
first_y, tail_y = Iterators.peel(slices(y))
99+
100+
# handle first kernel and inputs
101+
kernelmatrix!(K, first_kernels, first_x, first_y)
102+
103+
# handle remaining kernels and inputs
104+
Ktmp = similar(K)
105+
for (ki, xi, yi) in zip(tail_kernels, tail_x, tail_y)
106+
kernelmatrix!(Ktmp, ki, xi, yi)
107+
hadamard!(K, K, Ktmp)
93108
end
94109

95110
return K
@@ -99,10 +114,40 @@ function kernelmatrix_diag!(K::AbstractVector, k::KernelTensorProduct, x::Abstra
99114
validate_inplace_dims(K, x)
100115
validate_domain(k, x)
101116

102-
kernels_and_inputs = zip(k.kernels, slices(x))
103-
kernelmatrix_diag!(K, first(kernels_and_inputs)...)
104-
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
105-
K .*= kernelmatrix_diag(k, xi)
117+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
118+
first_x, tail_x = Iterators.peel(slices(x))
119+
120+
# handle first kernel and input
121+
kernelmatrix_diag!(K, first_kernels, first_x)
122+
123+
# handle remaining kernels and inputs
124+
Ktmp = similar(K)
125+
for (ki, xi) in zip(tail_kernels, tail_x)
126+
kernelmatrix_diag!(Ktmp, ki, xi)
127+
hadamard!(K, K, Ktmp)
128+
end
129+
130+
return K
131+
end
132+
133+
function kernelmatrix_diag!(
134+
K::AbstractVector, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector
135+
)
136+
validate_inplace_dims(K, x, y)
137+
validate_domain(k, x)
138+
139+
first_kernels, tail_kernels = Iterators.peel(k.kernels)
140+
first_x, tail_x = Iterators.peel(slices(x))
141+
first_y, tail_y = Iterators.peel(slices(y))
142+
143+
# handle first kernel and inputs
144+
kernelmatrix_diag!(K, first_kernels, first_x, first_y)
145+
146+
# handle remaining kernels and inputs
147+
Ktmp = similar(K)
148+
for (ki, xi, yi) in zip(tail_kernels, tail_x, tail_y)
149+
kernelmatrix_diag!(Ktmp, ki, xi, yi)
150+
hadamard!(K, K, Ktmp)
106151
end
107152

108153
return K

src/kernels/scaledkernel.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ function kernelmatrix_diag!(K::AbstractVector, κ::ScaledKernel, x::AbstractVect
6161
return K
6262
end
6363

64+
function kernelmatrix_diag!(
65+
K::AbstractVector, κ::ScaledKernel, x::AbstractVector, y::AbstractVector
66+
)
67+
kernelmatrix_diag!(K, κ.kernel, x, y)
68+
K .*= κ.σ²
69+
return K
70+
end
71+
6472
Base.:*(w::Real, k::Kernel) = ScaledKernel(k, w)
6573

6674
Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)

src/matrix/kernelkroneckermat.jl

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,58 @@ using .Kronecker: Kronecker
66
export kernelkronmat
77
export kronecker_kernelmatrix
88

9-
function kernelkronmat::Kernel, X::AbstractVector, dims::Int)
10-
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see [`iskroncompatible`](@ref))"
11-
k = kernelmatrix(κ, X)
12-
return Kronecker.kronecker(k, dims)
9+
@doc raw"""
10+
kernelkronmat(κ::Kernel, X::AbstractVector{<:Real}, dims::Int) -> KroneckerPower
11+
12+
Return a `KroneckerPower` matrix on the `D`-dimensional input grid constructed by ``\otimes_{i=1}^D X``,
13+
where `D` is given by `dims`.
14+
15+
!!! warning
16+
17+
Require `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`.
18+
"""
19+
function kernelkronmat::Kernel, X::AbstractVector{<:Real}, dims::Int)
20+
checkkroncompatible(κ)
21+
K = kernelmatrix(κ, X)
22+
return Kronecker.kronecker(K, dims)
1323
end
1424

15-
function kernelkronmat(
16-
κ::Kernel, X::AbstractVector{<:AbstractVector}; obsdim::Int=defaultobs
17-
)
18-
@assert iskroncompatible(κ) "The chosen kernel is not compatible for Kronecker matrices"
25+
@doc raw"""
26+
27+
kernelkronmat(κ::Kernel, X::AbstractVector{<:AbstractVector}) -> KroneckerProduct
28+
29+
Returns a `KroneckerProduct` matrix on the grid built with the collection of vectors ``\{X_i\}_{i=1}^D``: ``\otimes_{i=1}^D X_i``.
30+
31+
!!! warning
32+
33+
Requires `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`.
34+
"""
35+
function kernelkronmat::Kernel, X::AbstractVector{<:AbstractVector})
36+
checkkroncompatible(κ)
1937
Ks = kernelmatrix.(κ, X)
20-
return K = reduce(Kronecker.:, Ks)
38+
return reduce(Kronecker.:, Ks)
2139
end
2240

23-
"""
24-
To be compatible with kroenecker constructions the kernel must satisfy
25-
the property : for x,x' ∈ ℜᴰ
26-
k(x,x') = ∏ᵢᴰ k(xᵢ,x'ᵢ)
41+
@doc raw"""
42+
iskroncompatible(k::Kernel)
43+
44+
Determine whether kernel `k` is compatible with Kronecker constructions such as [`kernelkronmat`](@ref)
45+
46+
The function returns `false` by default. If `k` is compatible it must satisfy for all ``x, x' \in \mathbb{R}^D`:
47+
```math
48+
k(x, x') = \prod_{i=1}^D k(x_i, x'_i).
49+
```
2750
"""
2851
@inline iskroncompatible::Kernel) = false # Default return for kernels
2952

53+
function checkkroncompatible::Kernel)
54+
return iskroncompatible(κ) || throw(
55+
ArgumentError(
56+
"The chosen kernel is not compatible for Kronecker matrices (see [`iskroncompatible`](@ref))",
57+
),
58+
)
59+
end
60+
3061
function _kernelmatrix_kroneckerjl_helper(
3162
::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs
3263
)

src/test_utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ function test_interface(
8787

8888
tmp_diag = Vector{Float64}(undef, length(x0))
8989
@test kernelmatrix_diag!(tmp_diag, k, x0) kernelmatrix_diag(k, x0)
90+
@test kernelmatrix_diag!(tmp_diag, k, x0, x1) kernelmatrix_diag(k, x0, x1)
9091
end
9192

9293
function test_interface(
@@ -133,6 +134,18 @@ function test_interface(
133134
)
134135
end
135136

137+
function test_interface(
138+
rng::AbstractRNG, k::Kernel, ::Type{<:Vector{Vector{T}}}; dim_in=2, kwargs...
139+
) where {T<:Real}
140+
return test_interface(
141+
k,
142+
[randn(rng, T, dim_in) for _ in 1:1001],
143+
[randn(rng, T, dim_in) for _ in 1:1001],
144+
[randn(rng, T, dim_in) for _ in 1:1000];
145+
kwargs...,
146+
)
147+
end
148+
136149
function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...)
137150
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
138151
end
@@ -147,6 +160,9 @@ function test_interface(rng::AbstractRNG, k::Kernel, T::Type{<:Real}; kwargs...)
147160
@testset "RowVecs{$T}" begin
148161
test_interface(rng, k, RowVecs{T}; kwargs...)
149162
end
163+
@testset "Vector{Vector{T}}" begin
164+
test_interface(rng, k, Vector{Vector{T}}; kwargs...)
165+
end
150166
end
151167

152168
function test_interface(k::Kernel, T::Type{<:Real}=Float64; kwargs...)

src/utils.jl

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X))
8080

8181
dim(x::ColVecs) = size(x.X, 1)
8282

83-
pairwise(d::PreMetric, x::ColVecs) = Distances.pairwise(d, x.X; dims=2)
84-
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances.pairwise(d, x.X, y.X; dims=2)
83+
pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2)
84+
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)
8585
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)
86-
return Distances.pairwise(d, reduce(hcat, x), y.X; dims=2)
86+
return Distances_pairwise(d, reduce(hcat, x), y.X; dims=2)
8787
end
8888
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector)
89-
return Distances.pairwise(d, x.X, reduce(hcat, y); dims=2)
89+
return Distances_pairwise(d, x.X, reduce(hcat, y); dims=2)
9090
end
9191
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
9292
return Distances.pairwise!(out, d, x.X; dims=2)
@@ -150,13 +150,13 @@ Base.vcat(a::RowVecs, b::RowVecs) = RowVecs(vcat(a.X, b.X))
150150

151151
dim(x::RowVecs) = size(x.X, 2)
152152

153-
pairwise(d::PreMetric, x::RowVecs) = Distances.pairwise(d, x.X; dims=1)
154-
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances.pairwise(d, x.X, y.X; dims=1)
153+
pairwise(d::PreMetric, x::RowVecs) = Distances_pairwise(d, x.X; dims=1)
154+
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1)
155155
function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs)
156-
return Distances.pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1)
156+
return Distances_pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1)
157157
end
158158
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector)
159-
return Distances.pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1)
159+
return Distances_pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1)
160160
end
161161
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
162162
return Distances.pairwise!(out, d, x.X; dims=1)
@@ -197,17 +197,27 @@ function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector, y::Abstract
197197
end
198198
end
199199

200-
function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector)
201-
return validate_inplace_dims(K, x, x)
202-
end
203-
204-
function validate_inplace_dims(K::AbstractVector, x::AbstractVector)
205-
if length(K) != length(x)
200+
function validate_inplace_dims(K::AbstractVector, x::AbstractVector, y::AbstractVector)
201+
validate_inputs(x, y)
202+
n = length(x)
203+
if length(y) != n
204+
throw(
205+
DimensionMismatch(
206+
"Length of input x ($n) not consistent with length of input y " *
207+
"($(length(y))",
208+
),
209+
)
210+
end
211+
if length(K) != n
206212
throw(
207213
DimensionMismatch(
208-
"Length of target vector K ($(length(K))) not consistent with length of input" *
209-
"vector x ($(length(x))",
214+
"Length of target vector K ($(length(K))) not consistent with length of " *
215+
"inputs ($n)",
210216
),
211217
)
212218
end
213219
end
220+
221+
function validate_inplace_dims(K::AbstractVecOrMat, x::AbstractVector)
222+
return validate_inplace_dims(K, x, x)
223+
end

0 commit comments

Comments
 (0)