Skip to content

Commit c7d0afe

Browse files
MLUtils and Flux v0.13 compatibility (#155)
* MLUtils and Flux v0.13 compatibility * bimp version * cleanup * cleanup * import MLUtils * using numobs, getobs
1 parent 2801b51 commit c7d0afe

File tree

8 files changed

+22
-153
lines changed

8 files changed

+22
-153
lines changed

Project.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphNeuralNetworks"
22
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.3.15"
4+
version = "0.4.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -12,8 +12,8 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1212
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1313
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1414
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
15-
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
1615
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
16+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1717
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1818
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1919
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
@@ -29,17 +29,17 @@ Adapt = "3"
2929
CUDA = "3.3"
3030
ChainRulesCore = "1"
3131
DataStructures = "0.18"
32-
Flux = "0.12.7"
32+
Flux = "0.13"
3333
Functors = "0.2"
3434
Graphs = "1.4"
3535
KrylovKit = "0.5"
36-
LearnBase = "0.4, 0.5, 0.6"
36+
MLUtils = "0.2.3"
3737
MacroTools = "0.5"
38-
NNlib = "0.7, 0.8"
39-
NNlibCUDA = "0.1, 0.2"
38+
NNlib = "0.8"
39+
NNlibCUDA = "0.2"
4040
NearestNeighbors = "0.4"
4141
Reexport = "1"
42-
StatsBase = "0.32, 0.33"
42+
StatsBase = "0.33"
4343
julia = "1.6"
4444

4545
[extras]

src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ import Flux
99
using Flux: batch
1010
import NearestNeighbors
1111
import NNlib
12-
import LearnBase
1312
import StatsBase
1413
import KrylovKit
1514
using ChainRulesCore
1615
using LinearAlgebra, Random, Statistics
16+
import MLUtils
17+
using MLUtils: getobs, numobs
1718

1819
include("gnngraph.jl")
1920
export GNNGraph,

src/GNNGraphs/gnngraph.jl

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,14 @@ function Base.show(io::IO, g::GNNGraph)
223223
end
224224
end
225225

226-
### StatsBase/LearnBase compatibility
227-
StatsBase.nobs(g::GNNGraph) = g.num_graphs
228-
LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i)
229-
230-
# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683
231-
Flux.Data._nobs(g::GNNGraph) = g.num_graphs
232-
Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i)
226+
MLUtils.numobs(g::GNNGraph) = g.num_graphs
227+
MLUtils.getobs(g::GNNGraph, i) = getgraph(g, i)
233228

234229
# DataLoader compatibility passing a vector of graphs and
235230
# effectively using `batch` as a collated function.
236-
StatsBase.nobs(data::Vector{<:GNNGraph}) = length(data)
237-
LearnBase.getobs(data::Vector{<:GNNGraph}, i::Int) = data[i]
238-
LearnBase.getobs(data::Vector{<:GNNGraph}, i) = Flux.batch(data[i])
239-
Flux.Data._nobs(g::Vector{<:GNNGraph}) = StatsBase.nobs(g)
240-
Flux.Data._getobs(g::Vector{<:GNNGraph}, i) = LearnBase.getobs(g, i)
231+
MLUtils.numobs(data::Vector{<:GNNGraph}) = length(data)
232+
MLUtils.getobs(data::Vector{<:GNNGraph}, i::Int) = data[i]
233+
MLUtils.getobs(data::Vector{<:GNNGraph}, i) = Flux.batch(data[i])
241234

242235

243236
#########################

src/GNNGraphs/utils.jl

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -156,64 +156,3 @@ binarize(x) = map(>(0), x)
156156
@non_differentiable binarize(x...)
157157
@non_differentiable edge_encoding(x...)
158158
@non_differentiable edge_decoding(x...)
159-
160-
161-
162-
####################################
163-
# FROM MLBASE.jl
164-
# https://github.com/JuliaML/MLBase.jl/pull/1/files
165-
# remove when package is registered
166-
##############################################
167-
168-
numobs(A::AbstractArray{<:Any, N}) where {N} = size(A, N)
169-
170-
# 0-dim arrays
171-
numobs(A::AbstractArray{<:Any, 0}) = 1
172-
173-
function getobs(A::AbstractArray{<:Any, N}, idx) where N
174-
I = ntuple(_ -> :, N-1)
175-
return A[I..., idx]
176-
end
177-
178-
getobs(A::AbstractArray{<:Any, 0}, idx) = A[idx]
179-
180-
function getobs!(buffer::AbstractArray, A::AbstractArray{<:Any, N}, idx) where N
181-
I = ntuple(_ -> :, N-1)
182-
buffer .= A[I..., idx]
183-
return buffer
184-
end
185-
186-
# --------------------------------------------------------------------
187-
# Tuples and NamedTuples
188-
189-
_check_numobs_error() =
190-
throw(DimensionMismatch("All data containers must have the same number of observations."))
191-
192-
function _check_numobs(tup::Union{Tuple, NamedTuple})
193-
length(tup) == 0 && return
194-
n1 = numobs(tup[1])
195-
for i=2:length(tup)
196-
numobs(tup[i]) != n1 && _check_numobs_error()
197-
end
198-
end
199-
200-
function numobs(tup::Union{Tuple, NamedTuple})::Int
201-
_check_numobs(tup)
202-
return length(tup) == 0 ? 0 : numobs(tup[1])
203-
end
204-
205-
function getobs(tup::Union{Tuple, NamedTuple}, indices)
206-
_check_numobs(tup)
207-
return map(x -> getobs(x, indices), tup)
208-
end
209-
210-
function getobs!(buffers::Union{Tuple, NamedTuple},
211-
tup::Union{Tuple, NamedTuple},
212-
indices)
213-
_check_numobs(tup)
214-
215-
return map(buffers, tup) do buffer, x
216-
getobs!(buffer, x, indices)
217-
end
218-
end
219-
#######################################################

src/deprecations.jl

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +0,0 @@
1-
## Deprecated in v0.2
2-
3-
function compute_message end
4-
function update_node end
5-
function update_edge end
6-
7-
compute_message(l, xi, xj, e) = compute_message(l, xi, xj)
8-
update_node(l, x, m̄) =
9-
update_edge(l, e, m) = e
10-
11-
function propagate(l::GNNLayer, g::GNNGraph, aggr, x, e=nothing)
12-
@warn """
13-
Passing a GNNLayer to propagate is deprecated,
14-
you should pass the message function directly.
15-
The new signature is `propagate(f, g, aggr; [xi, xj, e])`.
16-
17-
The functions `compute_message`, `update_node`,
18-
and `update_edge` have been deprecated as well. Please
19-
refer to the documentation.
20-
"""
21-
m = apply_edges((a...) -> compute_message(l, a...), g, x, x, e)
22-
= aggregate_neighbors(g, aggr, m)
23-
x = update_node(l, x, m̄)
24-
e = update_edge(l, e, m)
25-
return x, e
26-
end
27-
28-
## Deprecated in v0.3
29-
30-
@deprecate copyxj(xi, xj, e) copy_xj(xi, xj, e)
31-
32-
@deprecate CGConv(nin::Int, ein::Int, out::Int, args...; kws...) CGConv((nin, ein) => out, args...; kws...)
33-

test/GNNGraphs/gnngraph.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@
245245
@test_throws AssertionError rand_graph(10, 30, ndata=1, graph_type=GRAPH_T)
246246
end
247247

248-
@testset "LearnBase and DataLoader compat" begin
248+
@testset "MLUtils and DataLoader compat" begin
249249
n, m, num_graphs = 10, 30, 50
250250
X = rand(10, n)
251251
E = rand(10, m)
@@ -255,18 +255,18 @@
255255
g = Flux.batch(data)
256256

257257
@testset "batch then pass to dataloader" begin
258-
@test LearnBase.getobs(g, 3) == getgraph(g, 3)
259-
@test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5)
260-
@test StatsBase.nobs(g) == g.num_graphs
258+
@test MLUtils.getobs(g, 3) == getgraph(g, 3)
259+
@test MLUtils.getobs(g, 3:5) == getgraph(g, 3:5)
260+
@test MLUtils.numobs(g) == g.num_graphs
261261

262262
d = Flux.Data.DataLoader(g, batchsize=2, shuffle=false)
263263
@test first(d) == getgraph(g, 1:2)
264264
end
265265

266266
@testset "pass to dataloader and collate" begin
267-
@test LearnBase.getobs(data, 3) == getgraph(g, 3)
268-
@test LearnBase.getobs(data, 3:5) == getgraph(g, 3:5)
269-
@test StatsBase.nobs(data) == g.num_graphs
267+
@test MLUtils.getobs(data, 3) == getgraph(g, 3)
268+
@test MLUtils.getobs(data, 3:5) == getgraph(g, 3:5)
269+
@test MLUtils.numobs(data) == g.num_graphs
270270

271271
d = Flux.Data.DataLoader(data, batchsize=2, shuffle=false)
272272
@test first(d) == getgraph(g, 1:2)

test/deprecations.jl

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,3 @@
11
@testset "deprecations" begin
2-
@testset "propagate" begin
3-
struct GCN{A<:AbstractMatrix, B, F} <: GNNLayer
4-
weight::A
5-
bias::B
6-
σ::F
7-
end
82

9-
Flux.@functor GCN # allow collecting params, gpu movement, etc...
10-
11-
function GCN(ch::Pair{Int,Int}, σ=identity)
12-
in, out = ch
13-
W = Flux.glorot_uniform(out, in)
14-
b = zeros(Float32, out)
15-
GCN(W, b, σ)
16-
end
17-
18-
GraphNeuralNetworks.compute_message(l::GCN, xi, xj, e) = xj
19-
20-
function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
21-
x, _ = propagate(l, g, +, x)
22-
return l.σ.(l.weight * x .+ l.bias)
23-
end
24-
25-
function new_forward(l, g, x)
26-
x = propagate(copy_xj, g, +, xj=x)
27-
return l.σ.(l.weight * x .+ l.bias)
28-
end
29-
30-
g = GNNGraph(random_regular_graph(10, 4), ndata=randn(3, 10))
31-
l = GCN(3 => 5, tanh)
32-
@test l(g, g.ndata.x) new_forward(l, g, g.ndata.x)
33-
end
343
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using CUDA
55
using Flux: gpu, @functor
66
using LinearAlgebra, Statistics, Random
77
using NNlib
8-
using LearnBase
8+
import MLUtils
99
import StatsBase
1010
using SparseArrays
1111
using Graphs

0 commit comments

Comments
 (0)