Skip to content

Accelerate connection sets merging with union-find #2824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 40 additions & 41 deletions src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,16 @@ end
Base.nameof(l::ConnectionElement) = renamespace(nameof(l.sys), getname(l.v))
Base.isequal(l1::ConnectionElement, l2::ConnectionElement) = l1 == l2
function Base.:(==)(l1::ConnectionElement, l2::ConnectionElement)
nameof(l1.sys) == nameof(l2.sys) && isequal(l1.v, l2.v) && l1.isouter == l2.isouter
l1.isouter == l2.isouter && nameof(l1.sys) == nameof(l2.sys) && isequal(l1.v, l2.v)
end

const _debug_mode = Base.JLOptions().check_bounds == 1

function Base.show(io::IO, c::ConnectionElement)
@unpack sys, v, isouter = c
print(io, nameof(sys), ".", v, "::", isouter ? "outer" : "inner")
end

function Base.hash(e::ConnectionElement, salt::UInt)
if _debug_mode
@assert e.h === _hash_impl(e.sys, e.v, e.isouter)
Expand All @@ -187,7 +192,10 @@ end
struct ConnectionSet
set::Vector{ConnectionElement} # namespace.sys, var, isouter
end
ConnectionSet() = ConnectionSet(ConnectionElement[])
Base.copy(c::ConnectionSet) = ConnectionSet(copy(c.set))
Base.:(==)(a::ConnectionSet, b::ConnectionSet) = a.set == b.set
Base.sort(a::ConnectionSet) = ConnectionSet(sort(a.set, by = string))

function Base.show(io::IO, c::ConnectionSet)
print(io, "<")
Expand Down Expand Up @@ -373,51 +381,42 @@ function generate_connection_set!(connectionsets, domain_csets,
end

function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
csets, merged = partial_merge(csets, allouter)
while merged
csets, merged = partial_merge(csets)
end
csets
end

function partial_merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
mcsets = ConnectionSet[]
ele2idx = Dict{ConnectionElement, Int}()
cacheset = Set{ConnectionElement}()
merged = false
for (j, cset) in enumerate(csets)
if allouter
cset = ConnectionSet(map(withtrueouter, cset.set))
end
idx = nothing
for e in cset.set
idx = get(ele2idx, e, nothing)
if idx !== nothing
merged = true
break
idx2ele = ConnectionElement[]
union_find = IntDisjointSets(0)
prev_id = Ref(-1)
for cset in csets, (j, s) in enumerate(cset.set)
v = allouter ? withtrueouter(s) : s
id = let ele2idx = ele2idx, idx2ele = idx2ele
get!(ele2idx, v) do
push!(idx2ele, v)
id = length(idx2ele)
id′ = push!(union_find)
@assert id == id′
id
end
end
if idx === nothing
push!(mcsets, copy(cset))
for e in cset.set
ele2idx[e] = length(mcsets)
end
else
for e in mcsets[idx].set
push!(cacheset, e)
end
for e in cset.set
push!(cacheset, e)
end
empty!(mcsets[idx].set)
for e in cacheset
ele2idx[e] = idx
push!(mcsets[idx].set, e)
end
empty!(cacheset)
# isequal might not be equal? lol
if v.sys.namespace !== nothing
idx2ele[id] = v
end
if j > 1
union!(union_find, prev_id[], id)
end
prev_id[] = id
end
id2set = Dict{Int, Int}()
merged_set = ConnectionSet[]
for (id, ele) in enumerate(idx2ele)
rid = find_root(union_find, id)
set_idx = get!(id2set, rid) do
set = ConnectionSet()
push!(merged_set, set)
length(merged_set)
end
push!(merged_set[set_idx].set, ele)
end
mcsets, merged
merged_set
end

function generate_connection_equations_and_stream_connections(csets::AbstractVector{
Expand Down
Loading