Skip to content

Commit 5f0cdc2

Browse files
authored
Merge pull request #110 from devmotion/zygote_alternative
2 parents a4572ad + 8500413 commit 5f0cdc2

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ Requires = "1.0.1"
2020
SpecialFunctions = "0.8, 0.9, 0.10"
2121
StatsBase = "0.32, 0.33"
2222
StatsFuns = "0.8, 0.9"
23-
Zygote = "= 0.4.16"
2423
ZygoteRules = "0.2"
2524
julia = "1.3"
2625

test/trainable.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,63 @@
11
@testset "trainable" begin
2-
using Flux: params
32
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; h = 0.5; r = rand(3)
43

4+
function test_params(kernel, reference)
5+
params_kernel = params(kernel)
6+
params_reference = params(reference)
7+
8+
@test length(params_kernel) == length(params_reference)
9+
@test all(p == q for (p, q) in zip(params_kernel, params_reference))
10+
end
11+
512
kc = ConstantKernel(c=c)
6-
@test all(params(kc) .== params([c]))
13+
test_params(kc, ([c],))
714

815
kfbm = FBMKernel(h = h)
9-
@test all(params(kfbm) .== params([h]))
16+
test_params(kfbm, ([h],))
1017

1118
kge = GammaExponentialKernel=γ)
12-
@test all(params(kge) .== params([γ]))
19+
test_params(kge, ([γ],))
1320

1421
kgr = GammaRationalQuadraticKernel=γ, α=α)
15-
@test all(params(kgr) .== params([α], [γ]))
22+
test_params(kgr, ([α], [γ]))
1623

1724
kl = LinearKernel(c=c)
18-
@test all(params(kl) .== params([c]))
25+
test_params(kl, ([c],))
1926

2027
km = MaternKernel=ν)
21-
@test all(params(km) .== params([ν]))
28+
test_params(km, ([ν],))
2229

2330
kp = PolynomialKernel(c=c, d=d)
24-
@test all(params(kp) .== params([d], [c]))
31+
test_params(kp, ([d], [c]))
2532

2633
kpe = PeriodicKernel(r = r)
27-
@test all(params(kpe) .== params(r))
34+
test_params(kpe, (r,))
2835

2936
kr = RationalQuadraticKernel=α)
30-
@test all(params(kr) .== params([α]))
37+
test_params(kr, ([α],))
3138

3239
k = km + kc
33-
@test all(params(k) .== params([k.weights], km, kc))
40+
test_params(k, (k.weights, km, kc))
3441

3542
k = km * kc
36-
@test all(params(k) .== params(km, kc))
43+
test_params(k, (km, kc))
3744

3845
s = 2.0
3946
k = transform(km, s)
40-
@test all(params(k) .== params([s], km))
47+
test_params(k, ([s], km))
4148

4249
v = [2.0]
4350
k = transform(kc, v)
44-
@test all(params(k) .== params(v, kc))
51+
test_params(k, (v, kc))
4552

4653
P = rand(3, 2)
4754
k = transform(km, LinearTransform(P))
48-
@test all(params(k) .== params(P, km))
55+
test_params(k, (P, km))
4956

5057
k = transform(km, LinearTransform(P) ScaleTransform(s))
51-
@test all(params(k) .== params([s], P, km))
58+
test_params(k, ([s], P, km))
5259

5360
c = Chain(Dense(3, 2))
5461
k = transform(km, FunctionTransform(c))
55-
@test all(params(k) .== params(c, km))
62+
test_params(k, (c, km))
5663
end

0 commit comments

Comments
 (0)