Skip to content

Commit c006d47

Browse files
authored
Add NeuralNetOneKernel (JuliaGaussianProcesses#70)
* Add nnone kernel * Add NerualNetOneKernel * Update docstring * Rename and redefine kernel * Remove P * Add tests * Include tests in runtests.jl * Fix docstring * Add more tests * Fix bug in _kernel * Remove kappa and _kernel and custom kernelmatrix functions * Update docstring * Update docs * Update docs * Improve docstring style and remove wiener kernel's tests
1 parent b10684d commit c006d47

File tree

4 files changed

+75
-1
lines changed

4 files changed

+75
-1
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ export MaternKernel, Matern32Kernel, Matern52Kernel
1818
export LinearKernel, PolynomialKernel
1919
export RationalQuadraticKernel, GammaRationalQuadraticKernel
2020
export MahalanobisKernel, GaborKernel, PiecewisePolynomialKernel
21-
export PeriodicKernel
21+
export PeriodicKernel, NeuralNetworkKernel
2222
export KernelSum, KernelProduct
2323
export TransformedKernel, ScaledKernel
2424
export TensorProduct

src/basekernels/nn.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
NeuralNetworkKernel()
3+
4+
Neural network kernel function.
5+
6+
```math
7+
κ(x, y) = asin(x' * y / sqrt[(1 + x' * x) * (1 + y' * y)])
8+
```
9+
# Significance
10+
Neal (1996) pursued the limits of large models, and showed that a Bayesian neural network
11+
becomes a Gaussian process with a **neural network kernel** as the number of units
12+
approaches infinity. Here, we give the neural network kernel for single hidden layer
13+
Bayesian neural network with erf (Error Function) as activation function.
14+
15+
# References:
16+
- [GPML Pg 105](http://www.gaussianprocess.org/gpml/chapters/RW4.pdf)
17+
- [Neal(1996)](https://www.cs.toronto.edu/~radford/bnn.book.html)
18+
- [Andrew Gordon's Thesis Pg 45](http://www.cs.cmu.edu/~andrewgw/andrewgwthesis.pdf)
19+
"""
20+
struct NeuralNetworkKernel <: BaseKernel end
21+
22+
function::NeuralNetworkKernel)(x, y)
23+
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
24+
end
25+
26+
Base.show(io::IO, κ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")

test/basekernels/nn.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
@testset "nn" begin
2+
using LinearAlgebra
3+
k = NeuralNetworkKernel()
4+
v1 = rand(3); v2 = rand(3)
5+
@test k(v1,v2) asin(v1' * v2 / sqrt((1 + v1' * v1) * (1 + v2' * v2))) atol=1e-5
6+
7+
# kernelmatrix tests
8+
m1 = rand(3,4)
9+
m2 = rand(3,4)
10+
@test kernelmatrix(k, m1, m1) kernelmatrix(k, m1) atol=1e-5
11+
@test_broken kernelmatrix(k, m1, m2) k(m1, m2) atol=1e-5
12+
13+
14+
x1 = rand()
15+
x2 = rand()
16+
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] k(x1, x2) atol=1e-5
17+
18+
@test k(v1, v2) k(v1, v2) atol=1e-5
19+
@test typeof(k(v1, v2)) <: Real
20+
21+
@test_broken size(k(m1, m2)) == (4, 4)
22+
@test_broken size(k(m1)) == (4, 4)
23+
24+
A1 = ones(4, 4)
25+
kernelmatrix!(A1, k, m1, m2)
26+
@test A1 kernelmatrix(k, m1, m2) atol=1e-5
27+
28+
A2 = ones(4, 4)
29+
kernelmatrix!(A2, k, m1)
30+
@test A2 kernelmatrix(k, m1) atol=1e-5
31+
32+
@test size(kerneldiagmatrix(k, m1)) == (4,)
33+
A3 = kernelmatrix(k, m1)
34+
@test kerneldiagmatrix(k, m1) [A3[i, i] for i in 1:LinearAlgebra.checksquare(A3)] atol=1e-5
35+
36+
A4 = ones(4)
37+
kerneldiagmatrix!(A4, k, m1)
38+
@test kerneldiagmatrix(k, m1) A4 atol=1e-5
39+
40+
A5 = ones(4,4)
41+
@test_throws AssertionError kernelmatrix!(A5, k, m1, m2, obsdim=3)
42+
@test_throws AssertionError kernelmatrix!(A5, k, m1, obsdim=3)
43+
@test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4))
44+
45+
@test k([x1], [x2]) k(x1, x2) atol=1e-5
46+
47+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ using KernelFunctions: metric, kappa
7171
include(joinpath("basekernels", "gabor.jl"))
7272
include(joinpath("basekernels", "maha.jl"))
7373
include(joinpath("basekernels", "matern.jl"))
74+
include(joinpath("basekernels", "nn.jl"))
7475
include(joinpath("basekernels", "periodic.jl"))
7576
include(joinpath("basekernels", "polynomial.jl"))
7677
include(joinpath("basekernels", "piecewisepolynomial.jl"))

0 commit comments

Comments
 (0)