-
Notifications
You must be signed in to change notification settings - Fork 51
GPU memory filling up #150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Any update on this? It seems like a major issue to me as it slows down any form of inference significantly. |
I tried to reduce the example, it turns out that the problem is not GNN.jl related but due to the allocation strategy of CUDA.jl (of which I know nothing about). In your original example the comparison wasn't measuring comparable operations for GNN and NN. Here are a few comparisons using GraphNeuralNetworks, CUDA, Flux
N = 10000
I = 10
function test_mem(n, data)
for i in 1:I
y = n(data)
CUDA.memory_status()
end
end
GC.gc(); CUDA.reclaim();
println("GNN, memory filling")
g = GNNGraph(collect(1:N-1), collect(2:N), num_nodes = N, ndata = rand(Float32, 1, N)) |> gpu
gnnchain = GNNChain(Dense(1, 1000), Dense(1000, 1)) |> gpu
CUDA.@time test_mem(gnnchain, g)
GC.gc(); CUDA.reclaim();
println("\n\nNN equivalent to GNN, memory filling")
x = g.ndata.x
chain = Chain(gnnchain.layers...)
@assert gnnchain(g).ndata.x ≈ chain(x)
## same results with these
# x = rand(1, N) |> gpu
# chain = Chain(Dense(1, 1000), Dense(1000, 1)) |> gpu
CUDA.@time test_mem(chain, x)
GC.gc(); CUDA.reclaim();
println("\n\nNN 1, same memory")
data = rand(N, 1) |> gpu
n = Chain(Dense(N, 1000), Dense(1000, 1)) |> gpu
CUDA.@time test_mem(n, data)
println("\n\nNN 2, memory filling:")
data = rand(N, N) |> gpu
n = Chain(Dense(N, 1000), Dense(1000, 1)) |> gpu
CUDA.@time test_mem(n, data)
|
I see, my bad. I wrongly assumed the equivalent of the GNN to be |
Uh oh!
There was an error while loading. Please reload this page.
When repeatedly performing inference operations, the GPU memory gets filled up quite fast. This causes the GPU to have to perform garbage collection. In my implementation, this accounted for 50% of GPU time as GC was needed every couple seconds. Using normal NN the GPU memory increases much slower.
The text was updated successfully, but these errors were encountered: