Skip to content

Commit 049dd9e

Browse files
committed
Accelerate connection sets merging with union-find
Baseline: ```julia julia> @time run_and_time_julia!(ss_times, times, max_sizes, 1, 100) 0.960393 seconds (8.15 M allocations: 540.022 MiB, 4.47% gc time) 0.888558584 julia> @time run_and_time_julia!(ss_times, times, max_sizes, 1, 200) 2.593054 seconds (17.95 M allocations: 1.131 GiB, 3.75% gc time) 2.465012458 julia> @time run_and_time_julia!(ss_times, times, max_sizes, 1, 300) 5.065673 seconds (29.41 M allocations: 1.821 GiB, 5.90% gc time) 4.861177375 ``` PR: ```julia julia> @time run_and_time_julia!(ss_times, times, max_sizes, 1, 100); 0.748587 seconds (7.61 M allocations: 513.135 MiB, 7.15% gc time) julia> @time run_and_time_julia!(ss_times, times, max_sizes, 1, 200); 1.681521 seconds (15.75 M allocations: 1.027 GiB, 7.71% gc time) julia> @time run_and_time_julia!(ss_times, times, max_sizes, 1, 300); 2.931254 seconds (24.43 M allocations: 1.590 GiB, 11.97% gc time) ```
1 parent 988caca commit 049dd9e

File tree

1 file changed

+28
-40
lines changed

1 file changed

+28
-40
lines changed

src/systems/connectors.jl

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ end
187187
struct ConnectionSet
188188
set::Vector{ConnectionElement} # namespace.sys, var, isouter
189189
end
190+
ConnectionSet() = ConnectionSet(ConnectionElement[])
190191
Base.copy(c::ConnectionSet) = ConnectionSet(copy(c.set))
191192

192193
function Base.show(io::IO, c::ConnectionSet)
@@ -373,51 +374,38 @@ function generate_connection_set!(connectionsets, domain_csets,
373374
end
374375

375376
function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
376-
csets, merged = partial_merge(csets, allouter)
377-
while merged
378-
csets, merged = partial_merge(csets)
379-
end
380-
csets
381-
end
382-
383-
function partial_merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
384-
mcsets = ConnectionSet[]
385377
ele2idx = Dict{ConnectionElement, Int}()
386-
cacheset = Set{ConnectionElement}()
387-
merged = false
388-
for (j, cset) in enumerate(csets)
389-
if allouter
390-
cset = ConnectionSet(map(withtrueouter, cset.set))
391-
end
392-
idx = nothing
393-
for e in cset.set
394-
idx = get(ele2idx, e, nothing)
395-
if idx !== nothing
396-
merged = true
397-
break
378+
idx2ele = ConnectionElement[]
379+
union_find = IntDisjointSets(0)
380+
prev_id = Ref(-1)
381+
for cset in csets, (j, s) in enumerate(cset.set)
382+
v = allouter ? withtrueouter(s) : s
383+
id = let ele2idx = ele2idx, idx2ele = idx2ele
384+
get!(ele2idx, v) do
385+
push!(idx2ele, v)
386+
id = length(idx2ele)
387+
id′ = push!(union_find)
388+
@assert id == id′
389+
id
398390
end
399391
end
400-
if idx === nothing
401-
push!(mcsets, copy(cset))
402-
for e in cset.set
403-
ele2idx[e] = length(mcsets)
404-
end
405-
else
406-
for e in mcsets[idx].set
407-
push!(cacheset, e)
408-
end
409-
for e in cset.set
410-
push!(cacheset, e)
411-
end
412-
empty!(mcsets[idx].set)
413-
for e in cacheset
414-
ele2idx[e] = idx
415-
push!(mcsets[idx].set, e)
416-
end
417-
empty!(cacheset)
392+
if j > 1
393+
union!(union_find, prev_id[], id)
394+
end
395+
prev_id[] = id
396+
end
397+
id2set = Dict{Int, ConnectionSet}()
398+
merged_set = ConnectionSet[]
399+
for (id, ele) in enumerate(idx2ele)
400+
rid = find_root(union_find, id)
401+
set = get!(id2set, rid) do
402+
set = ConnectionSet()
403+
push!(merged_set, set)
404+
set
418405
end
406+
push!(set.set, ele)
419407
end
420-
mcsets, merged
408+
merged_set
421409
end
422410

423411
function generate_connection_equations_and_stream_connections(csets::AbstractVector{

0 commit comments

Comments
 (0)