Skip to content

Commit 55f4909

Browse files
authored
Patch AD performance bug (#272)
* Replace Base.Fix1 with regular closure * Bump patch
1 parent 13985cc commit 55f4909

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.9.0"
3+
version = "0.9.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/matrix/kernelmatrix.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,32 +80,32 @@ kernelmatrix_diag(κ::Kernel, x::AbstractVector, y::AbstractVector) = map(κ, x,
8080
function kernelmatrix!(K::AbstractMatrix, κ::SimpleKernel, x::AbstractVector)
8181
validate_inplace_dims(K, x)
8282
pairwise!(K, metric(κ), x)
83-
return map!(Base.Fix1(kappa, κ), K, K)
83+
return map!(x -> kappa(κ, x), K, K)
8484
end
8585

8686
function kernelmatrix!(
8787
K::AbstractMatrix, κ::SimpleKernel, x::AbstractVector, y::AbstractVector
8888
)
8989
validate_inplace_dims(K, x, y)
9090
pairwise!(K, metric(κ), x, y)
91-
return map!(Base.Fix1(kappa, κ), K, K)
91+
return map!(x -> kappa(κ, x), K, K)
9292
end
9393

9494
function kernelmatrix::SimpleKernel, x::AbstractVector)
95-
return map(Base.Fix1(kappa, κ), pairwise(metric(κ), x))
95+
return map(x -> kappa(κ, x), pairwise(metric(κ), x))
9696
end
9797

9898
function kernelmatrix::SimpleKernel, x::AbstractVector, y::AbstractVector)
9999
validate_inputs(x, y)
100-
return map(Base.Fix1(kappa, κ), pairwise(metric(κ), x, y))
100+
return map(x -> kappa(κ, x), pairwise(metric(κ), x, y))
101101
end
102102

103103
function kernelmatrix_diag::SimpleKernel, x::AbstractVector)
104-
return map(Base.Fix1(kappa, κ), colwise(metric(κ), x))
104+
return map(x -> kappa(κ, x), colwise(metric(κ), x))
105105
end
106106

107107
function kernelmatrix_diag::SimpleKernel, x::AbstractVector, y::AbstractVector)
108-
return map(Base.Fix1(kappa, κ), colwise(metric(κ), x, y))
108+
return map(x -> kappa(κ, x), colwise(metric(κ), x, y))
109109
end
110110

111111
#

0 commit comments

Comments
 (0)