Skip to content

Commit 70e3593

Browse files
Transfer NeuralKernelNetwork over from Stheno (#283)
* NeuralKernelNetwork * Bump patch * Remove redundant line * Improve docs slightly * Fix formatting * Fix formatting * Remove Flux dep * Remove Flux * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Add compat for LogExpFunctions * Stop exporting NKN * Fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 536bdcb commit 70e3593

File tree

7 files changed

+211
-6
lines changed

7 files changed

+211
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.9.5"
3+
version = "0.9.6"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/KernelFunctions.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,16 @@ using Compat
4545
using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS
4646
using ChainRulesCore: @thunk, InplaceableThunk
4747
using CompositionsBase
48-
using Requires
49-
using Distances, LinearAlgebra
48+
using Distances
49+
using FillArrays
5050
using Functors
51+
using LinearAlgebra
52+
using Requires
5153
using SpecialFunctions: loggamma, besselk, polygamma
52-
using ZygoteRules: ZygoteRules
53-
using StatsFuns: logtwo, twoπ
54+
using StatsFuns: logtwo, twoπ, softplus
5455
using StatsBase
5556
using TensorCore
56-
using FillArrays
57+
using ZygoteRules: ZygoteRules
5758

5859
abstract type Kernel end
5960
abstract type SimpleKernel <: Kernel end
@@ -96,6 +97,7 @@ include(joinpath("kernels", "kernelsum.jl"))
9697
include(joinpath("kernels", "kernelproduct.jl"))
9798
include(joinpath("kernels", "kerneltensorproduct.jl"))
9899
include(joinpath("kernels", "overloads.jl"))
100+
include(joinpath("kernels", "neuralkernelnetwork.jl"))
99101
include(joinpath("approximations", "nystrom.jl"))
100102
include("generic.jl")
101103

src/distances/pairwise.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ function colwise(d::PreMetric, x::AbstractVector)
3636
return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition
3737
end
3838

39+
function colwise(d::PreMetric, x::ColVecs)
40+
return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition
41+
end
42+
43+
function colwise(d::PreMetric, x::RowVecs)
44+
return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition
45+
end
46+
3947
## The following is a hack for DotProduct and Delta to still work
4048
function colwise(d::Distances.UnionPreMetric, x::ColVecs)
4149
return Distances.colwise(d, x.X, x.X)

src/kernels/neuralkernelnetwork.jl

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Linear layer, perform linear transformation to input array
2+
# x₁ = softplus.(W) * x₀
3+
struct LinearLayer{T,MT<:AbstractArray{T}}
4+
W::MT
5+
end
6+
@functor LinearLayer
7+
8+
LinearLayer(in_dim, out_dim) = LinearLayer(randn(out_dim, in_dim))
9+
10+
(lin::LinearLayer)(x) = softplus.(lin.W) * x
11+
12+
function Base.show(io::IO, layer::LinearLayer)
13+
return print(io, "LinearLayer(", size(layer.W, 2), ", ", size(layer.W, 1), ")")
14+
end
15+
16+
# Product function, given an 2d array whose size is M×N, product layer will
17+
# multiply every m neighboring rows of the array elementwisely to obtain
18+
# an new array of size (M÷m)×N
19+
function product(x, step=2)
20+
m, n = size(x)
21+
m % step == 0 || error("the first dimension of inputs must be multiple of step")
22+
new_x = reshape(x, step, m ÷ step, n)
23+
return .*([new_x[i, :, :] for i in 1:step]...)
24+
end
25+
26+
# Primitive layer, mainly act as a container to hold basic kernels for the neural kernel network
27+
struct Primitive{T}
28+
kernels::T
29+
Primitive(ks...) = new{typeof(ks)}(ks)
30+
end
31+
@functor Primitive
32+
33+
# flatten k kernel matrices of size Mk×Nk, and concatenate these 1d array into a k×(Mk*Nk) 2d array
34+
_cat_kernel_array(x) = vcat([reshape(x[i], 1, :) for i in 1:length(x)]...)
35+
36+
# NOTE, though we implement `ew` & `pw` function for Primitive, it isn't a subtype of Kernel
37+
# type, I do this because it will facilitate writing NeuralKernelNetwork
38+
ew(p::Primitive, x) = _cat_kernel_array(map(k -> kernelmatrix_diag(k, x), p.kernels))
39+
pw(p::Primitive, x) = _cat_kernel_array(map(k -> kernelmatrix(k, x), p.kernels))
40+
41+
function ew(p::Primitive, x, x′)
42+
return _cat_kernel_array(map(k -> kernelmatrix_diag(k, x, x′), p.kernels))
43+
end
44+
pw(p::Primitive, x, x′) = _cat_kernel_array(map(k -> kernelmatrix(k, x, x′), p.kernels))
45+
46+
function Base.show(io::IO, layer::Primitive)
47+
print(io, "Primitive(")
48+
join(io, layer.kernels, ", ")
49+
return print(io, ")")
50+
end
51+
52+
"""
53+
NeuralKernelNetwork(primitives, nn)
54+
55+
Constructs a Neural Kernel Network (NKN) [1].
56+
57+
`primitives` are the based kernels, combined by `nn`.
58+
59+
```julia
60+
k1 = 0.6 * (SEKernel() ∘ ScaleTransform(0.5))
61+
k2 = 0.4 * (Matern32Kernel() ∘ ScaleTransform(0.1))
62+
primitives = Primitive(k1, k2)
63+
nkn = NeuralKernelNetwork(primitives, Chain(LinearLayer(2, 2), product))
64+
```
65+
66+
[1] - Sun, Shengyang, et al. "Differentiable compositional kernel learning for Gaussian
67+
processes." International Conference on Machine Learning. PMLR, 2018.
68+
"""
69+
struct NeuralKernelNetwork{PT,NNT} <: Kernel
70+
primitives::PT
71+
nn::NNT
72+
end
73+
@functor NeuralKernelNetwork
74+
75+
# use this function to reshape the 1d array back to kernel matrix
76+
_rebuild_kernel(x, n, m) = reshape(x, n, m)
77+
_rebuild_diag(x) = reshape(x, :)
78+
79+
::NeuralKernelNetwork)(x, y) = only(kernelmatrix(κ, [x], [y]))
80+
81+
function kernelmatrix_diag(nkn::NeuralKernelNetwork, x::AbstractVector)
82+
return _rebuild_diag(nkn.nn(ew(nkn.primitives, x)))
83+
end
84+
85+
function kernelmatrix(nkn::NeuralKernelNetwork, x::AbstractVector)
86+
return _rebuild_kernel(nkn.nn(pw(nkn.primitives, x)), length(x), length(x))
87+
end
88+
89+
function kernelmatrix_diag(nkn::NeuralKernelNetwork, x::AbstractVector, x′::AbstractVector)
90+
return _rebuild_diag(nkn.nn(ew(nkn.primitives, x, x′)))
91+
end
92+
93+
function kernelmatrix(nkn::NeuralKernelNetwork, x::AbstractVector, x′::AbstractVector)
94+
return _rebuild_kernel(nkn.nn(pw(nkn.primitives, x, x′)), length(x), length(x′))
95+
end
96+
97+
function kernelmatrix_diag!(K::AbstractVector, nkn::NeuralKernelNetwork, x::AbstractVector)
98+
K .= kernelmatrix_diag(nkn, x)
99+
return K
100+
end
101+
102+
function kernelmatrix!(K::AbstractMatrix, nkn::NeuralKernelNetwork, x::AbstractVector)
103+
K .= kernelmatrix(nkn, x)
104+
return K
105+
end
106+
107+
function kernelmatrix_diag!(
108+
K::AbstractVector, nkn::NeuralKernelNetwork, x::AbstractVector, x′::AbstractVector
109+
)
110+
K .= kernelmatrix_diag(nkn, x, x′)
111+
return K
112+
end
113+
114+
function kernelmatrix!(
115+
K::AbstractMatrix, nkn::NeuralKernelNetwork, x::AbstractVector, x′::AbstractVector
116+
)
117+
K .= kernelmatrix(nkn, x, x′)
118+
return K
119+
end
120+
121+
function Base.show(io::IO, kernel::NeuralKernelNetwork)
122+
print(io, "NeuralKernelNetwork(")
123+
join(io, [kernel.primitives, kernel.nn], ", ")
124+
return print(io, ")")
125+
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
88
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1011
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
@@ -21,6 +22,7 @@ Documenter = "0.25, 0.26"
2122
FiniteDifferences = "0.10.8, 0.11, 0.12"
2223
Flux = "0.10, 0.11, 0.12"
2324
ForwardDiff = "0.10"
25+
LogExpFunctions = "0.2"
2426
Kronecker = "0.4"
2527
PDMats = "0.9, 0.10, 0.11"
2628
ReverseDiff = "1.2"

test/kernels/neuralkernelnetwork.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using KernelFunctions: NeuralKernelNetwork, LinearLayer, product, Primitive
2+
3+
@testset "neural_kernel_network" begin
4+
rng, N, N′, D = MersenneTwister(123456), 5, 6, 2
5+
x0 = collect(range(-2.0, 2.0; length=N)) .+ 1e-3 .* randn(rng, N)
6+
x1 = collect(range(-1.7, 2.3; length=N)) .+ 1e-3 .* randn(rng, N)
7+
x2 = collect(range(-1.7, 3.3; length=N′)) .+ 1e-3 .* randn(rng, N′)
8+
9+
X0 = ColVecs(randn(rng, D, N))
10+
X1 = ColVecs(randn(rng, D, N))
11+
X2 = ColVecs(randn(rng, D, N′))
12+
13+
# Most of the NeuralKernelNetwork tests are currently broken.
14+
@testset "general test" begin
15+
16+
# Specify primitives.
17+
k1 = 0.6 * (SEKernel() ScaleTransform(0.5))
18+
k2 = 0.4 * (Matern32Kernel() ScaleTransform(0.1))
19+
primitives = Primitive(k1, k2)
20+
21+
# Build NKN Kernel.
22+
nkn = NeuralKernelNetwork(primitives, Chain(LinearLayer(2, 2), product))
23+
24+
# Apply standard test suite.
25+
TestUtils.test_interface(nkn, Float64)
26+
end
27+
@testset "kernel composition test" begin
28+
rng = MersenneTwister(123456)
29+
30+
# Specify primitives.
31+
k1 = rand(rng) * transform(SEKernel(), randn(rng))
32+
k2 = rand(rng) * transform(Matern32Kernel(), randn(rng))
33+
primitives = Primitive(k1, k2)
34+
35+
@testset "LinearLayer" begin
36+
# Specify linear NKN and equivalent composite kernel.
37+
weights = rand(rng, 1, 2)
38+
nkn_add_kernel = NeuralKernelNetwork(primitives, LinearLayer(weights))
39+
sum_k =
40+
LogExpFunctions.softplus(weights[1]) * k1 +
41+
LogExpFunctions.softplus(weights[2]) * k2
42+
43+
# Vector input.
44+
@test kernelmatrix_diag(nkn_add_kernel, x0) kernelmatrix_diag(sum_k, x0)
45+
@test kernelmatrix_diag(nkn_add_kernel, x0, x1)
46+
kernelmatrix_diag(sum_k, x0, x1)
47+
48+
# ColVecs input.
49+
@test kernelmatrix_diag(nkn_add_kernel, X0) kernelmatrix_diag(sum_k, X0)
50+
@test kernelmatrix_diag(nkn_add_kernel, X0, X1)
51+
kernelmatrix_diag(sum_k, X0, X1)
52+
end
53+
@testset "product" begin
54+
nkn_prod_kernel = NeuralKernelNetwork(primitives, product)
55+
prod_k = k1 * k2
56+
57+
# Vector input.
58+
@test kernelmatrix(nkn_prod_kernel, x0) kernelmatrix(prod_k, x0)
59+
@test kernelmatrix(nkn_prod_kernel, x0, x1) kernelmatrix(prod_k, x0, x1)
60+
61+
# ColVecs input.
62+
@test kernelmatrix(nkn_prod_kernel, X0) kernelmatrix(prod_k, X0)
63+
@test kernelmatrix(nkn_prod_kernel, X0, X1) kernelmatrix(prod_k, X0, X1)
64+
end
65+
end
66+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Distances
44
using Documenter
55
using Kronecker: Kronecker
66
using LinearAlgebra
7+
using LogExpFunctions
78
using PDMats
89
using Random
910
using SpecialFunctions
@@ -123,6 +124,7 @@ include("test_utils.jl")
123124
include(joinpath("kernels", "scaledkernel.jl"))
124125
include(joinpath("kernels", "transformedkernel.jl"))
125126
include(joinpath("kernels", "normalizedkernel.jl"))
127+
include(joinpath("kernels", "neuralkernelnetwork.jl"))
126128
end
127129
@info "Ran tests on Kernel"
128130

0 commit comments

Comments
 (0)