Skip to content

Fix method ambiguities #483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.48"
version = "0.10.49"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -14,6 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
2 changes: 2 additions & 0 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ using StatsBase
using TensorCore
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield

using SparseArrays: SparseArrays

# Hack to work around Zygote type inference problems.
const Distances_pairwise = Distances.pairwise

Expand Down
5 changes: 4 additions & 1 deletion src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

# Note that this is type piracy as the derivative should be NaN for x == y.
function ChainRulesCore.frule(
(_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector
(_, Δx, Δy)::Tuple{<:Any,<:Any,<:Any},
d::Distances.Euclidean,
x::AbstractVector,
y::AbstractVector,
)
Δ = x - y
D = sqrt(sum(abs2, Δ))
Expand Down
8 changes: 8 additions & 0 deletions src/kernels/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,13 @@ for (M, op, T) in (

$M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k)
$M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k))

# Fix method ambiguity issues
function $M.$op(ks1::$T, ks2::$T{<:AbstractVector{<:Kernel}})
return $T(vcat(collect(ks1.kernels), ks2.kernels))
end
function $M.$op(ks1::$T{<:AbstractVector{<:Kernel}}, ks2::$T)
return $T(vcat(ks1.kernels, collect(ks2.kernels)))
end
end
end
7 changes: 5 additions & 2 deletions src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ function ChainTransform(v, θ::AbstractVector)
end

Base.:∘(t₁::Transform, t₂::Transform) = ChainTransform((t₂, t₁))
Base.:∘(t::Transform, tc::ChainTransform) = ChainTransform(tuple(tc.transforms..., t))
Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(tuple(t, tc.transforms...))
Base.:∘(t::Transform, tc::ChainTransform) = ChainTransform((tc.transforms..., t))
Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform((t, tc.transforms...))
function Base.:∘(tc1::ChainTransform, tc2::ChainTransform)
return ChainTransform((tc2.transforms..., tc1.transforms...))
end

(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)

Expand Down
3 changes: 3 additions & 0 deletions src/transform/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ abstract type Transform end
Base.map(t::Transform, x::AbstractVector) = _map(t, x)
_map(t::Transform, x::AbstractVector) = t.(x)

# Fix method ambiguity issues
Base.map(t::Transform, x::SparseArrays.SparseVector) = _map(t, x)

"""
IdentityTransform()

Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ _to_colvecs(x::AbstractVector{<:Real}) = ColVecs(reshape(x, 1, :))

pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2)
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)
function pairwise(d::PreMetric, x::AbstractVector{<:AbstractVector{<:Real}}, y::ColVecs)
return Distances_pairwise(d, reduce(hcat, x), y.X; dims=2)
end
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector)
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector{<:AbstractVector{<:Real}})
return Distances_pairwise(d, x.X, reduce(hcat, y); dims=2)
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
Expand Down Expand Up @@ -172,10 +172,10 @@ dim(x::RowVecs) = size(x.X, 2)

pairwise(d::PreMetric, x::RowVecs) = Distances_pairwise(d, x.X; dims=1)
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1)
function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs)
function pairwise(d::PreMetric, x::AbstractVector{<:AbstractVector{<:Real}}, y::RowVecs)
return Distances_pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1)
end
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector)
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector{<:AbstractVector{<:Real}})
return Distances_pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1)
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
Expand Down
6 changes: 6 additions & 0 deletions test/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
k2 = SqExponentialKernel()
k = KernelProduct(k1, k2)
@test k == KernelProduct([k1, k2]) == KernelProduct((k1, k2))
for (_k1, _k2) in Iterators.product(
(k1, KernelProduct((k1,)), KernelProduct([k1])),
(k2, KernelProduct((k2,)), KernelProduct([k2])),
)
@test k == _k1 * _k2
end
@test length(k) == 2
@test string(k) == (
"Product of 2 kernels:\n\tLinear Kernel (c = 0.0)\n\tSquared " *
Expand Down
6 changes: 6 additions & 0 deletions test/kernels/kernelsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
k2 = SqExponentialKernel()
k = KernelSum(k1, k2)
@test k == KernelSum([k1, k2]) == KernelSum((k1, k2))
for (_k1, _k2) in Iterators.product(
(k1, KernelSum((k1,)), KernelSum([k1])),
(k2, KernelSum((k2,)), KernelSum([k2])),
)
@test k == _k1 + _k2
end
@test length(k) == 2
@test string(k) == (
"Sum of 2 kernels:\n" *
Expand Down
6 changes: 6 additions & 0 deletions test/kernels/kerneltensorproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@

@test kernel1 == kernel2
@test kernel1.kernels === (k1, k2) === KernelTensorProduct((k1, k2)).kernels
for (_k1, _k2) in Iterators.product(
(k1, KernelTensorProduct((k1,)), KernelTensorProduct([k1])),
(k2, KernelTensorProduct((k2,)), KernelTensorProduct([k2])),
)
@test kernel1 == _k1 ⊗ _k2
end
@test length(kernel1) == length(kernel2) == 2
@test string(kernel1) == (
"Tensor product of 2 kernels:\n" *
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ include("test_utils.jl")
if GROUP == "" || GROUP == "Others"
include("utils.jl")

@test isempty(detect_unbound_args(KernelFunctions))
@testset "general" begin
@test isempty(detect_unbound_args(KernelFunctions))
@test isempty(detect_ambiguities(KernelFunctions))
end

@testset "distances" begin
include("distances/pairwise.jl")
Expand Down
1 change: 1 addition & 0 deletions test/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Check composition constructors.
@test (tf ∘ ChainTransform([tp])).transforms == (tp, tf)
@test (ChainTransform([tf]) ∘ tp).transforms == (tp, tf)
@test (ChainTransform([tf]) ∘ ChainTransform([tp])).transforms == (tp, tf)

# Verify correctness.
x = ColVecs(randn(rng, 2, 3))
Expand Down