Skip to content

Commit 5e9d78d

Browse files
authored
Merge pull request #147 from JuliaGaussianProcesses/relax_input_and_correct_pairwise
Input check relaxation and change of `pairwise
2 parents 1bfda4d + 372f4cd commit 5e9d78d

File tree

7 files changed

+80
-26
lines changed

7 files changed

+80
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.4.5"
3+
version = "0.5.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

docs/src/create_kernel.md

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,41 @@
22

33
KernelFunctions.jl contains the most popular kernels already but you might want to make your own!
44

5-
Here is for example how one can define the Squared Exponential Kernel again :
5+
Here are a few ways depending on how complicated your kernel is :
6+
7+
### SimpleKernel for kernels function depending on a metric
8+
9+
If your kernel function is of the form `k(x, y) = f(d(x, y))` where `d(x, y)` is a `PreMetric`,
10+
you can construct your custom kernel by defining `kappa` and `metric` for your kernel.
11+
Here is for example how one can define the `SqExponentialKernel` again :
612

713
```julia
8-
struct MyKernel <: Kernel end
14+
struct MyKernel <: KernelFunctions.SimpleKernel end
915

1016
KernelFunctions.kappa(::MyKernel, d2::Real) = exp(-d2)
1117
KernelFunctions.metric(::MyKernel) = SqEuclidean()
1218
```
1319

14-
For a "Base" kernel, where the kernel function is simply a function applied on some metric between two vectors of real, you only need to:
15-
- Define your struct inheriting from `Kernel`.
16-
- Define a `kappa` function.
17-
- Define the metric used `SqEuclidean`, `DotProduct` etc. Note that the term "metric" is here overabused.
18-
- Optional : Define any parameter of your kernel as `trainable` by Flux.jl if you want to perform optimization on the parameters. We recommend wrapping all parameters in arrays to allow them to be mutable.
20+
### Kernel for more complex kernels
21+
22+
If your kernel does not satisfy such a representation, all you need to do is define `(k::MyKernel)(x, y)` and inherit from `Kernel`.
23+
For example we recreate here the `NeuralNetworkKernel`
24+
25+
```julia
26+
struct MyKernel <: KernelFunctions.Kernel end
27+
28+
(::MyKernel)(x, y) = asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
29+
```
30+
31+
Note that `BaseKernel` do not use `Distances.jl` and can therefore be a bit slower.
32+
33+
### Additional Options
1934

20-
Once these functions are defined, you can use all the wrapping functions of KernelFuntions.jl
35+
Finally there are additional functions you can define to bring in more features:
36+
- `KernelFunctions.trainable(k::MyKernel)`: it defines the trainable parameters of your kernel, it should return a `Tuple` of your parameters.
37+
These parameters will be passed to the `Flux.params` function. For some examples see the `trainable.jl` file in `src/`
38+
- `KernelFunctions.iskroncompatible(k::MyKernel)`: if your kernel factorizes in dimensions, you can declare your kernel as `iskroncompatible(k) = true` to use Kronecker methods.
39+
- `KernelFunctions.dim(x::MyDataType)`: by default the dimension of the inputs will only be checked for vectors of type `AbstractVector{<:Real}`. If you want to check the dimensionality of your inputs, dispatch the `dim` function on your datatype. Note that `0` is the default.
40+
- `dim` is called within `KernelFunctions.validate_inputs(x::MyDataType, y::MyDataType)`, which can instead be directly overloaded if you want to run special checks for your input types.
41+
- `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` functions to eventually optimize the computations.
42+
- `Base.print(io::IO, k::MyKernel)`: if you want to specialize the printing of your kernel

docs/src/metrics.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
KernelFunctions.jl relies on [Distances.jl](https://github.com/JuliaStats/Distances.jl) for computing the pairwise matrix.
44
To do so a distance measure is needed for each kernel. Two very common ones can already be used : `SqEuclidean` and `Euclidean`.
55
However all kernels do not rely on distances metrics respecting all the definitions. That's why additional metrics come with the package such as `DotProduct` (`<x,y>`) and `Delta` (`δ(x,y)`).
6-
Note that every `BaseKernel` must have a defined metric defined as :
6+
Note that every `SimpleKernel` must have a defined metric defined as :
77
```julia
8-
metric(::CustomKernel) = SqEuclidean()
8+
KernelFunctions.metric(::CustomKernel) = SqEuclidean()
99
```
1010

1111
## Adding a new metric

src/distances/pairwise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Add our own pairwise function to be able to apply it on vectors
22

3-
pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector) = broadcast(d, X, Y')
3+
pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector) = broadcast(d, X, permutedims(Y))
44

55
pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X)
66

src/matrix/kernelmatrix.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ end
5050
kernelmatrix::Kernel, x::AbstractVector) = kernelmatrix(κ, x, x)
5151

5252
function kernelmatrix::Kernel, x::AbstractVector, y::AbstractVector)
53-
validate_dims(x, y)
53+
validate_inputs(x, y)
5454
return κ.(x, permutedims(y))
5555
end
5656

@@ -89,7 +89,7 @@ function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
8989
end
9090

9191
function kernelmatrix::SimpleKernel, x::AbstractVector, y::AbstractVector)
92-
validate_dims(x, y)
92+
validate_inputs(x, y)
9393
return map(d -> kappa(κ, d), pairwise(metric(κ), x, y))
9494
end
9595

src/utils.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ function vec_of_vecs(X::AbstractMatrix; obsdim::Int = 2)
2020
end
2121
end
2222

23-
dim(x::AbstractVector{<:Real}) = 1
24-
dim(x::AbstractVector{Tuple{Any,Int}}) = 1
25-
2623
"""
2724
ColVecs(X::AbstractMatrix)
2825
@@ -94,9 +91,24 @@ For a transform return its parameters, for a `ChainTransform` return a vector of
9491
"""
9592
#params
9693

94+
dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype.
95+
dim(x::AbstractVector) = dim(first(x))
96+
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))
97+
dim(x::AbstractVector{<:Real}) = 1
98+
99+
100+
function validate_inputs(x, y)
101+
if dim(x) != dim(y) # Passes by default if `dim` is not defined
102+
throw(DimensionMismatch(
103+
"Dimensionality of x ($(dim(x))) not equality to that of y ($(dim(y)))",
104+
))
105+
end
106+
return nothing
107+
end
108+
97109

98110
function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector, y::AbstractVector)
99-
validate_dims(x, y)
111+
validate_inputs(x, y)
100112
if size(K) != (length(x), length(y))
101113
throw(DimensionMismatch(
102114
"Size of the target matrix K ($(size(K))) not consistent with lengths of " *
@@ -117,11 +129,3 @@ function validate_inplace_dims(K::AbstractVector, x::AbstractVector)
117129
))
118130
end
119131
end
120-
121-
function validate_dims(x::AbstractVector, y::AbstractVector)
122-
if dim(x) != dim(y)
123-
throw(DimensionMismatch(
124-
"Dimensionality of x ($(dim(x))) not equality to that of y ($(dim(y)))",
125-
))
126-
end
127-
end

test/utils.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,32 @@
7575
@test back(ones(size(X)))[1].X == ones(size(X))
7676
end
7777
end
78+
@testset "input checks" begin
79+
D = 3; D⁻ = 2
80+
N1 = 2; N2 = 3
81+
x = [rand(rng, D) for _ in 1:N1]
82+
x⁻ = [rand(rng, D⁻) for _ in 1:N1]
83+
y = [rand(rng, D) for _ in 1:N2]
84+
xx = [rand(rng, D, D) for _ in 1:N1]
85+
xx⁻ = [rand(rng, D, D⁻) for _ in 1:N1]
86+
yy = [rand(rng, D, D) for _ in 1:N2]
87+
88+
@test KernelFunctions.dim("string") == 0
89+
@test KernelFunctions.dim(["string", "string2"]) == 0
90+
@test KernelFunctions.dim(rand(rng, 4)) == 1
91+
@test KernelFunctions.dim(x) == D
92+
93+
@test_nowarn KernelFunctions.validate_inplace_dims(zeros(N1, N2), x, y)
94+
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N1, N1), x, y)
95+
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N1, N2), x⁻, y)
96+
@test_nowarn KernelFunctions.validate_inplace_dims(zeros(N1, N1), x)
97+
@test_nowarn KernelFunctions.validate_inplace_dims(zeros(N1), x)
98+
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N2), x)
99+
100+
@test_nowarn KernelFunctions.validate_inputs(x, y)
101+
@test_throws DimensionMismatch KernelFunctions.validate_inputs(x⁻, y)
102+
103+
@test_nowarn KernelFunctions.validate_inputs(xx, yy)
104+
@test_nowarn KernelFunctions.validate_inputs(xx⁻, yy)
105+
end
78106
end

0 commit comments

Comments
 (0)