Skip to content

Commit e653fe4

Browse files
committed
Remove transform and test deprecations
1 parent 777f7b2 commit e653fe4

15 files changed

+56
-45
lines changed

docs/create_kernel_plots.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ n_grid = 101
1313
fill(x₀, n_grid, 1)
1414
xrange = reshape(collect(range(-3, 3; length=n_grid)), :, 1)
1515

16-
k = transform(SqExponentialKernel(), 1.0)
16+
k = SqExponentialKernel() ScaleTransform(1.0)
1717
K1 = kernelmatrix(k, xrange; obsdim=1)
1818
p = heatmap(
1919
K1;
@@ -35,7 +35,7 @@ p = heatmap(
3535
)
3636
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_matern.png"))
3737

38-
k = transform(PolynomialKernel(; c=0.0, d=2.0), LinearTransform(randn(3, 1)))
38+
k = PolynomialKernel(; c=0.0, d=2.0) LinearTransform(randn(3, 1))
3939
K3 = kernelmatrix(k, xrange; obsdim=1)
4040
p = heatmap(
4141
K3;
@@ -47,7 +47,7 @@ p = heatmap(
4747
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_poly.png"))
4848

4949
k =
50-
0.5 * SqExponentialKernel() * transform(LinearKernel(), 0.5) +
50+
0.5 * SqExponentialKernel() * (LinearKernel() ScaleTransform(0.5)) +
5151
0.4 * (@kernel Matern32Kernel() FunctionTransform(x -> sin.(x)))
5252
K4 = kernelmatrix(k, xrange; obsdim=1)
5353
p = heatmap(

src/basekernels/gabor.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,22 @@ end
2525
::GaborKernel)(x, y) = κ.kernel(x, y)
2626

2727
function _gabor(; ell=nothing, p=nothing)
28-
if ell === nothing
29-
if p === nothing
30-
return SqExponentialKernel() * CosineKernel()
31-
else
32-
return SqExponentialKernel() * transform(CosineKernel(), 1 ./ p)
33-
end
34-
elseif p === nothing
35-
return transform(SqExponentialKernel(), 1 ./ ell) * CosineKernel()
28+
ell_transform = if ell === nothing
29+
IdentityTransform()
30+
elseif ell isa Real
31+
ScaleTransform(inv(ell))
3632
else
37-
return transform(SqExponentialKernel(), 1 ./ ell) *
38-
transform(CosineKernel(), 1 ./ p)
33+
ARDTransform(inv.(ell))
3934
end
35+
p_transform = if p === nothing
36+
IdentityTransform()
37+
elseif p isa Real
38+
ScaleTransform(inv(p))
39+
else
40+
ARDTransform(inv.(p))
41+
end
42+
43+
return (SqExponentialKernel() ell_transform) * (CosineKernel() p_transform)
4044
end
4145

4246
function Base.getproperty(k::GaborKernel, v::Symbol)

src/kernels/transformedkernel.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
44
Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`.
55
6-
It is preferred to create kernels with input transformations with [`∘`](@ref), or its
7-
alias [`compose`](@ref), instead of `TransformedKernel` directly since [`∘`](@ref)
8-
allows optimized implementations for specific kernels and transformations.
6+
It is preferred to create kernels with input transformations with `∘` or its alias
7+
`compose` instead of `TransformedKernel` directly since this allows optimized
8+
implementations for specific kernels and transformations.
99
1010
# Definition
1111

test/basekernels/gabor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p)
1111
@test k(v1, v2) k_manual atol = 1e-5
1212

13-
lhs_manual = transform(SqExponentialKernel(), 1 / k.ell)(v1, v2)
14-
rhs_manual = transform(CosineKernel(), 1 / k.p)(v1, v2)
13+
lhs_manual = (SqExponentialKernel() ScaleTransform(1 / k.ell))(v1, v2)
14+
rhs_manual = (CosineKernel() ScaleTransform(1 / k.p))(v1, v2)
1515
@test k(v1, v2) lhs_manual * rhs_manual atol = 1e-5
1616

1717
k = GaborKernel()

test/deprecations.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
@testset "deprecations.jl" begin
2+
p = rand()
3+
v = rand(3)
4+
M = rand(3, 3)
5+
kernel = SqExponentialKernel()
6+
7+
@test (@test_deprecated transform(kernel, LinearTransform(M))) ==
8+
kernel LinearTransform(M)
9+
@test (@test_deprecated transform(kernel ScaleTransform(p), ARDTransform(v))) ==
10+
kernel ARDTransform(v) ScaleTransform(p)
11+
@test (@test_deprecated transform(kernel, p)) == kernel ScaleTransform(p)
12+
@test (@test_deprecated transform(kernel, v)) == kernel ARDTransform(v)
13+
end

test/kernels/transformedkernel.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,11 @@
1010
k = SqExponentialKernel()
1111
kt = TransformedKernel(k, ScaleTransform(s))
1212
ktard = TransformedKernel(k, ARDTransform(v))
13-
@test kt(v1, v2) == transform(k, ScaleTransform(s))(v1, v2)
14-
@test kt(v1, v2) == transform(k, s)(v1, v2)
1513
@test kt(v1, v2) == (k ScaleTransform(s))(v1, v2)
1614
@test kt(v1, v2) k(s * v1, s * v2) atol = 1e-5
17-
@test ktard(v1, v2) transform(k, ARDTransform(v))(v1, v2) atol = 1e-5
1815
@test ktard(v1, v2) == (k ARDTransform(v))(v1, v2)
19-
@test ktard(v1, v2) == transform(k, v)(v1, v2)
2016
@test ktard(v1, v2) == k(v .* v1, v .* v2)
21-
@test transform(kt, s2)(v1, v2) kt(s2 * v1, s2 * v2)
22-
@test KernelFunctions.kernel(kt) == k
17+
@test (kt s2)(v1, v2) kt(s2 * v1, s2 * v2)
2318
@test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s))
2419

2520
TestUtils.test_interface(k, Float64)
@@ -51,15 +46,12 @@
5146
P = rand(3, 2)
5247
c = Chain(Dense(3, 2))
5348

54-
test_params(transform(k, s), (k, [s]))
55-
test_params(transform(k, v), (k, v))
56-
test_params(transform(k, LinearTransform(P)), (k, P))
57-
test_params(transform(k, LinearTransform(P) ScaleTransform(s)), (k, [s], P))
58-
test_params(transform(k, FunctionTransform(c)), (k, c))
49+
test_params(k ScaleTransform(s), (k, [s]))
50+
test_params(k ARDTransform(v), (k, v))
51+
test_params(k LinearTransform(P), (k, P))
52+
test_params(k (LinearTransform(P) ScaleTransform(s)), (k, [s], P))
53+
test_params(k FunctionTransform(c), (k, c))
5954

6055
@test (k (LinearTransform(P') ScaleTransform(s)))(v1, v2) ==
6156
((k LinearTransform(P')) ScaleTransform(s))(v1, v2)
62-
test_params(k LinearTransform(P), (P, k))
63-
test_params(k LinearTransform(P) ScaleTransform(s), ([s], P, k))
64-
test_params(k FunctionTransform(c), (c, k))
6557
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ include("test_utils.jl")
148148
include("chainrules.jl")
149149
include("zygoterules.jl")
150150

151+
include("deprecations.jl")
152+
151153
@testset "doctests" begin
152154
DocMeta.setdocmeta!(
153155
KernelFunctions,

test/transform/ardtransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@
4242
@test_throws DimensionMismatch map(t, ColVecs(randn(rng, D + 1, 3)))
4343

4444
@test repr(t) == "ARD Transform (dims: $D)"
45-
test_ADs(x -> transform(SEKernel(), exp.(x)), randn(rng, 3))
45+
test_ADs(x -> SEKernel() ARDTransform(exp.(x)), randn(rng, 3))
4646
end

test/transform/chaintransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
# Verify printing works as expected.
2424
@test repr(tp tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)"
2525
test_ADs(
26-
x -> transform(SEKernel(), ScaleTransform(exp(x[1])) ARDTransform(exp.(x[2:4]))),
26+
x -> SEKernel() (ScaleTransform(exp(x[1])) ARDTransform(exp.(x[2:4]))),
2727
randn(rng, 4),
2828
)
2929
end

test/transform/functiontransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@
2828

2929
@test repr(FunctionTransform(sin)) == "Function Transform: $(sin)"
3030
f(a, x) = sin.(a .* x)
31-
test_ADs(x -> transform(SEKernel(), FunctionTransform(y -> f(x, y))), randn(rng, 3))
31+
test_ADs(x -> SEKernel() FunctionTransform(y -> f(x, y)), randn(rng, 3))
3232
end

test/transform/lineartransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@
4242
@test_throws DimensionMismatch map(t, ColVecs(randn(rng, Din + 1, Dout)))
4343

4444
@test repr(t) == "Linear transform (size(A) = ($Dout, $Din))"
45-
test_ADs(x -> transform(SEKernel(), LinearTransform(x)), randn(rng, 3, 3))
45+
test_ADs(x -> SEKernel() LinearTransform(x), randn(rng, 3, 3))
4646
end

test/transform/periodic_transform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
x = collect(range(0.0, 3.0 / f; length=1_000))
66

77
# Construct in the usual way.
8-
k_eq_periodic = transform(PeriodicKernel(; r=[sqrt(0.25)]), f)
8+
k_eq_periodic = PeriodicKernel(; r=[sqrt(0.25)]) ScaleTransform(f)
99

1010
# Construct using the peridic transform.
11-
k_eq_transform = transform(SqExponentialKernel(), PeriodicTransform(f))
11+
k_eq_transform = SqExponentialKernel() PeriodicTransform(f)
1212

1313
@test kernelmatrix(k_eq_periodic, x) kernelmatrix(k_eq_transform, x)
1414
# TODO - add interface_tests once #159 is merged.

test/transform/scaletransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@
1919
@test t.s == [s2]
2020
@test isequal(ScaleTransform(s), ScaleTransform(s))
2121
@test repr(t) == "Scale Transform (s = $(s2))"
22-
test_ADs(x -> transform(SEKernel(), exp(x[1])), randn(rng, 1))
22+
test_ADs(x -> SEKernel() ScaleTransform(exp(x[1])), randn(rng, 1))
2323
end

test/transform/selecttransform.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
@test repr(t) == "Select Transform (dims: $(select2))"
4545
@test repr(ts) == "Select Transform (dims: $(select_symbols2))"
4646

47-
test_ADs(() -> transform(SEKernel(), SelectTransform([1, 2])))
47+
test_ADs(() -> SEKernel() SelectTransform([1, 2]))
4848

4949
X = randn(rng, (4, 3))
5050
A = AxisArray(X; row=[:a, :b, :c, :d], col=[:x, :y, :z])
@@ -53,10 +53,10 @@
5353
Z = randn(rng, (2, 3))
5454
C = AxisArray(Z; row=[:e, :f], col=[:x, :y, :z])
5555

56-
tx_row = transform(SEKernel(), SelectTransform([1, 2, 4]))
57-
ta_row = transform(SEKernel(), SelectTransform([:a, :b, :d]))
58-
tx_col = transform(SEKernel(), SelectTransform([1, 3]))
59-
ta_col = transform(SEKernel(), SelectTransform([:x, :z]))
56+
tx_row = SEKernel() SelectTransform([1, 2, 4])
57+
ta_row = SEKernel() SelectTransform([:a, :b, :d])
58+
tx_col = SEKernel() SelectTransform([1, 3])
59+
ta_col = SEKernel() SelectTransform([:x, :z])
6060

6161
@test kernelmatrix(tx_row, X; obsdim=2) kernelmatrix(ta_row, A; obsdim=2)
6262
@test kernelmatrix(tx_col, X; obsdim=1) kernelmatrix(ta_col, A; obsdim=1)

test/transform/transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
@test IdentityTransform()(x) == x
99
@test map(IdentityTransform(), x) == x
1010
end
11-
test_ADs(() -> transform(SEKernel(), IdentityTransform()))
11+
test_ADs(() -> SEKernel() IdentityTransform())
1212
end

0 commit comments

Comments
 (0)