Skip to content

Commit 6aea76e

Browse files
authored
Merge pull request #141 from JuliaGaussianProcesses/fix_implicit_gradients
Fixing implicit gradients
2 parents fec2318 + 76d9b6a commit 6aea76e

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

src/kernels/transformedkernel.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ end
6363
# Kernel matrix operations
6464

6565
function kerneldiagmatrix!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector)
66-
return kerneldiagmatrix!(K, κ.kernel, map.transform, x))
66+
return kerneldiagmatrix!(K, κ.kernel, _map.transform, x))
6767
end
6868

6969
function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
70-
return kernelmatrix!(K, kernel(κ), map.transform, x))
70+
return kernelmatrix!(K, kernel(κ), _map.transform, x))
7171
end
7272

7373
function kernelmatrix!(
@@ -76,17 +76,17 @@ function kernelmatrix!(
7676
x::AbstractVector,
7777
y::AbstractVector,
7878
)
79-
return kernelmatrix!(K, kernel(κ), map.transform, x), map.transform, y))
79+
return kernelmatrix!(K, kernel(κ), _map.transform, x), _map.transform, y))
8080
end
8181

8282
function kerneldiagmatrix::TransformedKernel, x::AbstractVector)
83-
return kerneldiagmatrix.kernel, map.transform, x))
83+
return kerneldiagmatrix.kernel, _map.transform, x))
8484
end
8585

8686
function kernelmatrix::TransformedKernel, x::AbstractVector)
87-
return kernelmatrix(kernel(κ), map.transform, x))
87+
return kernelmatrix(kernel(κ), _map.transform, x))
8888
end
8989

9090
function kernelmatrix::TransformedKernel, x::AbstractVector, y::AbstractVector)
91-
return kernelmatrix(kernel(κ), map.transform, x), map.transform, y))
91+
return kernelmatrix(kernel(κ), _map.transform, x), _map.transform, y))
9292
end

src/transform/scaletransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717

1818
set!(t::ScaleTransform::Real) = t.s .= [ρ]
1919

20-
(t::ScaleTransform)(x) = first(t.s) .* x
20+
(t::ScaleTransform)(x) = first(t.s) * x
2121

2222
_map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
2323
_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)

test/kernels/transformedkernel.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,27 @@
5252
end
5353
end
5454
test_ADs(x->transform(SqExponentialKernel(), x[1]), rand(1))# ADs = [:ForwardDiff, :ReverseDiff])
55+
# Test implicit gradients
56+
@testset "Implicit gradients" begin
57+
k = transform(SqExponentialKernel(), 2.0)
58+
ps = Flux.params(k)
59+
X = rand(10, 1); x = vec(X)
60+
A = rand(10, 10)
61+
# Implicit
62+
g1 = Flux.gradient(ps) do
63+
tr(kernelmatrix(k, X, obsdim = 1) * A)
64+
end
65+
# Explicit
66+
g2 = Flux.gradient(k) do k
67+
tr(kernelmatrix(k, X, obsdim = 1) * A)
68+
end
69+
70+
# Implicit for a vector
71+
g3 = Flux.gradient(ps) do
72+
tr(kernelmatrix(k, x) * A)
73+
end
74+
@test g1[first(ps)] first(g2).transform.s
75+
@test g1[first(ps)] g3[first(ps)]
76+
end
77+
5578
end

0 commit comments

Comments
 (0)