Skip to content

Commit 7143f87

Browse files
devmotiontheogfgithub-actions[bot]
authored
Add median_heuristic_transform (#245)
* Add `median_heuristic_transform` * Move `median_heuristic_transform` to convenience fcns in the docs * Only compute pairwise distances between different elements * Update src/transform/scaletransform.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix tests * Update test/transform/scaletransform.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml Co-authored-by: Théo Galy-Fajou <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 99b53c6 commit 7143f87

File tree

8 files changed

+61
-1
lines changed

8 files changed

+61
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.34"
3+
version = "0.10.35"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -15,6 +15,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1515
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1616
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1717
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
18+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1819
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1920
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
2021
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

docs/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
[deps]
2+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
23
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
34
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
45
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
56
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
7+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
68

79
[compat]
10+
Distances = "0.10"
811
Documenter = "0.27"
912
KernelFunctions = "0.10"
1013
Kronecker = "0.4, 0.5"

docs/src/transform.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,5 @@ PeriodicTransform
4141

4242
```@docs
4343
with_lengthscale
44+
median_heuristic_transform
4445
```

src/KernelFunctions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ export Transform,
3030
PeriodicTransform
3131
export with_lengthscale
3232

33+
export median_heuristic_transform
34+
3335
export NystromFact, nystrom
3436

3537
export gaborkernel
@@ -63,6 +65,8 @@ using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield
6365
# Hack to work around Zygote type inference problems.
6466
const Distances_pairwise = Distances.pairwise
6567

68+
using Statistics: median!
69+
6670
abstract type Kernel end
6771
abstract type SimpleKernel <: Kernel end
6872

src/transform/scaletransform.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,39 @@ _map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)
3333
Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s))
3434

3535
Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", only(t.s), ")")
36+
37+
# Helpers
38+
39+
"""
40+
median_heuristic_transform(distance, x::AbstractVector)
41+
42+
Create a [`ScaleTransform`](@ref) that divides the input elementwise by the median
43+
`distance` of the data points in `x`.
44+
45+
The `distance` has to support pairwise evaluation with `KernelFunctions.pairwise`. All
46+
`PreMetric`s of the package [Distances.jl](https://github.com/JuliaStats/Distances.jl) such
47+
as `Euclidean` satisfy this requirement automatically.
48+
49+
# Examples
50+
51+
```jldoctest
52+
julia> using Distances, Statistics
53+
54+
julia> x = ColVecs(rand(100, 10));
55+
56+
julia> t = median_heuristic_transform(Euclidean(), x);
57+
58+
julia> y = map(t, x);
59+
60+
julia> median(euclidean(y[i], y[j]) for i in 1:10, j in 1:10 if i != j) ≈ 1
61+
true
62+
```
63+
"""
64+
function median_heuristic_transform(f, x::AbstractVector)
65+
# Compute pairwise distances between **different** elements
66+
n = length(x)
67+
distances = vec(pairwise(f, x))
68+
deleteat!(distances, 1:(n + 1):(n^2))
69+
70+
return ScaleTransform(inv(median!(distances)))
71+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1515
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
16+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1718
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1819

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using LogExpFunctions
99
using PDMats
1010
using Random
1111
using SpecialFunctions
12+
using Statistics
1213
using Test
1314
using Zygote: Zygote
1415
using ForwardDiff: ForwardDiff

test/transform/scaletransform.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,17 @@
2020
@test isequal(ScaleTransform(s), ScaleTransform(s))
2121
@test repr(t) == "Scale Transform (s = $(s2))"
2222
test_ADs(x -> SEKernel() ScaleTransform(exp(x[1])), randn(rng, 1))
23+
24+
@testset "median heuristic" begin
25+
for x in (x, XV, XC, XR), dist in (Euclidean(), Cityblock())
26+
n = length(x)
27+
t = median_heuristic_transform(dist, x)
28+
@test t isa ScaleTransform
29+
@test first(t.s)
30+
inv(median(dist(x[i], x[j]) for i in 1:n, j in 1:n if i != j))
31+
32+
y = map(t, x)
33+
@test median(dist(y[i], y[j]) for i in 1:n, j in 1:n if i != j) 1
34+
end
35+
end
2336
end

0 commit comments

Comments
 (0)