Skip to content

Commit 5f2a594

Browse files
authored
Merge pull request #2824 from SciML/myb/union_find
Accelerate connection sets merging with union-find
2 parents ddcf59e + 53eab1e commit 5f2a594

File tree

1 file changed

+40
-41
lines changed

1 file changed

+40
-41
lines changed

src/systems/connectors.jl

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,16 @@ end
163163
Base.nameof(l::ConnectionElement) = renamespace(nameof(l.sys), getname(l.v))
164164
Base.isequal(l1::ConnectionElement, l2::ConnectionElement) = l1 == l2
165165
function Base.:(==)(l1::ConnectionElement, l2::ConnectionElement)
166-
nameof(l1.sys) == nameof(l2.sys) && isequal(l1.v, l2.v) && l1.isouter == l2.isouter
166+
l1.isouter == l2.isouter && nameof(l1.sys) == nameof(l2.sys) && isequal(l1.v, l2.v)
167167
end
168168

169169
const _debug_mode = Base.JLOptions().check_bounds == 1
170170

171+
function Base.show(io::IO, c::ConnectionElement)
172+
@unpack sys, v, isouter = c
173+
print(io, nameof(sys), ".", v, "::", isouter ? "outer" : "inner")
174+
end
175+
171176
function Base.hash(e::ConnectionElement, salt::UInt)
172177
if _debug_mode
173178
@assert e.h === _hash_impl(e.sys, e.v, e.isouter)
@@ -187,7 +192,10 @@ end
187192
struct ConnectionSet
188193
set::Vector{ConnectionElement} # namespace.sys, var, isouter
189194
end
195+
ConnectionSet() = ConnectionSet(ConnectionElement[])
190196
Base.copy(c::ConnectionSet) = ConnectionSet(copy(c.set))
197+
Base.:(==)(a::ConnectionSet, b::ConnectionSet) = a.set == b.set
198+
Base.sort(a::ConnectionSet) = ConnectionSet(sort(a.set, by = string))
191199

192200
function Base.show(io::IO, c::ConnectionSet)
193201
print(io, "<")
@@ -373,51 +381,42 @@ function generate_connection_set!(connectionsets, domain_csets,
373381
end
374382

375383
function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
376-
csets, merged = partial_merge(csets, allouter)
377-
while merged
378-
csets, merged = partial_merge(csets)
379-
end
380-
csets
381-
end
382-
383-
function partial_merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
384-
mcsets = ConnectionSet[]
385384
ele2idx = Dict{ConnectionElement, Int}()
386-
cacheset = Set{ConnectionElement}()
387-
merged = false
388-
for (j, cset) in enumerate(csets)
389-
if allouter
390-
cset = ConnectionSet(map(withtrueouter, cset.set))
391-
end
392-
idx = nothing
393-
for e in cset.set
394-
idx = get(ele2idx, e, nothing)
395-
if idx !== nothing
396-
merged = true
397-
break
385+
idx2ele = ConnectionElement[]
386+
union_find = IntDisjointSets(0)
387+
prev_id = Ref(-1)
388+
for cset in csets, (j, s) in enumerate(cset.set)
389+
v = allouter ? withtrueouter(s) : s
390+
id = let ele2idx = ele2idx, idx2ele = idx2ele
391+
get!(ele2idx, v) do
392+
push!(idx2ele, v)
393+
id = length(idx2ele)
394+
id′ = push!(union_find)
395+
@assert id == id′
396+
id
398397
end
399398
end
400-
if idx === nothing
401-
push!(mcsets, copy(cset))
402-
for e in cset.set
403-
ele2idx[e] = length(mcsets)
404-
end
405-
else
406-
for e in mcsets[idx].set
407-
push!(cacheset, e)
408-
end
409-
for e in cset.set
410-
push!(cacheset, e)
411-
end
412-
empty!(mcsets[idx].set)
413-
for e in cacheset
414-
ele2idx[e] = idx
415-
push!(mcsets[idx].set, e)
416-
end
417-
empty!(cacheset)
399+
# isequal might not be equal? lol
400+
if v.sys.namespace !== nothing
401+
idx2ele[id] = v
402+
end
403+
if j > 1
404+
union!(union_find, prev_id[], id)
405+
end
406+
prev_id[] = id
407+
end
408+
id2set = Dict{Int, Int}()
409+
merged_set = ConnectionSet[]
410+
for (id, ele) in enumerate(idx2ele)
411+
rid = find_root(union_find, id)
412+
set_idx = get!(id2set, rid) do
413+
set = ConnectionSet()
414+
push!(merged_set, set)
415+
length(merged_set)
418416
end
417+
push!(merged_set[set_idx].set, ele)
419418
end
420-
mcsets, merged
419+
merged_set
421420
end
422421

423422
function generate_connection_equations_and_stream_connections(csets::AbstractVector{

0 commit comments

Comments
 (0)