Skip to content

Getcoeffchunk #2393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ the `constraint`.
mask,
constraint)
eadj = M.row_cols
for i in range
@inbounds for i in range
vertices = eadj[i]
if constraint(length(vertices))
for (j, v) in enumerate(vertices)
Expand All @@ -170,7 +170,7 @@ end
range,
mask,
constraint)
for i in range
@inbounds for i in range
row = @view M[i, :]
if constraint(count(!iszero, row))
for (v, val) in enumerate(row)
Expand Down Expand Up @@ -382,13 +382,6 @@ end

swap!(v, i, j) = v[i], v[j] = v[j], v[i]

function getcoeff(vars, coeffs, var)
for (vj, v) in enumerate(vars)
v == var && return coeffs[vj]
end
return 0
end

"""
$(SIGNATURES)

Expand Down
199 changes: 180 additions & 19 deletions src/systems/sparsematrixclil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,74 @@ end
# build something that works for us here and worry about it later.
nonzerosmap(a::CLILVector) = NonZeros(a)

findfirstequal(vpivot, ivars) = findfirst(isequal(vpivot), ivars)
function findfirstequal(vpivot::Int64, ivars::AbstractVector{Int64})
GC.@preserve ivars begin
ret = Base.llvmcall(("""
declare i8 @llvm.cttz.i8(i8, i1);
define i64 @entry(i64 %0, i64 %1, i64 %2) #0 {
top:
%ivars = inttoptr i64 %1 to i64*
%btmp = insertelement <8 x i64> undef, i64 %0, i64 0
%var = shufflevector <8 x i64> %btmp, <8 x i64> undef, <8 x i32> zeroinitializer
%lenm7 = add nsw i64 %2, -7
%dosimditer = icmp ugt i64 %2, 7
br i1 %dosimditer, label %L9.lr.ph, label %L32

L9.lr.ph:
%len8 = and i64 %2, 9223372036854775800
br label %L9

L9:
%i = phi i64 [ 0, %L9.lr.ph ], [ %vinc, %L30 ]
%ivarsi = getelementptr inbounds i64, i64* %ivars, i64 %i
%vpvi = bitcast i64* %ivarsi to <8 x i64>*
%v = load <8 x i64>, <8 x i64>* %vpvi, align 8
%m = icmp eq <8 x i64> %v, %var
%mu = bitcast <8 x i1> %m to i8
%matchnotfound = icmp eq i8 %mu, 0
br i1 %matchnotfound, label %L30, label %L17

L17:
%tz8 = call i8 @llvm.cttz.i8(i8 %mu, i1 true)
%tz64 = zext i8 %tz8 to i64
%vis = add nuw i64 %i, %tz64
br label %common.ret

common.ret:
%retval = phi i64 [ %vis, %L17 ], [ -1, %L32 ], [ %si, %L51 ], [ -1, %L67 ]
ret i64 %retval

L30:
%vinc = add nuw nsw i64 %i, 8
%continue = icmp slt i64 %vinc, %lenm7
br i1 %continue, label %L9, label %L32

L32:
%cumi = phi i64 [ 0, %top ], [ %len8, %L30 ]
%done = icmp eq i64 %cumi, %2
br i1 %done, label %common.ret, label %L51

L51:
%si = phi i64 [ %inc, %L67 ], [ %cumi, %L32 ]
%spi = getelementptr inbounds i64, i64* %ivars, i64 %si
%svi = load i64, i64* %spi, align 8
%match = icmp eq i64 %svi, %0
br i1 %match, label %common.ret, label %L67

L67:
%inc = add i64 %si, 1
%dobreak = icmp eq i64 %inc, %2
br i1 %dobreak, label %common.ret, label %L51

}
attributes #0 = { alwaysinline }
""", "entry"), Int64, Tuple{Int64, Ptr{Int64}, Int64}, vpivot, pointer(ivars),
length(ivars))
end
ret < 0 ? nothing : ret + 1
end

function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swapto, pivot,
last_pivot; pivot_equal_optimization = true)
# for ei in nzrows(>= k)
Expand Down Expand Up @@ -168,12 +236,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
# conservative, we leave it at this, as this captures the most important
# case for MTK (where most pivots are `1` or `-1`).
pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot)

for ei in (k + 1):size(M, 1)
@inbounds for ei in (k + 1):size(M, 1)
# eliminate `v`
coeff = 0
ivars = eadj[ei]
vj = findfirst(isequal(vpivot), ivars)
vj = findfirstequal(vpivot, ivars)
if vj !== nothing
coeff = old_cadj[ei][vj]
deleteat!(old_cadj[ei], vj)
Expand All @@ -189,24 +256,118 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
ivars = eadj[ei]
icoeffs = old_cadj[ei]

tmp_incidence = similar(eadj[ei], 0)
tmp_coeffs = similar(old_cadj[ei], 0)
# TODO: We know both ivars and kvars are sorted, we could just write
# a quick iterator here that does this without allocation/faster.
vars = sort(union(ivars, kvars))

for v in vars
v == vpivot && continue
ck = getcoeff(kvars, kcoeffs, v)
ci = getcoeff(ivars, icoeffs, v)
p1 = Base.Checked.checked_mul(pivot, ci)
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
if !iszero(ci)
push!(tmp_incidence, v)
push!(tmp_coeffs, ci)
numkvars = length(kvars)
numivars = length(ivars)
tmp_incidence = similar(eadj[ei], numkvars + numivars)
tmp_coeffs = similar(old_cadj[ei], numkvars + numivars)
tmp_len = 0
kvind = ivind = 0
if _debug_mode
# in debug mode, we at least check to confirm we're iterating over
# `v`s in the correct order
vars = sort(union(ivars, kvars))
vi = 0
end
if numivars > 0 && numkvars > 0
kvv = kvars[kvind += 1]
ivv = ivars[ivind += 1]
dobreak = false
while true
if kvv == ivv
v = kvv
ck = kcoeffs[kvind]
ci = icoeffs[ivind]
kvind += 1
ivind += 1
if kvind > numkvars
dobreak = true
else
kvv = kvars[kvind]
end
if ivind > numivars
dobreak = true
else
ivv = ivars[ivind]
end
p1 = Base.Checked.checked_mul(pivot, ci)
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
elseif kvv < ivv
v = kvv
ck = kcoeffs[kvind]
kvind += 1
if kvind > numkvars
dobreak = true
else
kvv = kvars[kvind]
end
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
else # kvv > ivv
v = ivv
ci = icoeffs[ivind]
ivind += 1
if ivind > numivars
dobreak = true
else
ivv = ivars[ivind]
end
ci = exactdiv(Base.Checked.checked_mul(pivot, ci), last_pivot)
end
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot && !iszero(ci)
tmp_incidence[tmp_len += 1] = v
tmp_coeffs[tmp_len] = ci
end
dobreak && break
end
elseif numkvars > 0
ivind = 1
kvv = kvars[kvind += 1]
elseif numivars > 0
kvind = 1
ivv = ivars[ivind += 1]
end
if kvind <= numkvars
v = kvv
while true
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot
ck = kcoeffs[kvind]
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
if !iszero(ci)
tmp_incidence[tmp_len += 1] = v
tmp_coeffs[tmp_len] = ci
end
end
(kvind == numkvars) && break
v = kvars[kvind += 1]
end
elseif ivind <= numivars
v = ivv
while true
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot
p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind])
ci = exactdiv(p1, last_pivot)
if !iszero(ci)
tmp_incidence[tmp_len += 1] = v
tmp_coeffs[tmp_len] = ci
end
end
(ivind == numivars) && break
v = ivars[ivind += 1]
end
end
resize!(tmp_incidence, tmp_len)
resize!(tmp_coeffs, tmp_len)
eadj[ei] = tmp_incidence
old_cadj[ei] = tmp_coeffs
end
Expand Down