Skip to content

Commit b8ae35c

Browse files
github-actions[bot]st--willtebbutt
authored
CompatHelper: bump compat for "ChainRulesCore" to "1" (#344)
* CompatHelper: bump compat for "ChainRulesCore" to "1" * bump patch version * Update InplaceableThunks * Restrict to version 1 * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Remove Flux dependency * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Remove another Flux reference * FD + Zygote issue * Tweak fbm tests * Note Flux code Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: ST John <[email protected]> Co-authored-by: WT <[email protected]> Co-authored-by: willtebbutt <[email protected]>
1 parent 2203f49 commit b8ae35c

File tree

8 files changed

+39
-26
lines changed

8 files changed

+39
-26
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.13"
3+
version = "0.10.14"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -21,7 +21,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2121
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2222

2323
[compat]
24-
ChainRulesCore = "0.9.44, 0.10"
24+
ChainRulesCore = "1"
2525
Compat = "3.7"
2626
CompositionsBase = "0.1"
2727
Distances = "0.10"

src/chainrules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ function ChainRulesCore.rrule(
127127
function SqMahalanobis_pullback::Real)
128128
a_b = a - b
129129
∂qmat = InplaceableThunk(
130-
@thunk((a_b * a_b') * Δ), -> mul!(X̄, a_b, a_b', true, Δ)
130+
-> mul!(X̄, a_b, a_b', true, Δ), @thunk((a_b * a_b') * Δ)
131131
)
132132
∂a = InplaceableThunk(
133-
@thunk((2 * Δ) * dist.qmat * a_b), -> mul!(X̄, dist.qmat, a_b, true, 2 * Δ)
133+
-> mul!(X̄, dist.qmat, a_b, true, 2 * Δ), @thunk((2 * Δ) * dist.qmat * a_b)
134134
)
135135
∂b = InplaceableThunk(
136-
@thunk((-2 * Δ) * dist.qmat * a_b), -> mul!(X̄, dist.qmat, a_b, true, -2 * Δ)
136+
-> mul!(X̄, dist.qmat, a_b, true, -2 * Δ), @thunk((-2 * Δ) * dist.qmat * a_b)
137137
)
138138
return Tangent{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b
139139
end

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
33
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
44
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
55
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
6-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
76
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
7+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
88
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
@@ -20,8 +20,8 @@ AxisArrays = "0.4.3"
2020
Distances = "0.9, 0.10"
2121
Documenter = "0.25, 0.26, 0.27"
2222
FiniteDifferences = "0.10.8, 0.11, 0.12"
23-
Flux = "0.10, 0.11, 0.12"
2423
ForwardDiff = "0.10"
24+
Functors = "0.2"
2525
Kronecker = "0.4"
2626
LogExpFunctions = "0.2, 0.3"
2727
PDMats = "0.9, 0.10, 0.11"

test/basekernels/fbm.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,12 @@
1515
@test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))"
1616
test_ADs(FBMKernel; ADs=[:ReverseDiff])
1717

18-
# Tests failing for ForwardDiff and [email protected] (obtained with Julia > 1.3).
18+
# Tests failing for ForwardDiff and [email protected].
1919
# Related to: https://github.com/FluxML/Zygote.jl/issues/1036
20-
@test_broken !isinf(ForwardDiff.gradient(x -> x[1]^x[2], [0.0, 0.9])[1])
21-
if VERSION >= v"1.4.0"
22-
f(x, y) = x^y
23-
@test_broken !isinf(
24-
Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1]
25-
)
26-
else
27-
test_ADs(FBMKernel; ADs=[:Zygote])
28-
end
20+
f(x, y) = x^y
21+
@test_broken !isinf(
22+
Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1]
23+
)
2924

3025
test_params(k, ([h],))
3126
end

test/kernels/neuralkernelnetwork.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using KernelFunctions: NeuralKernelNetwork, LinearLayer, product, Primitive
1919
primitives = Primitive(k1, k2)
2020

2121
# Build NKN Kernel.
22-
nkn = NeuralKernelNetwork(primitives, Chain(LinearLayer(2, 2), product))
22+
nkn = NeuralKernelNetwork(primitives, product)
2323

2424
# Apply standard test suite.
2525
TestUtils.test_interface(nkn, Float64)

test/kernels/transformedkernel.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,21 @@
3030
# Test implicit gradients
3131
@testset "Implicit gradients" begin
3232
k = SqExponentialKernel() ScaleTransform(2.0)
33-
ps = Flux.params(k)
33+
ps = params(k)
3434
X = rand(10, 1)
3535
x = vec(X)
3636
A = rand(10, 10)
3737
# Implicit
38-
g1 = Flux.gradient(ps) do
38+
g1 = Zygote.gradient(ps) do
3939
tr(kernelmatrix(k, X; obsdim=1) * A)
4040
end
4141
# Explicit
42-
g2 = Flux.gradient(k) do k
42+
g2 = Zygote.gradient(k) do k
4343
tr(kernelmatrix(k, X; obsdim=1) * A)
4444
end
4545

4646
# Implicit for a vector
47-
g3 = Flux.gradient(ps) do
47+
g3 = Zygote.gradient(ps) do
4848
tr(kernelmatrix(k, x) * A)
4949
end
5050
@test g1[first(ps)] first(g2).transform.s
@@ -53,12 +53,12 @@
5353

5454
@testset "Parameters" begin
5555
k = ConstantKernel(; c=rand(rng))
56-
c = Chain(Dense(3, 2))
56+
# c = Chain(Dense(3, 2))
5757

5858
test_params(k ScaleTransform(s), (k, [s]))
5959
test_params(k ARDTransform(v), (k, v))
6060
test_params(k LinearTransform(P), (k, P))
6161
test_params(k LinearTransform(P) ScaleTransform(s), (k, [s], P))
62-
test_params(k FunctionTransform(c), (k, c))
62+
# test_params(k ∘ FunctionTransform(c), (k, c))
6363
end
6464
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ using KernelFunctions
22
using AxisArrays
33
using Distances
44
using Documenter
5+
using Functors: functor
56
using Kronecker: Kronecker
67
using LinearAlgebra
78
using LogExpFunctions
89
using PDMats
910
using Random
1011
using SpecialFunctions
1112
using Test
12-
using Flux
1313
using Zygote: Zygote
1414
using ForwardDiff: ForwardDiff
1515
using ReverseDiff: ReverseDiff

test/test_utils.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,25 @@
11
# More test utilities. Can't be included in KernelFunctions because they introduce a number
22
# of additional deps that we don't want to have in the main package.
33

4-
# Check parameters of kernels
4+
# Check parameters of kernels. `trainable`, `params!`, and `params` are taken directly from
5+
# Flux.jl so as to avoid having to depend on Flux at test-time.
6+
trainable(m) = functor(m)[1]
7+
8+
params!(p::Zygote.Params, x::AbstractArray{<:Number}, seen=Zygote.IdSet()) = push!(p, x)
9+
10+
function params!(p::Zygote.Params, x, seen=Zygote.IdSet())
11+
x in seen && return nothing
12+
push!(seen, x)
13+
for child in trainable(x)
14+
params!(p, child, seen)
15+
end
16+
end
17+
18+
function params(m...)
19+
ps = Zygote.Params()
20+
params!(ps, m)
21+
return ps
22+
end
523

624
function test_params(kernel, reference)
725
params_kernel = params(kernel)

0 commit comments

Comments
 (0)