Skip to content

Commit adf55bf

Browse files
authored
Merge pull request #106 from devmotion/lineartransform
Use LinearTransform instead of LowRankTransform
2 parents 9723a5c + e5c48d9 commit adf55bf

12 files changed

+56
-54
lines changed

docs/create_kernel_plots.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ p = heatmap(K2,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(
2323
savefig(joinpath(@__DIR__,"src","assets","heatmap_matern.png"))
2424

2525

26-
k = transform(PolynomialKernel(c=0.0,d=2.0),LowRankTransform(randn(3,1)))
26+
k = transform(PolynomialKernel(c=0.0,d=2.0), LinearTransform(randn(3,1)))
2727
K3 = kernelmatrix(k,xrange,obsdim=1)
2828
p = heatmap(K3,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
2929
savefig(joinpath(@__DIR__,"src","assets","heatmap_poly.png"))

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ export KernelSum, KernelProduct
2323
export TransformedKernel, ScaledKernel
2424
export TensorProduct
2525

26-
export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
26+
export Transform, SelectTransform, ChainTransform, ScaleTransform, LinearTransform,
27+
ARDTransform, IdentityTransform, FunctionTransform
2728

2829
export NystromFact, nystrom
2930

src/trainable.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ trainable(t::ChainTransform) = t.transforms
4242

4343
trainable(t::FunctionTransform) = (t.f,)
4444

45-
trainable(t::LowRankTransform) = (t.proj,)
45+
trainable(t::LinearTransform) = (t.A,)
4646

4747
trainable(t::ScaleTransform) = (t.s,)

src/transform/chaintransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Chain a series of transform, here `t1` will be called first
55
```
66
t1 = ScaleTransform()
7-
t2 = LowRankTransform(rand(3,4))
7+
t2 = LinearTransform(rand(3,4))
88
ct = ChainTransform([t1,t2]) #t1 will be called first
99
ct == t2 ∘ t1
1010
```

src/transform/lineartransform.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
LinearTransform(A::AbstractMatrix)
3+
4+
Apply the linear transformation realised by the matrix `A`.
5+
6+
The second dimension of `A` must match the number of features of the target.
7+
8+
# Examples
9+
10+
```julia-repl
11+
julia> A = rand(10, 5)
12+
13+
julia> tr = LinearTransform(A)
14+
```
15+
"""
16+
struct LinearTransform{T<:AbstractMatrix{<:Real}} <: Transform
17+
A::T
18+
end
19+
20+
function set!(t::LinearTransform{<:AbstractMatrix{T}}, A::AbstractMatrix{T}) where {T<:Real}
21+
size(t.A) == size(A) ||
22+
error("size of the given matrix ", size(A), " and of the transformation matrix ",
23+
size(t.A), " are not the same")
24+
t.A .= A
25+
end
26+
27+
(t::LinearTransform)(x::Real) = vec(t.A * x)
28+
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x
29+
30+
Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x')
31+
Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
32+
Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
33+
34+
function Base.show(io::IO, t::LinearTransform)
35+
print(io::IO, "Linear transform (size(A) = ", size(t.A), ")")
36+
end

src/transform/lowranktransform.jl

Lines changed: 0 additions & 33 deletions
This file was deleted.

src/transform/transform.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
export Transform, IdentityTransform, ScaleTransform, ARDTransform, LowRankTransform, FunctionTransform, ChainTransform
2-
31
include("scaletransform.jl")
42
include("ardtransform.jl")
5-
include("lowranktransform.jl")
3+
include("lineartransform.jl")
64
include("functiontransform.jl")
75
include("selecttransform.jl")
86
include("chaintransform.jl")

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ using KernelFunctions: metric, kappa
5656
include(joinpath("transform", "transform.jl"))
5757
include(joinpath("transform", "scaletransform.jl"))
5858
include(joinpath("transform", "ardtransform.jl"))
59-
include(joinpath("transform", "lowranktransform.jl"))
59+
include(joinpath("transform", "lineartransform.jl"))
6060
include(joinpath("transform", "functiontransform.jl"))
6161
include(joinpath("transform", "selecttransform.jl"))
6262
include(joinpath("transform", "chaintransform.jl"))

test/test_AD.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ end
4545
transform_AD(Val(Symbol($AD)),ScaleTransform(l),A)
4646
# ARD Transform
4747
transform_AD(Val(Symbol($AD)),ARDTransform(vl),A)
48-
# LowRankTransform
49-
transform_AD(Val(Symbol($AD)),LowRankTransform(rand(2,10)),A)
48+
# Linear transform
49+
transform_AD(Val(Symbol($AD)), LinearTransform(rand(2,10)),A)
5050
# Chain Transform
51-
# transform_AD(Val(Symbol($AD)),LowRankTransform,A)
51+
# transform_AD(Val(Symbol($AD)), LinearTransform, A)
5252
end
5353
end
5454
end

test/trainable.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@
4444
@test all(params(k) .== params(v, kc))
4545

4646
P = rand(3, 2)
47-
k = transform(km,LowRankTransform(P))
47+
k = transform(km, LinearTransform(P))
4848
@test all(params(k) .== params(P, km))
4949

50-
k = transform(km, LowRankTransform(P) ScaleTransform(s))
50+
k = transform(km, LinearTransform(P) ScaleTransform(s))
5151
@test all(params(k) .== params([s], P, km))
5252

5353
c = Chain(Dense(3, 2))

test/transform/chaintransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
rng = MersenneTwister(123546)
33

44
P = rand(rng, 3, 2)
5-
tp = LowRankTransform(P)
5+
tp = LinearTransform(P)
66

77
f(x) = sin.(x)
88
tf = FunctionTransform(f)

test/transform/lowranktransform.jl renamed to test/transform/lineartransform.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
@testset "lowranktransform" begin
1+
@testset "lineartransform" begin
22
rng = MersenneTwister(123546)
33

44
@testset "Real inputs" begin
55
P = randn(rng, 3, 1)
6-
t = LowRankTransform(P)
6+
t = LinearTransform(P)
77

88
x = randn(rng, 4)
99
x′ = map(t, x)
@@ -16,7 +16,7 @@
1616
Din = 3
1717
Dout = 4
1818
P = randn(rng, Dout, Din)
19-
t = LowRankTransform(P)
19+
t = LinearTransform(P)
2020

2121
x_cols = ColVecs(randn(rng, Din, 8))
2222
x_rows = RowVecs(randn(rng, 9, Din))
@@ -31,14 +31,14 @@
3131
Din = 2
3232
Dout = 5
3333
P = randn(rng, Dout, Din)
34-
t = LowRankTransform(P)
34+
t = LinearTransform(P)
3535

3636
P2 = randn(rng, Dout, Din)
3737
KernelFunctions.set!(t, P2)
38-
@test t.proj == P2
39-
@test_throws AssertionError KernelFunctions.set!(t, rand(rng, Din + 1, Dout))
38+
@test t.A == P2
39+
@test_throws ErrorException KernelFunctions.set!(t, rand(rng, Din + 1, Dout))
4040

4141
@test_throws DimensionMismatch map(t, ColVecs(randn(rng, Din + 1, Dout)))
4242

43-
@test repr(t) == "Low Rank Transform (size(P) = ($Dout, $Din))"
43+
@test repr(t) == "Linear transform (size(A) = ($Dout, $Din))"
4444
end

0 commit comments

Comments
 (0)