Skip to content

Getcoeffchunk #2393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand Down Expand Up @@ -74,6 +75,7 @@ Distributions = "0.23, 0.24, 0.25"
DocStringExtensions = "0.7, 0.8, 0.9"
DomainSets = "0.6"
DynamicQuantities = "^0.11.2"
FindFirstFunctions = "1"
ForwardDiff = "0.10.3"
FunctionWrappersWrappers = "0.1"
Graphs = "1.5.2"
Expand Down
11 changes: 2 additions & 9 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ the `constraint`.
mask,
constraint)
eadj = M.row_cols
for i in range
@inbounds for i in range
vertices = eadj[i]
if constraint(length(vertices))
for (j, v) in enumerate(vertices)
Expand All @@ -170,7 +170,7 @@ end
range,
mask,
constraint)
for i in range
@inbounds for i in range
row = @view M[i, :]
if constraint(count(!iszero, row))
for (v, val) in enumerate(row)
Expand Down Expand Up @@ -382,13 +382,6 @@ end

swap!(v, i, j) = v[i], v[j] = v[j], v[i]

function getcoeff(vars, coeffs, var)
for (vj, v) in enumerate(vars)
v == var && return coeffs[vj]
end
return 0
end

"""
$(SIGNATURES)

Expand Down
133 changes: 114 additions & 19 deletions src/systems/sparsematrixclil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ end
# build something that works for us here and worry about it later.
nonzerosmap(a::CLILVector) = NonZeros(a)

using FindFirstFunctions: findfirstequal

function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swapto, pivot,
last_pivot; pivot_equal_optimization = true)
# for ei in nzrows(>= k)
Expand Down Expand Up @@ -168,12 +170,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
# conservative, we leave it at this, as this captures the most important
# case for MTK (where most pivots are `1` or `-1`).
pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot)

for ei in (k + 1):size(M, 1)
@inbounds for ei in (k + 1):size(M, 1)
# eliminate `v`
coeff = 0
ivars = eadj[ei]
vj = findfirst(isequal(vpivot), ivars)
vj = findfirstequal(vpivot, ivars)
if vj !== nothing
coeff = old_cadj[ei][vj]
deleteat!(old_cadj[ei], vj)
Expand All @@ -189,24 +190,118 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
ivars = eadj[ei]
icoeffs = old_cadj[ei]

tmp_incidence = similar(eadj[ei], 0)
tmp_coeffs = similar(old_cadj[ei], 0)
# TODO: We know both ivars and kvars are sorted, we could just write
# a quick iterator here that does this without allocation/faster.
vars = sort(union(ivars, kvars))

for v in vars
v == vpivot && continue
ck = getcoeff(kvars, kcoeffs, v)
ci = getcoeff(ivars, icoeffs, v)
p1 = Base.Checked.checked_mul(pivot, ci)
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
if !iszero(ci)
push!(tmp_incidence, v)
push!(tmp_coeffs, ci)
numkvars = length(kvars)
numivars = length(ivars)
tmp_incidence = similar(eadj[ei], numkvars + numivars)
tmp_coeffs = similar(old_cadj[ei], numkvars + numivars)
tmp_len = 0
kvind = ivind = 0
if _debug_mode
# in debug mode, we at least check to confirm we're iterating over
# `v`s in the correct order
vars = sort(union(ivars, kvars))
vi = 0
end
if numivars > 0 && numkvars > 0
kvv = kvars[kvind += 1]
ivv = ivars[ivind += 1]
dobreak = false
while true
if kvv == ivv
v = kvv
ck = kcoeffs[kvind]
ci = icoeffs[ivind]
kvind += 1
ivind += 1
if kvind > numkvars
dobreak = true
else
kvv = kvars[kvind]
end
if ivind > numivars
dobreak = true
else
ivv = ivars[ivind]
end
p1 = Base.Checked.checked_mul(pivot, ci)
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
elseif kvv < ivv
v = kvv
ck = kcoeffs[kvind]
kvind += 1
if kvind > numkvars
dobreak = true
else
kvv = kvars[kvind]
end
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
else # kvv > ivv
v = ivv
ci = icoeffs[ivind]
ivind += 1
if ivind > numivars
dobreak = true
else
ivv = ivars[ivind]
end
ci = exactdiv(Base.Checked.checked_mul(pivot, ci), last_pivot)
end
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot && !iszero(ci)
tmp_incidence[tmp_len += 1] = v
tmp_coeffs[tmp_len] = ci
end
dobreak && break
end
elseif numkvars > 0
ivind = 1
kvv = kvars[kvind += 1]
elseif numivars > 0
kvind = 1
ivv = ivars[ivind += 1]
end
if kvind <= numkvars
v = kvv
while true
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot
ck = kcoeffs[kvind]
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
if !iszero(ci)
tmp_incidence[tmp_len += 1] = v
tmp_coeffs[tmp_len] = ci
end
end
(kvind == numkvars) && break
v = kvars[kvind += 1]
end
elseif ivind <= numivars
v = ivv
while true
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot
p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind])
ci = exactdiv(p1, last_pivot)
if !iszero(ci)
tmp_incidence[tmp_len += 1] = v
tmp_coeffs[tmp_len] = ci
end
end
(ivind == numivars) && break
v = ivars[ivind += 1]
end
end
resize!(tmp_incidence, tmp_len)
resize!(tmp_coeffs, tmp_len)
eadj[ei] = tmp_incidence
old_cadj[ei] = tmp_coeffs
end
Expand Down