Skip to content

Commit df8a06c

Browse files
authored
Merge branch 'master' into dependabot/github_actions/actions/upload-artifact-4
2 parents e7b0375 + 9b29249 commit df8a06c

File tree

15 files changed

+174
-25
lines changed

15 files changed

+174
-25
lines changed

.github/workflows/benchmark-comment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
echo ::set-output name=body::$(cat ./pull-request-number.artifact)
4242
# check if the previous comment exists
4343
- name: find comment
44-
uses: peter-evans/find-comment@v1
44+
uses: peter-evans/find-comment@v3
4545
id: fc
4646
with:
4747
issue-number: ${{ steps.output-pull-request-number.outputs.body }}

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.60"
3+
version = "0.10.63"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# KernelFunctions.jl
22

3-
![CI](https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/workflows/CI/badge.svg?branch=master)
3+
[![CI](https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/actions/workflows/ci.yml/badge.svg?branch=master)](https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/actions/workflows/ci.yml?query=branch%3Amaster)
44
[![codecov](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/branch/master/graph/badge.svg?token=rmDh3gb7hN)](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl)
55
[![Documentation (stable)](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable)
66
[![Documentation (latest)](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliagaussianprocesses.github.io/KernelFunctions.jl/dev)

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ export tensor, ⊗, compose
4848

4949
using Compat
5050
using ChainRulesCore: ChainRulesCore, Tangent, ZeroTangent, NoTangent
51-
using ChainRulesCore: @thunk, InplaceableThunk
51+
using ChainRulesCore: @thunk, InplaceableThunk, ProjectTo, unthunk
5252
using CompositionsBase
5353
using Distances
5454
using FillArrays

src/chainrules.jl

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,113 @@ end
111111

112112
function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
113113
d = x - y
114-
sind = sinpi.(d)
115-
abs2_sind_r = abs2.(sind) ./ s.r
114+
abs2_sind_r = (sinpi.(d) ./ s.r) .^ 2
116115
val = sum(abs2_sind_r)
117-
gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2)
116+
gradx = π .* sinpi.(2 .* d) ./ s.r .^ 2
118117
function evaluate_pullback::Any)
119-
return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx
118+
= -2Δ .* abs2_sind_r ./ s.r
119+
= ChainRulesCore.Tangent{typeof(s)}(; r=r̄)
120+
return s̄, Δ * gradx, -Δ * gradx
120121
end
121122
return val, evaluate_pullback
122123
end
123124

125+
function ChainRulesCore.rrule(
126+
::typeof(Distances.pairwise), d::Sinus, x::AbstractMatrix; dims=2
127+
)
128+
project_x = ProjectTo(x)
129+
function pairwise_pullback(z̄)
130+
Δ = unthunk(z̄)
131+
n = size(x, dims)
132+
= collect(zero(x))
133+
= zero(d.r)
134+
if dims == 1
135+
for j in 1:n, i in 1:n
136+
xi = view(x, i, :)
137+
xj = view(x, j, :)
138+
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- xj)) ./ d.r .^ 2
139+
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
140+
x̄[i, :] += ds
141+
x̄[j, :] -= ds
142+
end
143+
elseif dims == 2
144+
for j in 1:n, i in 1:n
145+
xi = view(x, :, i)
146+
xj = view(x, :, j)
147+
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
148+
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
149+
x̄[:, i] .+= ds
150+
x̄[:, j] .-= ds
151+
end
152+
end
153+
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
154+
return NoTangent(), d̄, @thunk(project_x(x̄))
155+
end
156+
return Distances.pairwise(d, x; dims), pairwise_pullback
157+
end
158+
159+
function ChainRulesCore.rrule(
160+
::typeof(Distances.pairwise), d::Sinus, x::AbstractMatrix, y::AbstractMatrix; dims=2
161+
)
162+
project_x = ProjectTo(x)
163+
project_y = ProjectTo(y)
164+
function pairwise_pullback(z̄)
165+
Δ = unthunk(z̄)
166+
n = size(x, dims)
167+
m = size(y, dims)
168+
= collect(zero(x))
169+
= collect(zero(y))
170+
= zero(d.r)
171+
if dims == 1
172+
for j in 1:m, i in 1:n
173+
xi = view(x, i, :)
174+
yj = view(y, j, :)
175+
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2
176+
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
177+
x̄[i, :] .+= ds
178+
ȳ[j, :] .-= ds
179+
end
180+
elseif dims == 2
181+
for j in 1:m, i in 1:n
182+
xi = view(x, :, i)
183+
yj = view(y, :, j)
184+
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2
185+
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
186+
x̄[:, i] .+= ds
187+
ȳ[:, j] .-= ds
188+
end
189+
end
190+
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
191+
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
192+
end
193+
return Distances.pairwise(d, x, y; dims), pairwise_pullback
194+
end
195+
196+
function ChainRulesCore.rrule(
197+
::typeof(Distances.colwise), d::Sinus, x::AbstractMatrix, y::AbstractMatrix
198+
)
199+
project_x = ProjectTo(x)
200+
project_y = ProjectTo(y)
201+
function colwise_pullback(z̄)
202+
Δ = unthunk(z̄)
203+
n = size(x, 2)
204+
= collect(zero(x))
205+
= collect(zero(y))
206+
= zero(d.r)
207+
for i in 1:n
208+
xi = view(x, :, i)
209+
yi = view(y, :, i)
210+
ds = π .* Δ[i] .* sinpi.(2 .* (xi .- yi)) ./ d.r .^ 2
211+
.-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3
212+
x̄[:, i] .+= ds
213+
ȳ[:, i] .-= ds
214+
end
215+
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
216+
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
217+
end
218+
return Distances.colwise(d, x, y), colwise_pullback
219+
end
220+
124221
## Reverse Rules for matrix wrappers
125222

126223
function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)

src/kernels/kernelsum.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Base.length(k::KernelSum) = length(k.kernels)
4545

4646
_sum(f, ks::Tuple, args...) = f(first(ks), args...) + _sum(f, Base.tail(ks), args...)
4747
_sum(f, ks::Tuple{Tx}, args...) where {Tx} = f(only(ks), args...)
48+
_sum(f, ks::AbstractVector, args...) = sum(k -> f(k, args...), ks)
4849

4950
::KernelSum)(x, y) = _sum((k, x, y) -> k(x, y), κ.kernels, x, y)
5051

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
22
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
3+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
35
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
46
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
57
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

test/basekernels/periodic.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
TestUtils.test_interface(PeriodicKernel(; r=[0.9, 0.9]), ColVecs{Float64})
1616
TestUtils.test_interface(PeriodicKernel(; r=[0.8, 0.7]), RowVecs{Float64})
1717

18-
# test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff])
19-
@test_broken "Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff"
18+
test_ADs(r -> PeriodicKernel(; r=exp.(r)), log.(r))
2019
test_params(k, (r,))
2120
end

test/basekernels/sm.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,6 @@
5757
end
5858

5959
# test_ADs(x->spectral_mixture_kernel(exp.(x[1:3]), reshape(x[4:18], 5, 3), reshape(x[19:end], 5, 3)), vcat(log.(αs₁), γs[:], ωs[:]), dims = [5,5])
60-
@test_broken "No tests passing (BaseKernel)"
60+
# No tests passing (BaseKernel)
61+
@test_broken false
6162
end

test/basekernels/wiener.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,6 @@
4444
TestUtils.test_interface(k2, x0, x1, x2)
4545
TestUtils.test_interface(k3, x0, x1, x2)
4646
# test_ADs(()->WienerKernel(i=1))
47-
@test_broken "No tests passing"
47+
# No tests passing
48+
@test_broken false
4849
end

test/chainrules.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,23 @@
1919
compare_gradient(:Zygote, [x, y]) do xy
2020
KernelFunctions.Sinus(r)(xy[1], xy[2])
2121
end
22+
@testset "rrules for Sinus(r=$r)" for r in (rand(3),)
23+
dist = KernelFunctions.Sinus(r)
24+
@testset "$type" for type in (Vector, SVector{3})
25+
test_rrule(dist, type(rand(3)), type(rand(3)))
26+
end
27+
@testset "$type1, $type2" for type1 in (Matrix, SMatrix{3,2}),
28+
type2 in (Matrix, SMatrix{3,4})
29+
30+
test_rrule(Distances.pairwise, dist, type1(rand(3, 2)); fkwargs=(dims=2,))
31+
test_rrule(
32+
Distances.pairwise,
33+
dist,
34+
type1(rand(3, 2)),
35+
type2(rand(3, 4));
36+
fkwargs=(dims=2,),
37+
)
38+
test_rrule(Distances.colwise, dist, type1(rand(3, 2)), type1(rand(3, 2)))
39+
end
40+
end
2241
end

test/kernels/kernelsum.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,24 @@
22
k1 = LinearKernel()
33
k2 = SqExponentialKernel()
44
k = KernelSum(k1, k2)
5-
@test k == KernelSum([k1, k2]) == KernelSum((k1, k2))
5+
kvec = KernelSum([k1, k2])
6+
@test k == kvec == KernelSum((k1, k2))
67
for (_k1, _k2) in Iterators.product(
78
(k1, KernelSum((k1,)), KernelSum([k1])), (k2, KernelSum((k2,)), KernelSum([k2]))
89
)
910
@test k == _k1 + _k2
11+
@test kvec == _k1 + _k2
1012
end
11-
@test length(k) == 2
12-
@test repr(k) == (
13+
@test length(k) == length(kvec) == 2
14+
@test repr(k) ==
15+
repr(kvec) ==
1316
"Sum of 2 kernels:\n" *
1417
"\tLinear Kernel (c = 0.0)\n" *
1518
"\tSquared Exponential Kernel (metric = Euclidean(0.0))"
16-
)
1719

1820
# Standardised tests.
1921
test_interface(k, Float64)
22+
test_interface(kvec, Float64)
2023
test_interface(ConstantKernel(; c=1.5) + WhiteKernel(), Vector{String})
2124
test_ADs(x -> KernelSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1))
2225
test_interface_ad_perf(2.4, StableRNG(123456)) do c

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using KernelFunctions
22
using AxisArrays
3+
using ChainRulesCore
4+
using ChainRulesTestUtils
35
using Distances
46
using Documenter
57
using Functors: functor

test/test_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ function test_zygote_perf_heuristic(
384384
@test_broken fwd[1] == fwd[2]
385385
end
386386
if passes[3]
387-
@test pb[1] == pb[2]
387+
@test abs(pb[1] - pb[2]) 1
388388
else
389389
@test_broken pb[1] == pb[2]
390390
end

test/transform/selecttransform.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,41 @@
104104
end
105105

106106
@testset "$(AD)" for AD in [:ReverseDiff]
107-
@test_broken ga = gradient(AD, A) do a
108-
testfunction(ta_row, a, 2)
107+
@test_broken let
108+
gx = gradient(AD, X) do x
109+
testfunction(tx_row, x, 2)
110+
end
111+
ga = gradient(AD, A) do a
112+
testfunction(ta_row, a, 2)
113+
end
114+
gx ga
109115
end
110-
@test_broken ga = gradient(AD, A) do a
111-
testfunction(ta_col, a, 1)
116+
@test_broken let
117+
gx = gradient(AD, X) do x
118+
testfunction(tx_col, x, 1)
119+
end
120+
ga = gradient(AD, A) do a
121+
testfunction(ta_col, a, 1)
122+
end
123+
gx ga
112124
end
113-
@test_broken ga = gradient(AD, A) do a
114-
testfunction(ta_row, a, B, 2)
125+
@test_broken let
126+
gx = gradient(AD, X) do x
127+
testfunction(tx_row, x, Y, 2)
128+
end
129+
ga = gradient(AD, A) do a
130+
testfunction(ta_row, a, B, 2)
131+
end
132+
gx ga
115133
end
116-
@test_broken ga = gradient(AD, A) do a
117-
testfunction(ta_col, a, C, 1)
134+
@test_broken let
135+
gx = gradient(AD, X) do x
136+
testfunction(tx_col, x, Z, 1)
137+
end
138+
ga = gradient(AD, A) do a
139+
testfunction(ta_col, a, C, 1)
140+
end
141+
gx ga
118142
end
119143
end
120144

0 commit comments

Comments
 (0)