You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried to build a model, where in its forward function I need to batch the input GNNGraphs. In the backprop process an error pops up:
ERROR: Mutating arrays is not supported -- called copyto!(Vector{Symbol}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations
I also wrote a dummy example to reproduce this error:
using Zygote
using GNNGraphs
using Flux
using CUDA
g1 = GNNGraph([1,2,3], [2,3,4])
g2 = GNNGraph([1,2,3], [2,4,5])
function test_fn(x)
graphs = [g1, g2]
gs = batch(graphs)
return sum(gs.num_nodes)
end
# This will error
gradient(test_fn, 1.0)
I wonder if there is a way to workaround this? Or is it valid to simply use Zygote.@nograd batch? Any insights are welcome:)
The text was updated successfully, but these errors were encountered:
Yes batch is not differentiable at the moment. One typically does the batching before taking the gradient though.
If used within a gradient context, it should be used like this:
julia>using GNNGraphs, Zygote
julia>functiontest_fn(x, graphs)
gs = Zygote.ignore_derivatives() dobatch(graphs)
endreturnsum(gs.num_nodes * x)
end
test_fn (generic function with 2 methods)
julia> graphs = [GNNGraph([1,2,3], [2,3,4]), GNNGraph([1,2,3], [2,4,5])]
2-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:GNNGraph(4, 3) with no data
GNNGraph(5, 3) with no data
julia>gradient(test_fn, 2.0, graphs)
(9.0, nothing)
Hello! Many thanks for your great job!
I tried to build a model, where in its forward function I need to batch the input GNNGraphs. In the backprop process an error pops up:
I also wrote a dummy example to reproduce this error:
I wonder if there is a way to workaround this? Or is it valid to simply use
Zygote.@nograd batch
? Any insights are welcome:)The text was updated successfully, but these errors were encountered: