Skip to content

Commit 716285c

Browse files
authored
Merge pull request #42 from theogf/wct/fix-typo
Fix typo
2 parents 2d846fe + 06c3519 commit 716285c

File tree

3 files changed

+29
-29
lines changed

3 files changed

+29
-29
lines changed

src/kernels/kernelsum.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,28 @@ function KernelSum(
2222
)
2323
@assert length(kernels) == length(weights) "Weights and kernel vector should be of the same length"
2424
@assert all(weights .>= 0) "All weights should be positive"
25-
KernelSum(kernels, weights)
25+
return KernelSum(kernels, weights)
2626
end
2727

2828
params(k::KernelSum) = (k.weights, params.(k.kernels))
2929
opt_params(k::KernelSum) = (k.weights, opt_params.(k.kernels))
3030

3131
Base.:+(k1::Kernel, k2::Kernel) = KernelSum([k1, k2], weights = [1.0, 1.0])
32-
Base.:+(k1::ScaledKernel, k2::ScaledKernel) = KernelSum([kernel(k1), kernel(k2)], weights = [first(k1.σ), first(k2.σ)])
32+
Base.:+(k1::ScaledKernel, k2::ScaledKernel) = KernelSum([kernel(k1), kernel(k2)], weights = [first(k1.σ²), first(k2.σ²)])
3333
Base.:+(k1::KernelSum, k2::KernelSum) =
3434
KernelSum(vcat(k1.kernels, k2.kernels), weights = vcat(k1.weights, k2.weights))
3535
Base.:+(k::Kernel, ks::KernelSum) =
3636
KernelSum(vcat(k, ks.kernels), weights = vcat(1.0, ks.weights))
3737
Base.:+(k::ScaledKernel, ks::KernelSum) =
38-
KernelSum(vcat(kernel(k), ks.kernels), weights = vcat(first(k.σ), ks.weights))
38+
KernelSum(vcat(kernel(k), ks.kernels), weights = vcat(first(k.σ²), ks.weights))
3939
Base.:+(k::ScaledKernel, ks::Kernel) =
40-
KernelSum(vcat(kernel(k), ks), weights = vcat(first(k.σ), 1.0))
40+
KernelSum(vcat(kernel(k), ks), weights = vcat(first(k.σ²), 1.0))
4141
Base.:+(ks::KernelSum, k::Kernel) =
4242
KernelSum(vcat(ks.kernels, k), weights = vcat(ks.weights, 1.0))
4343
Base.:+(ks::KernelSum, k::ScaledKernel) =
44-
KernelSum(vcat(ks.kernels, kernel(k)), weights = vcat(ks.weights, first(k.σ)))
44+
KernelSum(vcat(ks.kernels, kernel(k)), weights = vcat(ks.weights, first(k.σ²)))
4545
Base.:+(ks::Kernel, k::ScaledKernel) =
46-
KernelSum(vcat(ks, kernel(k)), weights = vcat(1.0, first(k.σ)))
46+
KernelSum(vcat(ks, kernel(k)), weights = vcat(1.0, first(k.σ²)))
4747
Base.:*(w::Real, k::KernelSum) = KernelSum(k.kernels, weights = w * k.weights) #TODO add tests
4848

4949
Base.length(k::KernelSum) = length(k.kernels)

src/kernels/scaledkernel.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
1-
struct ScaledKernel{Tk<:Kernel,<:Real} <: Kernel
1+
struct ScaledKernel{Tk<:Kernel, Tσ²<:Real} <: Kernel
22
kernel::Tk
3-
σ::Vector{Tσ}
3+
σ²::Vector{Tσ²}
44
end
55

6-
function ScaledKernel(kernel::Tk,σ::Tσ=1.0) where {Tk<:Kernel,Tσ<:Real}
7-
@check_args(ScaledKernel, σ, σ > zero(Tσ), "σ > 0")
8-
ScaledKernel{Tk,}(kernel,])
6+
function ScaledKernel(kernel::Tk, σ²::Tσ²=1.0) where {Tk<:Kernel,Tσ²<:Real}
7+
@check_args(ScaledKernel, σ², σ² > zero(Tσ²), "σ² > 0")
8+
return ScaledKernel{Tk, Tσ²}(kernel, [σ²])
99
end
1010

11-
kappa(k::ScaledKernel, x) = first(k.σ)*kappa(k.kernel, x)
11+
kappa(k::ScaledKernel, x) = first(k.σ²) * kappa(k.kernel, x)
1212

1313
metric(k::ScaledKernel) = metric(k.kernel)
1414

15-
params(k::ScaledKernel) = (k.σ,params(k.kernel))
16-
opt_params(k::ScaledKernel) = (k.σ,opt_params(k.kernel))
15+
params(k::ScaledKernel) = (k.σ², params(k.kernel))
16+
opt_params(k::ScaledKernel) = (k.σ², opt_params(k.kernel))
1717

18-
Base.:*(w::Real,k::Kernel) = ScaledKernel(k,w)
18+
Base.:*(w::Real, k::Kernel) = ScaledKernel(k, w)
1919

20-
Base.show(io::IO::ScaledKernel) = printshifted(io,κ,0)
20+
Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)
2121

22-
function printshifted(io::IO::ScaledKernel,shift::Int)
23-
printshifted(io,κ.kernel,shift)
24-
print(io,"\n"*("\t"^(shift+1))*"- σ = $(first.σ))")
22+
function printshifted(io::IO, κ::ScaledKernel, shift::Int)
23+
printshifted(io, κ.kernel, shift)
24+
print(io,"\n" * ("\t"^(shift+1)) * "- σ² = $(first.σ²))")
2525
end

test/runtests.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ using Distances
44
using Random
55

66
@testset "KernelFunctions" begin
7-
include("test_kernelmatrix.jl")
8-
include("test_approximations.jl")
9-
include("test_constructors.jl")
10-
# include("test_AD.jl")
11-
include("test_transform.jl")
12-
include("test_distances.jl")
13-
include("test_kernels.jl")
14-
include("test_generic.jl")
15-
include("test_adjoints.jl")
16-
include("test_custom.jl")
7+
include("test_kernelmatrix.jl")
8+
include("test_approximations.jl")
9+
include("test_constructors.jl")
10+
# include("test_AD.jl")
11+
include("test_transform.jl")
12+
include("test_distances.jl")
13+
include("test_kernels.jl")
14+
include("test_generic.jl")
15+
include("test_adjoints.jl")
16+
include("test_custom.jl")
1717
#include("types.jl")
1818
end

0 commit comments

Comments
 (0)