Skip to content

Input check relaxation and change of `pairwise #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 4, 2020
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.4.5"
version = "0.5.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
38 changes: 30 additions & 8 deletions docs/src/create_kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,41 @@

KernelFunctions.jl contains the most popular kernels already but you might want to make your own!

Here is for example how one can define the Squared Exponential Kernel again :
Here are a few ways depending on how complicated your kernel is :

### SimpleKernel for kernels function depending on a metric

If your kernel function is of the form `k(x, y) = f(d(x, y))` where `d(x, y)` is a `PreMetric`,
you can construct your custom kernel by defining `kappa` and `metric` for your kernel.
Here is for example how one can define the `SqExponentialKernel` again :

```julia
struct MyKernel <: Kernel end
struct MyKernel <: KernelFunctions.SimpleKernel end

KernelFunctions.kappa(::MyKernel, d2::Real) = exp(-d2)
KernelFunctions.metric(::MyKernel) = SqEuclidean()
```

For a "Base" kernel, where the kernel function is simply a function applied on some metric between two vectors of real, you only need to:
- Define your struct inheriting from `Kernel`.
- Define a `kappa` function.
- Define the metric used `SqEuclidean`, `DotProduct` etc. Note that the term "metric" is here overabused.
- Optional : Define any parameter of your kernel as `trainable` by Flux.jl if you want to perform optimization on the parameters. We recommend wrapping all parameters in arrays to allow them to be mutable.
### Kernel for more complex kernels

If your kernel does not satisfy such a representation, all you need to do is define `(k::MyKernel)(x, y)` and inherit from `Kernel`.
For example we recreate here the `NeuralNetworkKernel`

```julia
struct MyKernel <: KernelFunctions.Kernel end

(::MyKernel)(x, y) = asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
```

Note that `BaseKernel` do not use `Distances.jl` and can therefore be a bit slower.

### Additional Options

Once these functions are defined, you can use all the wrapping functions of KernelFuntions.jl
Finally there are additional functions you can define to bring in more features:
- `KernelFunctions.trainable(k::MyKernel)`: it defines the trainable parameters of your kernel, it should return a `Tuple` of your parameters.
These parameters will be passed to the `Flux.params` function. For some examples see the `trainable.jl` file in `src/`
- `KernelFunctions.iskroncompatible(k::MyKernel)`: if your kernel factorizes in dimensions, you can declare your kernel as `iskroncompatible(k) = true` to use Kronecker methods.
- `KernelFunctions.dim(x::MyDataType)`: by default the dimension of the inputs will only be checked for vectors of type `AbstractVector{<:Real}`. If you want to check the dimensionality of your inputs, dispatch the `dim` function on your datatype. Note that `0` is the default.
- `dim` is called within `KernelFunctions.validate_inputs(x::MyDataType, y::MyDataType)`, which can instead be directly overloaded if you want to run special checks for your input types.
- `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` functions to eventually optimize the computations.
- `Base.print(io::IO, k::MyKernel)`: if you want to specialize the printing of your kernel
4 changes: 2 additions & 2 deletions docs/src/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
KernelFunctions.jl relies on [Distances.jl](https://github.com/JuliaStats/Distances.jl) for computing the pairwise matrix.
To do so a distance measure is needed for each kernel. Two very common ones can already be used : `SqEuclidean` and `Euclidean`.
However all kernels do not rely on distances metrics respecting all the definitions. That's why additional metrics come with the package such as `DotProduct` (`<x,y>`) and `Delta` (`δ(x,y)`).
Note that every `BaseKernel` must have a defined metric defined as :
Note that every `SimpleKernel` must have a defined metric defined as :
```julia
metric(::CustomKernel) = SqEuclidean()
KernelFunctions.metric(::CustomKernel) = SqEuclidean()
```

## Adding a new metric
Expand Down
2 changes: 1 addition & 1 deletion src/distances/pairwise.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Add our own pairwise function to be able to apply it on vectors

pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector) = broadcast(d, X, Y')
pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector) = broadcast(d, X, permutedims(Y))

pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X)

Expand Down
4 changes: 2 additions & 2 deletions src/matrix/kernelmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
kernelmatrix(κ::Kernel, x::AbstractVector) = kernelmatrix(κ, x, x)

function kernelmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector)
validate_dims(x, y)
validate_inputs(x, y)
return κ.(x, permutedims(y))
end

Expand Down Expand Up @@ -89,7 +89,7 @@ function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
end

function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector)
validate_dims(x, y)
validate_inputs(x, y)
return map(d -> kappa(κ, d), pairwise(metric(κ), x, y))
end

Expand Down
28 changes: 16 additions & 12 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ function vec_of_vecs(X::AbstractMatrix; obsdim::Int = 2)
end
end

dim(x::AbstractVector{<:Real}) = 1
dim(x::AbstractVector{Tuple{Any,Int}}) = 1

"""
ColVecs(X::AbstractMatrix)

Expand Down Expand Up @@ -94,9 +91,24 @@ For a transform return its parameters, for a `ChainTransform` return a vector of
"""
#params

dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype.
dim(x::AbstractVector) = dim(first(x))
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))
dim(x::AbstractVector{<:Real}) = 1


function validate_inputs(x, y)
if dim(x) != dim(y) # Passes by default if `dim` is not defined
throw(DimensionMismatch(
"Dimensionality of x ($(dim(x))) not equality to that of y ($(dim(y)))",
))
end
return nothing
end


function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector, y::AbstractVector)
validate_dims(x, y)
validate_inputs(x, y)
if size(K) != (length(x), length(y))
throw(DimensionMismatch(
"Size of the target matrix K ($(size(K))) not consistent with lengths of " *
Expand All @@ -117,11 +129,3 @@ function validate_inplace_dims(K::AbstractVector, x::AbstractVector)
))
end
end

function validate_dims(x::AbstractVector, y::AbstractVector)
if dim(x) != dim(y)
throw(DimensionMismatch(
"Dimensionality of x ($(dim(x))) not equality to that of y ($(dim(y)))",
))
end
end
28 changes: 28 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,32 @@
@test back(ones(size(X)))[1].X == ones(size(X))
end
end
@testset "input checks" begin
D = 3; D⁻ = 2
N1 = 2; N2 = 3
x = [rand(rng, D) for _ in 1:N1]
x⁻ = [rand(rng, D⁻) for _ in 1:N1]
y = [rand(rng, D) for _ in 1:N2]
xx = [rand(rng, D, D) for _ in 1:N1]
xx⁻ = [rand(rng, D, D⁻) for _ in 1:N1]
yy = [rand(rng, D, D) for _ in 1:N2]

@test KernelFunctions.dim("string") == 0
@test KernelFunctions.dim(["string", "string2"]) == 0
@test KernelFunctions.dim(rand(rng, 4)) == 1
@test KernelFunctions.dim(x) == D

@test_nowarn KernelFunctions.validate_inplace_dims(zeros(N1, N2), x, y)
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N1, N1), x, y)
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N1, N2), x⁻, y)
@test_nowarn KernelFunctions.validate_inplace_dims(zeros(N1, N1), x)
@test_nowarn KernelFunctions.validate_inplace_dims(zeros(N1), x)
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N2), x)

@test_nowarn KernelFunctions.validate_inputs(x, y)
@test_throws DimensionMismatch KernelFunctions.validate_inputs(x⁻, y)

@test_nowarn KernelFunctions.validate_inputs(xx, yy)
@test_nowarn KernelFunctions.validate_inputs(xx⁻, yy)
end
end