Skip to content

Commit 04962e5

Browse files
committed
optimize bareiss_update_virtual_colswap_mtk!
1 parent e089c9e commit 04962e5

File tree

2 files changed

+107
-43
lines changed

2 files changed

+107
-43
lines changed

src/systems/alias_elimination.jl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -382,36 +382,6 @@ end
382382

383383
swap!(v, i, j) = v[i], v[j] = v[j], v[i]
384384

385-
function getcoeff(vars, coeffs, var)
386-
Nvars = length(vars)
387-
i = 0
388-
chunk_size = 8
389-
@inbounds while i < Nvars - chunk_size + 1
390-
btup = let vars = vars, var = var, i = i
391-
ntuple(Val(chunk_size)) do j
392-
@inbounds vars[i + j] == var
393-
end
394-
end
395-
inds = ntuple(Base.Fix2(-, 1), Val(8))
396-
eights = ntuple(Returns(8), Val(8))
397-
inds = map(ifelse, btup, inds, eights)
398-
inds4 = (min(inds[1], inds[5]),
399-
min(inds[2], inds[6]),
400-
min(inds[3], inds[7]),
401-
min(inds[4], inds[8]))
402-
inds2 = (min(inds4[1], inds4[3]), min(inds4[2], inds4[4]))
403-
ind = min(inds2[1], inds2[2])
404-
if ind != 8
405-
return coeffs[i + ind + 1]
406-
end
407-
i += chunk_size
408-
end
409-
@inbounds for vj in (i + 1):Nvars
410-
vars[vj] == var && return coeffs[vj]
411-
end
412-
return 0
413-
end
414-
415385
"""
416386
$(SIGNATURES)
417387

src/systems/sparsematrixclil.jl

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
169169
# case for MTK (where most pivots are `1` or `-1`).
170170
pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot)
171171

172-
for ei in (k + 1):size(M, 1)
172+
@inbounds for ei in (k + 1):size(M, 1)
173173
# eliminate `v`
174174
coeff = 0
175175
ivars = eadj[ei]
@@ -193,18 +193,112 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
193193
tmp_coeffs = similar(old_cadj[ei], 0)
194194
# TODO: We know both ivars and kvars are sorted, we could just write
195195
# a quick iterator here that does this without allocation/faster.
196-
vars = sort(union(ivars, kvars))
197-
198-
for v in vars
199-
v == vpivot && continue
200-
ck = getcoeff(kvars, kcoeffs, v)
201-
ci = getcoeff(ivars, icoeffs, v)
202-
p1 = Base.Checked.checked_mul(pivot, ci)
203-
p2 = Base.Checked.checked_mul(coeff, ck)
204-
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
205-
if !iszero(ci)
206-
push!(tmp_incidence, v)
207-
push!(tmp_coeffs, ci)
196+
numkvars = length(kvars)
197+
numivars = length(ivars)
198+
kvind = ivind = 0
199+
if _debug_mode
200+
# in debug mode, we at least check to confirm we're iterating over
201+
# `v`s in the correct order
202+
vars = sort(union(ivars, kvars))
203+
vi = 0
204+
end
205+
if numivars > 0 && numkvars > 0
206+
kvv = kvars[kvind += 1]
207+
ivv = ivars[ivind += 1]
208+
dobreak = false
209+
while true
210+
if kvv == ivv
211+
v = kvv
212+
ck = kcoeffs[kvind]
213+
ci = icoeffs[ivind]
214+
kvind += 1
215+
ivind += 1
216+
if kvind > numkvars
217+
dobreak = true
218+
else
219+
kvv = kvars[kvind]
220+
end
221+
if ivind > numivars
222+
dobreak = true
223+
else
224+
ivv = ivars[ivind]
225+
end
226+
elseif kvv < ivv
227+
v = kvv
228+
ck = kcoeffs[kvind]
229+
ci = zero(eltype(icoeffs))
230+
kvind += 1
231+
if kvind > numkvars
232+
dobreak = true
233+
else
234+
kvv = kvars[kvind]
235+
end
236+
else # kvv > ivv
237+
v = ivv
238+
ck = zero(eltype(kcoeffs))
239+
ci = icoeffs[ivind]
240+
ivind += 1
241+
if ivind > numivars
242+
dobreak = true
243+
else
244+
ivv = ivars[ivind]
245+
end
246+
end
247+
if _debug_mode
248+
@assert v == vars[vi += 1]
249+
end
250+
if v != vpivot
251+
p1 = Base.Checked.checked_mul(pivot, ci)
252+
p2 = Base.Checked.checked_mul(coeff, ck)
253+
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
254+
if !iszero(ci)
255+
push!(tmp_incidence, v)
256+
push!(tmp_coeffs, ci)
257+
end
258+
end
259+
dobreak && break
260+
end
261+
elseif numivars == 0
262+
ivind = 1
263+
kvv = kvars[kvind += 1]
264+
else # numkvars == 0
265+
kvind = 1
266+
ivv = ivars[ivind += 1]
267+
end
268+
if kvind <= numkvars
269+
v = kvv
270+
while true
271+
if _debug_mode
272+
@assert v == vars[vi += 1]
273+
end
274+
if v != vpivot
275+
ck = kcoeffs[kvind]
276+
p2 = Base.Checked.checked_mul(coeff, ck)
277+
ci = exactdiv(Base.Checked.checked_sub(0, p2), last_pivot)
278+
if !iszero(ci)
279+
push!(tmp_incidence, v)
280+
push!(tmp_coeffs, ci)
281+
end
282+
end
283+
(kvind == numkvars) && break
284+
v = kvars[kvind += 1]
285+
end
286+
elseif ivind <= numivars
287+
v = ivv
288+
while true
289+
if _debug_mode
290+
@assert v == vars[vi += 1]
291+
end
292+
if v != vpivot
293+
p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind])
294+
ci = exactdiv(p1, last_pivot)
295+
if !iszero(ci)
296+
push!(tmp_incidence, v)
297+
push!(tmp_coeffs, ci)
298+
end
299+
end
300+
(ivind == numivars) && break
301+
v = ivars[ivind += 1]
208302
end
209303
end
210304
eadj[ei] = tmp_incidence

0 commit comments

Comments
 (0)