-
Notifications
You must be signed in to change notification settings - Fork 36
Use tuples in KernelSum and KernelProduct; Make Zygote tests pass for KernelSum and KernelProduct; Improve doctring. #146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 15 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
4dc64b0
Use tuples in KernelSum and KernelProduct
sharanry 4d6fb2b
Remove weights in KernelSum
sharanry eb191c4
Add jldocs and remove duplicate functions
sharanry 52d346d
Revert userguide and fix formatting
sharanry be77e9c
Add more jldoctests
sharanry b56ec5c
Fix test/trainable.jl
sharanry 9d6a025
Add detailed docstring
sharanry 991453e
Iterate over kernels
sharanry 7bfee65
Sum/product of array of kernels results in array of kernels
sharanry dc13bf9
Modify show
sharanry e4a9bf3
Fix AD test
sharanry 7de1ae4
Add tests for show function
sharanry 6c29199
Zygote tests pass now
sharanry b0ddd11
Modify + and * functions.
sharanry 4fe48ef
Define '==' in TensorProduct
sharanry e2bfafc
Address code review
sharanry 44b7b0d
Update kernels.md
sharanry b2f2ac1
Patch bump
sharanry ac2871f
Merge branch 'master' into sharan/use-tuples
sharanry File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,103 @@ | ||
""" | ||
KernelProduct(kernels::Array{Kernel}) | ||
KernelProduct <: Kernel | ||
|
||
Create a product of kernels. | ||
One can also use the operator `*` : | ||
Create a product of kernels. One can also use the overloaded operator `*`. | ||
|
||
There are various ways in which you create a `KernelProduct`: | ||
|
||
The simplest way to sepcify a `KernelProduct` would be to use the overloaded `*` operator. This is | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
equivalent to creating a `KernelProduct` by specifying the kernels as the arguments to the constructor. | ||
```jldoctest kernelprod | ||
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5); | ||
|
||
julia> (k = k1 * k2) == KernelProduct(k1, k2) | ||
true | ||
|
||
julia> kernelmatrix(k1 * k2, X) == kernelmatrix(k1, X) .* kernelmatrix(k2, X) | ||
true | ||
|
||
julia> kernelmatrix(k, X) == kernelmatrix(k1 * k2, X) | ||
true | ||
``` | ||
k1 = SqExponentialKernel() | ||
k2 = LinearKernel() | ||
k = KernelProduct([k1, k2]) == k1 * k2 | ||
kernelmatrix(k, X) == kernelmatrix(k1, X) .* kernelmatrix(k2, X) | ||
kernelmatrix(k, X) == kernelmatrix(k1 * k2, X) | ||
|
||
You could also use specify a `KernelProduct` by providing a `Tuple` or a `Vector` of the | ||
kernels to be summed. We suggest you to use a `Tuple` when you have fewer components | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and a `Vector` when dealing with large number of components. | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
```jldoctest kernelprod | ||
julia> KernelProduct((k1, k2)) == k1 * k2 | ||
true | ||
|
||
julia> KernelProduct([k1, k2]) == KernelProduct((k1, k2)) == k1 * k2 | ||
true | ||
``` | ||
""" | ||
struct KernelProduct <: Kernel | ||
kernels::Vector{Kernel} | ||
struct KernelProduct{Tk} <: Kernel | ||
kernels::Tk | ||
end | ||
|
||
function KernelProduct(kernel::Kernel, kernels::Kernel...) | ||
return KernelProduct((kernel, kernels...)) | ||
end | ||
|
||
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2]) | ||
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test | ||
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels)) | ||
Base.:*(kp::KernelProduct,k::Kernel) = KernelProduct(vcat(kp.kernels,k)) | ||
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct(k1, k2) | ||
|
||
function Base.:*( | ||
k1::KernelProduct{<:AbstractVector{<:Kernel}}, | ||
k2::KernelProduct{<:AbstractVector{<:Kernel}} | ||
) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
KernelProduct(vcat(k1.kernels, k2.kernels)) | ||
end | ||
|
||
function Base.:*(k1::KernelProduct,k2::KernelProduct) | ||
return KernelProduct(k1.kernels..., k2.kernels...) #TODO Add test | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
function Base.:*(k::Kernel, ks::KernelProduct{<:AbstractVector{<:Kernel}}) | ||
KernelProduct(vcat(k, ks.kernels)) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(k, kp.kernels...) | ||
|
||
function Base.:*(ks::KernelProduct{<:AbstractVector{<:Kernel}}, k::Kernel) | ||
KernelProduct(vcat(ks.kernels, k)) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
Base.:*(kp::KernelProduct,k::Kernel) = KernelProduct(kp.kernels..., k) | ||
|
||
Base.length(k::KernelProduct) = length(k.kernels) | ||
|
||
(κ::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels) | ||
|
||
function kernelmatrix(κ::KernelProduct, x::AbstractVector) | ||
return reduce(hadamard, kernelmatrix(κ.kernels[i], x) for i in 1:length(κ)) | ||
return reduce(hadamard, kernelmatrix(k, x) for k in κ.kernels) | ||
end | ||
|
||
function kernelmatrix(κ::KernelProduct, x::AbstractVector, y::AbstractVector) | ||
return reduce(hadamard, kernelmatrix(κ.kernels[i], x, y) for i in 1:length(κ)) | ||
return reduce(hadamard, kernelmatrix(k, x, y) for k in κ.kernels) | ||
end | ||
|
||
function kerneldiagmatrix(κ::KernelProduct, x::AbstractVector) | ||
return reduce(hadamard, kerneldiagmatrix(κ.kernels[i], x) for i in 1:length(κ)) | ||
return reduce(hadamard, kerneldiagmatrix(k, x) for k in κ.kernels) | ||
end | ||
|
||
function Base.show(io::IO, κ::KernelProduct) | ||
printshifted(io, κ, 0) | ||
end | ||
|
||
function Base.:(==)(x::KernelProduct, y::KernelProduct) | ||
return ( | ||
length(x.kernels) == length(y.kernels) && | ||
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels)) | ||
) | ||
end | ||
|
||
function printshifted(io::IO, κ::KernelProduct, shift::Int) | ||
print(io, "Product of $(length(κ)) kernels:") | ||
for i in 1:length(κ) | ||
print(io, "\n" * ("\t" ^ (shift + 1))* "- ") | ||
printshifted(io, κ.kernels[i], shift + 2) | ||
for k in κ.kernels | ||
print(io, "\n" ) | ||
for _ in 1:(shift + 1) | ||
print(io, "\t") | ||
end | ||
printshifted(io, k, shift + 2) | ||
end | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,101 @@ | ||
""" | ||
KernelSum(kernels::Array{Kernel}; weights::Array{Real}=ones(length(kernels))) | ||
KernelSum <: Kernel | ||
|
||
Create a positive weighted sum of kernels. All weights should be positive. | ||
One can also use the operator `+` | ||
Create a sum of kernels. One can also use the operator `+`. | ||
|
||
There are various ways in which you create a `KernelSum`: | ||
|
||
The simplest way to sepcify a `KernelSum` would be to use the overloaded `+` operator. This is | ||
equivalent to creating a `KernelSum` by specifying the kernels as the arguments to the constructor. | ||
```jldoctest kernelsum | ||
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5); | ||
|
||
julia> (k = k1 + k2) == KernelSum(k1, k2) | ||
true | ||
|
||
julia> kernelmatrix(k1 + k2, X) == kernelmatrix(k1, X) .+ kernelmatrix(k2, X) | ||
true | ||
|
||
julia> kernelmatrix(k, X) == kernelmatrix(k1 + k2, X) | ||
true | ||
``` | ||
k1 = SqExponentialKernel() | ||
k2 = LinearKernel() | ||
k = KernelSum([k1, k2]) == k1 + k2 | ||
kernelmatrix(k, X) == kernelmatrix(k1, X) .+ kernelmatrix(k2, X) | ||
kernelmatrix(k, X) == kernelmatrix(k1 + k2, X) | ||
kweighted = 0.5* k1 + 2.0*k2 == KernelSum([k1, k2], weights = [0.5, 2.0]) | ||
|
||
You could also use specify a `KernelSum` by providing a `Tuple` or a `Vector` of the | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
kernels to be summed. We suggest you to use a `Tuple` when you have fewer components | ||
and a `Vector` when dealing with large number of components. | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
```jldoctest kernelsum | ||
julia> KernelSum((k1, k2)) == k1 + k2 | ||
true | ||
|
||
julia> KernelSum([k1, k2]) == KernelSum((k1, k2)) == k1 + k2 | ||
true | ||
``` | ||
""" | ||
struct KernelSum <: Kernel | ||
kernels::Vector{Kernel} | ||
weights::Vector{Real} | ||
struct KernelSum{Tk} <: Kernel | ||
kernels::Tk | ||
end | ||
|
||
function KernelSum(kernel::Kernel, kernels::Kernel...) | ||
return KernelSum((kernel, kernels...)) | ||
end | ||
|
||
function KernelSum( | ||
kernels::AbstractVector{<:Kernel}; | ||
weights::AbstractVector{<:Real} = ones(Float64, length(kernels)), | ||
) | ||
@assert length(kernels) == length(weights) "Weights and kernel vector should be of the same length" | ||
@assert all(weights .>= 0) "All weights should be positive" | ||
return KernelSum(kernels, weights) | ||
Base.:+(k1::Kernel, k2::Kernel) = KernelSum(k1, k2) | ||
|
||
function Base.:+( | ||
k1::KernelSum{<:AbstractVector{<:Kernel}}, | ||
k2::KernelSum{<:AbstractVector{<:Kernel}} | ||
) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
KernelSum(vcat(k1.kernels, k2.kernels)) | ||
end | ||
|
||
Base.:+(k1::Kernel, k2::Kernel) = KernelSum([k1, k2], weights = [1.0, 1.0]) | ||
Base.:+(k1::ScaledKernel, k2::ScaledKernel) = KernelSum([kernel(k1), kernel(k2)], weights = [first(k1.σ²), first(k2.σ²)]) | ||
Base.:+(k1::KernelSum, k2::KernelSum) = | ||
KernelSum(vcat(k1.kernels, k2.kernels), weights = vcat(k1.weights, k2.weights)) | ||
Base.:+(k::Kernel, ks::KernelSum) = | ||
KernelSum(vcat(k, ks.kernels), weights = vcat(1.0, ks.weights)) | ||
Base.:+(k::ScaledKernel, ks::KernelSum) = | ||
KernelSum(vcat(kernel(k), ks.kernels), weights = vcat(first(k.σ²), ks.weights)) | ||
Base.:+(k::ScaledKernel, ks::Kernel) = | ||
KernelSum(vcat(kernel(k), ks), weights = vcat(first(k.σ²), 1.0)) | ||
Base.:+(ks::KernelSum, k::Kernel) = | ||
KernelSum(vcat(ks.kernels, k), weights = vcat(ks.weights, 1.0)) | ||
Base.:+(ks::KernelSum, k::ScaledKernel) = | ||
KernelSum(vcat(ks.kernels, kernel(k)), weights = vcat(ks.weights, first(k.σ²))) | ||
Base.:+(ks::Kernel, k::ScaledKernel) = | ||
KernelSum(vcat(ks, kernel(k)), weights = vcat(1.0, first(k.σ²))) | ||
Base.:*(w::Real, k::KernelSum) = KernelSum(k.kernels, weights = w * k.weights) #TODO add tests | ||
Base.:+(k1::KernelSum, k2::KernelSum) = KernelSum(k1.kernels..., k2.kernels...) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function Base.:+(k::Kernel, ks::KernelSum{<:AbstractVector{<:Kernel}}) | ||
KernelSum(vcat(k, ks.kernels)) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
Base.:+(k::Kernel, ks::KernelSum) = KernelSum(k, ks.kernels...) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function Base.:+(ks::KernelSum{<:AbstractVector{<:Kernel}}, k::Kernel) | ||
KernelSum(vcat(ks.kernels, k)) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
Base.:+(ks::KernelSum, k::Kernel) = KernelSum(ks.kernels..., k) | ||
|
||
Base.length(k::KernelSum) = length(k.kernels) | ||
|
||
(κ::KernelSum)(x, y) = sum(κ.weights[i] * κ.kernels[i](x, y) for i in 1:length(κ)) | ||
(κ::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels) | ||
|
||
function kernelmatrix(κ::KernelSum, x::AbstractVector) | ||
return sum(κ.weights[i] * kernelmatrix(κ.kernels[i], x) for i in 1:length(κ)) | ||
return sum(kernelmatrix(k, x) for k in κ.kernels) | ||
end | ||
|
||
function kernelmatrix(κ::KernelSum, x::AbstractVector, y::AbstractVector) | ||
return sum(κ.weights[i] * kernelmatrix(κ.kernels[i], x, y) for i in 1:length(κ)) | ||
return sum(kernelmatrix(k, x, y) for k in κ.kernels) | ||
end | ||
|
||
function kerneldiagmatrix(κ::KernelSum, x::AbstractVector) | ||
return sum(κ.weights[i] * kerneldiagmatrix(κ.kernels[i], x) for i in 1:length(κ)) | ||
return sum(kerneldiagmatrix(k, x) for k in κ.kernels) | ||
end | ||
|
||
function Base.show(io::IO, κ::KernelSum) | ||
printshifted(io, κ, 0) | ||
end | ||
|
||
function Base.:(==)(x::KernelSum, y::KernelSum) | ||
return ( | ||
length(x.kernels) == length(y.kernels) && | ||
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels)) | ||
) | ||
end | ||
|
||
function printshifted(io::IO,κ::KernelSum, shift::Int) | ||
print(io,"Sum of $(length(κ)) kernels:") | ||
for i in 1:length(κ) | ||
print(io, "\n" * ("\t" ^ (shift + 1)) * "- (w = $(κ.weights[i])) ") | ||
printshifted(io, κ.kernels[i], shift + 2) | ||
for k in κ.kernels | ||
print(io, "\n" ) | ||
for _ in 1:(shift + 1) | ||
print(io, "\t") | ||
end | ||
printshifted(io, k, shift + 2) | ||
end | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.