Skip to content

Fix AD issue with mixed input for TransformedKernel #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 31, 2020

Conversation

sharanry
Copy link
Contributor

Fix #151

@sharanry sharanry requested review from willtebbutt and theogf August 31, 2020 07:45
@devmotion
Copy link
Member

I don't think this fixes the actual problem, IMO we shouldn't have to handle and convert ColVecs and RowVecs inputs here. Could it be that it's just missing a vec_of_vecs definition for ColVecs and RowVecs or an adjoint for _map?

@sharanry
Copy link
Contributor Author

I don't think this fixes the actual problem, IMO we shouldn't have to handle and convert ColVecs and RowVecs inputs here. Could it be that it's just missing a vec_of_vecs definition for ColVecs and RowVecs or an adjoint for _map?

The transform map is applied individually to each input IIRC. So, I don't think it should cause a problem The problem is when executing the pullback function using the common result as a reference. Maybe defining adjoint for kernelmatrix of TransformedKernel should help? Isn't that where the inputs are combined to form a single output.

@devmotion
Copy link
Member

This shows that the problem is not caused by kernelmatrix:

julia> Zygote.gradient() do
           sum(sum(KernelFunctions._map(k.transform, X)))
       end
()

julia> Zygote.gradient() do
           sum(sum(KernelFunctions._map(k.transform, Y)))
       end
ERROR: In slow method
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::KernelFunctions.var"#back#173")(::FillArrays.Fill{FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}},1,Tuple{Base.OneTo{Int64}}}) at /home/david/.julia/dev/KernelFunctions/src/zygote_adjoints.jl:66
 [3] (::KernelFunctions.var"#159#back#174"{KernelFunctions.var"#back#173"})(::FillArrays.Fill{FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}},1,Tuple{Base.OneTo{Int64
}}}) at /home/david/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] _map at /home/david/.julia/dev/KernelFunctions/src/transform/scaletransform.jl:25 [inlined]
 [5] (::typeof((_map)))(::FillArrays.Fill{FillArrays.Fill{Float64,1,Tuple{Base.OneTo
{Int64}}},1,Tuple{Base.OneTo{Int64}}}) at /home/david/.julia/packages/Zygote/rqvFi/src/compiler/interface2.jl:0
 [6] #17 at ./REPL[25]:2 [inlined]
 [7] (::typeof((#17)))(::Float64) at /home/david/.julia/packages/Zygote/rqvFi/src/compiler/interface2.jl:0
 [8] (::Zygote.var"#41#42"{typeof((#17))})(::Float64) at /home/david/.julia/packages/Zygote/rqvFi/src/compiler/interface.jl:45
 [9] gradient(::Function) at /home/david/.julia/packages/Zygote/rqvFi/src/compiler/interface.jl:54
 [10] top-level scope at REPL[25]:1

@devmotion
Copy link
Member

Or _map:

julia> Zygote.gradient(rand(3, 3)) do x
           sum(sum(KernelFunctions.ColVecs(x)))
       end
ERROR: In slow method
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::KernelFunctions.var"#back#173")(::FillArrays.Fill{FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}},1,Tuple{Base.OneTo{Int64}}}) at /home/david/.julia/dev/KernelFunctions/src/zygote_adjoints.jl:66
 [3] (::KernelFunctions.var"#159#back#174"{KernelFunctions.var"#back#173"})(::FillArrays.Fill{FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}},1,Tuple{Base.OneTo{Int64}}}) at /home/david/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] #29 at ./REPL[30]:2 [inlined]
 [5] (::typeof((#29)))(::Float64) at /home/david/.julia/packages/Zygote/rqvFi/src/compiler/interface2.jl:0
 [6] (::Zygote.var"#41#42"{typeof((#29))})(::Float64) at /home/david/.julia/packages/Zygote/rqvFi/src/compiler/interface.jl:45
 [7] gradient(::Function, ::Array{Float64,2}) at /home/david/.julia/packages/Zygote/rqvFi/src/compiler/interface.jl:54
 [8] top-level scope at REPL[30]:1

@devmotion
Copy link
Member

I guess here the problem might be that it falls back to

pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector) = broadcast(d, X, permutedims(Y))
for mixed inputs.

@sharanry sharanry merged commit cfdb033 into master Aug 31, 2020
@devmotion devmotion deleted the sharan/fix-AD-bug branch August 31, 2020 16:28
@devmotion
Copy link
Member

Oh I forgot - we should bump the patch version such that we can make a new release @sharanry

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AD issue with mixed inputs
2 participants