-
Notifications
You must be signed in to change notification settings - Fork 36
Fix gradient issues with kernelmatrix_diag and use ChainRulesCore #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 42 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
e525614
Use broadcasting instead of map for kerneldiagmatrix
theogf e56492a
Removed method for transformedkernel
theogf 35a6306
Restored functions and applied suggestions
theogf 25e5efd
Added tests for diagmatrix
theogf 2f85ebc
Put changes to the right file and removed utils_AD.jl
theogf cae225f
Apply suggestions from code review
theogf 3f16f07
Added colwise and fixed kerneldiagmatrix
theogf 8c0d0a2
Added colwise for RowVecs and ColVecs
theogf 13a10fd
Removed definition relying on Distances.colwise!
theogf 78a2078
Merge branch 'master' into fix_diagmat
theogf 5ca94e7
Readapt to kernelmatrix_diag
theogf 2c60abd
Fixes for Zygote
theogf 9214211
Remove type piracy
theogf 87edbc8
Adding some adjoints (not everything fixed yet)
theogf f65556b
Fixed adjoint for polynomials
theogf 48e2dcb
Add ChainRulesCore for defining rrule
theogf 6cc803d
Replace broadcast by map
theogf 0e30941
Missing return for style
theogf 61869b1
Fixing ZygoteRules
theogf 06bd4f0
Renamed zygote_adjoints to chainrules
theogf 8e1e516
Apply formatting suggestions
theogf aaa16de
Added forward rule for Euclidean distance
theogf 52b1ae5
Corrected rules for Row/ColVecs constructors
theogf 4067a42
Added ZygoteRules back for the "map hack"
theogf 641ebee
Corrected the rrules
theogf 13d1e39
Type stable frule
theogf 4675c2f
Corrected tests
theogf 0b97c1a
Adapted the use of Distances.jl
theogf ad9838e
Added methods to make nn work
theogf 650dc08
Missing kernelmatrix_diag
theogf 1703db1
Formatting suggestions
theogf e2cd167
Added methods for FBM
theogf 01ffac0
Last fix on Delta
theogf 9bfb6eb
Potential fix for Euclidean
theogf f3fa4bc
Missing Distances.
theogf a0c2a64
Wrong file naming
theogf ff5a66b
Correct formatting
theogf 8157b4c
Better error message
theogf e6bfdb1
Moar formatting
theogf db5e7b8
Applied suggestions
theogf a44a762
Fixed the dims issue with pairwise
theogf 72889dd
Fixed formatting
theogf 25549c1
Missing @thunk
theogf bbe5c7c
Putting back Composite to Any
theogf e08dbf4
add @thunk for -delta a
theogf 48bd681
Update src/chainrules.jl
theogf 3298d34
Update KernelFunctions.jl
theogf 0b99771
Apply suggestions from code review
theogf c26edf3
Update Project.toml
theogf 647862a
Merge branch 'master' into fix_diagmat
theogf File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
## Forward Rules | ||
|
||
# Note that this is type piracy as the derivative should be NaN for x == y. | ||
function ChainRulesCore.frule( | ||
(_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector | ||
) | ||
Δ = x - y | ||
D = sqrt(sum(abs2, Δ)) | ||
if !iszero(D) | ||
Δ ./= D | ||
end | ||
return D, dot(Δ, Δx) - dot(Δ, Δy) | ||
end | ||
|
||
## Reverse Rules Delta | ||
|
||
function ChainRulesCore.rrule(dist::Delta, x::AbstractVector, y::AbstractVector) | ||
d = dist(x, y) | ||
function evaluate_pullback(::Any) | ||
return NO_FIELDS, Zero(), Zero() | ||
end | ||
return d, evaluate_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2 | ||
) | ||
P = Distances.pairwise(d, X, Y; dims=dims) | ||
function pairwise_pullback(::AbstractMatrix) | ||
return NO_FIELDS, NO_FIELDS, Zero(), Zero() | ||
end | ||
return P, pairwise_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2 | ||
) | ||
P = Distances.pairwise(d, X; dims=dims) | ||
function pairwise_pullback(::AbstractMatrix) | ||
return NO_FIELDS, NO_FIELDS, Zero() | ||
end | ||
return P, pairwise_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix | ||
) | ||
C = Distances.colwise(d, X, Y) | ||
function colwise_pullback(::AbstractVector) | ||
return NO_FIELDS, NO_FIELDS, Zero(), Zero() | ||
end | ||
return C, colwise_pullback | ||
end | ||
|
||
## Reverse Rules DotProduct | ||
|
||
function ChainRulesCore.rrule(dist::DotProduct, x::AbstractVector, y::AbstractVector) | ||
d = dist(x, y) | ||
function evaluate_pullback(Δ::Any) | ||
return NO_FIELDS, Δ .* y, Δ .* x | ||
end | ||
return d, evaluate_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.pairwise), | ||
d::DotProduct, | ||
X::AbstractMatrix, | ||
Y::AbstractMatrix; | ||
dims=2, | ||
) | ||
P = Distances.pairwise(d, X, Y; dims=dims) | ||
function pairwise_pullback_cols(Δ::AbstractMatrix) | ||
if dims == 1 | ||
return NO_FIELDS, NO_FIELDS, Δ * Y, Δ' * X | ||
else | ||
return NO_FIELDS, NO_FIELDS, Y * Δ', X * Δ | ||
end | ||
end | ||
return P, pairwise_pullback_cols | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2 | ||
) | ||
P = Distances.pairwise(d, X; dims=dims) | ||
function pairwise_pullback_cols(Δ::AbstractMatrix) | ||
if dims == 1 | ||
return NO_FIELDS, NO_FIELDS, 2 * Δ * X | ||
else | ||
return NO_FIELDS, NO_FIELDS, 2 * X * Δ | ||
end | ||
end | ||
return P, pairwise_pullback_cols | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix | ||
) | ||
C = Distances.colwise(d, X, Y) | ||
function colwise_pullback(Δ::AbstractVector) | ||
return NO_FIELDS, NO_FIELDS, Δ' .* Y, Δ' .* X | ||
end | ||
return C, colwise_pullback | ||
end | ||
|
||
## Reverse Rules Sinus | ||
|
||
function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) | ||
d = x - y | ||
sind = sinpi.(d) | ||
abs2_sind_r = abs2.(sind) ./ s.r | ||
val = sum(abs2_sind_r) | ||
gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2) | ||
function evaluate_pullback(Δ::Any) | ||
return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx | ||
end | ||
return val, evaluate_pullback | ||
end | ||
|
||
## Reverse Rulse SqMahalanobis | ||
|
||
function ChainRulesCore.rrule( | ||
dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector | ||
) | ||
d = dist(a, b) | ||
function SqMahalanobis_pullback(Δ::Real) | ||
B_Bᵀ = dist.qmat + transpose(dist.qmat) | ||
a_b = a - b | ||
δa = @thunk((B_Bᵀ * a_b) * Δ) | ||
return (qmat=(a_b * a_b') * Δ,), δa, -δa | ||
end | ||
return d, SqMahalanobis_pullback | ||
end | ||
|
||
## Reverse Rules for matrix wrappers | ||
|
||
function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) | ||
ColVecs_pullback(Δ::Composite{<:ColVecs}) = (NO_FIELDS, Δ.X) | ||
function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) | ||
return error( | ||
"Pullback on AbstractVector{<:AbstractVector}.\n" * | ||
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * | ||
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`", | ||
) | ||
end | ||
return ColVecs(X), ColVecs_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) | ||
RowVecs_pullback(Δ::Composite{<:RowVecs}) = (NO_FIELDS, Δ.X) | ||
function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) | ||
return error( | ||
"Pullback on AbstractVector{<:AbstractVector}.\n" * | ||
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * | ||
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", | ||
) | ||
end | ||
return RowVecs(X), RowVecs_pullback | ||
end | ||
|
||
theogf marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.