Skip to content

Commit 5abdada

Browse files
committed
Added set! and corrections
1 parent 9581045 commit 5abdada

File tree

6 files changed

+54
-9
lines changed

6 files changed

+54
-9
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1010
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1111

1212
[compat]
13-
julia = "1.0"
13+
Distances = "0.8.2"
1414
PDMats = "0.9.9"
1515
SpecialFunctions = "0.7.2"
16-
Distances = "0.8.2"
16+
StatsFuns = "0.8"
17+
julia = "1.0"
1718

1819
[extras]
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/generic.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,19 @@ for kernel in [:ExponentialKernel,:SqExponentialKernel,:Matern32Kernel,:Matern52
2929
$kernel(t::Tr) where {Tr<:Transform} = $kernel{eltype(t),Tr}(t)
3030
end
3131
end
32+
33+
function set!()
34+
35+
end
36+
37+
function set!(k::Kernel{T,ScaleTransform{Base.RefValue{<:Tρ}}}::Tρ) where {T,Tρ<:Real}
38+
set!(k.transform,ρ)
39+
end
40+
41+
function set!(k::Kernel{T,ScaleTransform{<:AbstractVector{<:Tρ}}}::AbstractVector{<:Tρ}) where {T,Tρ<:Real}
42+
set!(k.transform,ρ)
43+
end
44+
45+
function set!(k::Kernel{T,LowRankTransform{<:AbstractMatrix{<:Tm}}},m::AbstractMatrix{<:Tm}) where {T,Tm<:Real}
46+
set!(k.transform,m)
47+
end

src/kernels/exponential.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct SqExponentialKernel{T,Tr} <: Kernel{T,Tr}
1818
end
1919
end
2020

21-
@inline kappa::SqExponentialKernel, d²::Real) where {T} = exp(-d²)
21+
@inline kappa::SqExponentialKernel, d²::Real) = exp(-d²)
2222

2323
### Aliases
2424
const RBFKernel = SqExponentialKernel
@@ -42,7 +42,7 @@ struct ExponentialKernel{T,Tr} <: Kernel{T,Tr}
4242
end
4343
end
4444

45-
@inline kappa::ExponentialKernel, d::Real) where {T} = exp(-d)
45+
@inline kappa::ExponentialKernel, d::Real) = exp(-d)
4646

4747
### Aliases
4848
const LaplacianKernel = ExponentialKernel
@@ -80,4 +80,4 @@ function GammaExponentialKernel(t::Tr,gamma::T₁=2.0) where {Tr<:Transform,T₁
8080
GammaExponentialKernel{eltype(Tr),Tr,T₁}(t,gamma)
8181
end
8282

83-
@inline kappa::GammaExponentialKernel, d²::Real) where {T} = exp(-^κ.γ)
83+
@inline kappa::GammaExponentialKernel, d²::Real) = exp(-^κ.γ)

src/matrix/kernelkroeneckermat.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
function kernelkronmat(
2+
κ::Kernel,
3+
X::AbstractVector,
4+
dims::Int
5+
)
6+
@assert iskroncompatible(κ) "The kernel chosed is not compatible for kroenecker matrices"
7+
K = kernelmatrix(κ,reshape(X,:,1),obsdim=1)
8+
9+
end
10+
11+
12+
function iskroncompatible::Kernel)
13+
14+
end

src/transform/lowranktransform.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ struct LowRankTransform{T<:AbstractMatrix{<:Real}} <: Transform
1111
proj::T
1212
end
1313

14+
function set!(t::LowRankTransform{<:AbstractMatrix{T}},M::AbstractMatrix{T}) where {T<:Real}
15+
@assert size(t) == size(M) "Size of the given matrix $(size(M)) and the projection matrix $(size(t)) are not the same"
16+
t.proj .= M
17+
end
18+
1419
Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
1520
Base.size(tr::LowRankTransform) = size(tr.proj) # TODO Add test
1621

src/transform/scaletransform.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
Multiply every element of the matrix by `l` for a scalar
1010
Multiply every vector of observation by `v` element-wise for a vector
1111
"""
12-
struct ScaleTransform{T<:Union{Real,AbstractVector{<:Real}}} <: Transform
12+
struct ScaleTransform{T<:Union{Base.RefValue{<:Real},AbstractVector{<:Real}}} <: Transform
1313
s::T
1414
end
1515

1616
function ScaleTransform(s::T=1.0) where {T<:Real}
1717
@check_args(ScaleTransform, s, s > zero(T), "s > 0")
18-
ScaleTransform{T}(s)
18+
ScaleTransform{T}(Ref(s))
1919
end
2020

2121
function ScaleTransform(s::T,dims::Integer) where {T<:Real} # TODO Add test
@@ -28,7 +28,16 @@ function ScaleTransform(s::A) where {A<:AbstractVector{<:Real}}
2828
ScaleTransform{A}(s)
2929
end
3030

31-
dim(str::ScaleTransform{<:Real}) = 1 #TODO Add test
31+
function set!(t::ScaleTransform{Base.RefValue{T}}::T) where {T<:Real}
32+
t.s[] = ρ
33+
end
34+
35+
function set!(t::ScaleTransform{AbstractVector{T}}::AbstractVector{T}) where {T<:Real}
36+
@assert length(ρ) == dim(t) "Trying to set a vector of size $(length(ρ)) to ScaleTransform of dimension $(dim(t))"
37+
t.s .= ρ
38+
end
39+
40+
dim(str::ScaleTransform{Base.RefValue{<:Real}}) = 1 #TODO Add test
3241
dim(str::ScaleTransform{<:AbstractVector{<:Real}}) = length(str.s)
3342

3443
function transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int)
@@ -40,4 +49,4 @@ end
4049
transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real},obsdim::Int=defaultobs) = t.s .* x
4150
_transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 1 ? t.s'.*X : t.s .* X
4251

43-
transform(t::ScaleTransform{<:Real},x::AbstractVecOrMat,obsdim::Int=defaultobs) = t.s .* x
52+
transform(t::ScaleTransform{Base.RefValue{<:Real}},x::AbstractVecOrMat,obsdim::Int=defaultobs) = t.s[] .* x

0 commit comments

Comments
 (0)