Skip to content

Commit a06fe2b

Browse files
committed
minor microoptimization
1 parent 04962e5 commit a06fe2b

File tree

2 files changed

+90
-23
lines changed

2 files changed

+90
-23
lines changed

src/systems/alias_elimination.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ the `constraint`.
153153
mask,
154154
constraint)
155155
eadj = M.row_cols
156-
for i in range
156+
@inbounds for i in range
157157
vertices = eadj[i]
158158
if constraint(length(vertices))
159159
for (j, v) in enumerate(vertices)
@@ -170,7 +170,7 @@ end
170170
range,
171171
mask,
172172
constraint)
173-
for i in range
173+
@inbounds for i in range
174174
row = @view M[i, :]
175175
if constraint(count(!iszero, row))
176176
for (v, val) in enumerate(row)

src/systems/sparsematrixclil.jl

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,74 @@ end
129129
# build something that works for us here and worry about it later.
130130
nonzerosmap(a::CLILVector) = NonZeros(a)
131131

132+
findfirstequal(vpivot, ivars) = findfirst(isequal(vpivot), ivars)
133+
function findfirstequal(vpivot::Int64, ivars::AbstractVector{Int64})
134+
GC.@preserve ivars begin
135+
ret = Base.llvmcall(("""
136+
declare i8 @llvm.cttz.i8(i8, i1);
137+
define i64 @entry(i64 %0, i64 %1, i64 %2) #0 {
138+
top:
139+
%ivars = inttoptr i64 %1 to i64*
140+
%btmp = insertelement <8 x i64> undef, i64 %0, i64 0
141+
%var = shufflevector <8 x i64> %btmp, <8 x i64> undef, <8 x i32> zeroinitializer
142+
%lenm7 = add nsw i64 %2, -7
143+
%dosimditer = icmp ugt i64 %2, 7
144+
br i1 %dosimditer, label %L9.lr.ph, label %L32
145+
146+
L9.lr.ph:
147+
%len8 = and i64 %2, 9223372036854775800
148+
br label %L9
149+
150+
L9:
151+
%i = phi i64 [ 0, %L9.lr.ph ], [ %vinc, %L30 ]
152+
%ivarsi = getelementptr inbounds i64, i64* %ivars, i64 %i
153+
%vpvi = bitcast i64* %ivarsi to <8 x i64>*
154+
%v = load <8 x i64>, <8 x i64>* %vpvi, align 8
155+
%m = icmp eq <8 x i64> %v, %var
156+
%mu = bitcast <8 x i1> %m to i8
157+
%matchnotfound = icmp eq i8 %mu, 0
158+
br i1 %matchnotfound, label %L30, label %L17
159+
160+
L17:
161+
%tz8 = call i8 @llvm.cttz.i8(i8 %mu, i1 true)
162+
%tz64 = zext i8 %tz8 to i64
163+
%vis = add nuw i64 %i, %tz64
164+
br label %common.ret
165+
166+
common.ret:
167+
%retval = phi i64 [ %vis, %L17 ], [ -1, %L32 ], [ %si, %L51 ], [ -1, %L67 ]
168+
ret i64 %retval
169+
170+
L30:
171+
%vinc = add nuw nsw i64 %i, 8
172+
%continue = icmp slt i64 %vinc, %lenm7
173+
br i1 %continue, label %L9, label %L32
174+
175+
L32:
176+
%cumi = phi i64 [ 0, %top ], [ %len8, %L30 ]
177+
%done = icmp eq i64 %cumi, %2
178+
br i1 %done, label %common.ret, label %L51
179+
180+
L51:
181+
%si = phi i64 [ %inc, %L67 ], [ %cumi, %L32 ]
182+
%spi = getelementptr inbounds i64, i64* %ivars, i64 %si
183+
%svi = load i64, i64* %spi, align 8
184+
%match = icmp eq i64 %svi, %0
185+
br i1 %match, label %common.ret, label %L67
186+
187+
L67:
188+
%inc = add i64 %si, 1
189+
%dobreak = icmp eq i64 %inc, %2
190+
br i1 %dobreak, label %common.ret, label %L51
191+
192+
}
193+
attributes #0 = { alwaysinline }
194+
""", "entry"), Int64, Tuple{Int64, Ptr{Int64}, Int64}, vpivot, pointer(ivars),
195+
length(ivars))
196+
end
197+
ret < 0 ? nothing : ret + 1
198+
end
199+
132200
function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swapto, pivot,
133201
last_pivot; pivot_equal_optimization = true)
134202
# for ei in nzrows(>= k)
@@ -168,12 +236,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
168236
# conservative, we leave it at this, as this captures the most important
169237
# case for MTK (where most pivots are `1` or `-1`).
170238
pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot)
171-
172239
@inbounds for ei in (k + 1):size(M, 1)
173240
# eliminate `v`
174241
coeff = 0
175242
ivars = eadj[ei]
176-
vj = findfirst(isequal(vpivot), ivars)
243+
vj = findfirstequal(vpivot, ivars)
177244
if vj !== nothing
178245
coeff = old_cadj[ei][vj]
179246
deleteat!(old_cadj[ei], vj)
@@ -189,12 +256,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
189256
ivars = eadj[ei]
190257
icoeffs = old_cadj[ei]
191258

192-
tmp_incidence = similar(eadj[ei], 0)
193-
tmp_coeffs = similar(old_cadj[ei], 0)
194-
# TODO: We know both ivars and kvars are sorted, we could just write
195-
# a quick iterator here that does this without allocation/faster.
196259
numkvars = length(kvars)
197260
numivars = length(ivars)
261+
tmp_incidence = similar(eadj[ei], numkvars + numivars)
262+
tmp_coeffs = similar(old_cadj[ei], numkvars + numivars)
263+
tmp_len = 0
198264
kvind = ivind = 0
199265
if _debug_mode
200266
# in debug mode, we at least check to confirm we're iterating over
@@ -223,38 +289,37 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
223289
else
224290
ivv = ivars[ivind]
225291
end
292+
p1 = Base.Checked.checked_mul(pivot, ci)
293+
p2 = Base.Checked.checked_mul(coeff, ck)
294+
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
226295
elseif kvv < ivv
227296
v = kvv
228297
ck = kcoeffs[kvind]
229-
ci = zero(eltype(icoeffs))
230298
kvind += 1
231299
if kvind > numkvars
232300
dobreak = true
233301
else
234302
kvv = kvars[kvind]
235303
end
304+
p2 = Base.Checked.checked_mul(coeff, ck)
305+
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
236306
else # kvv > ivv
237307
v = ivv
238-
ck = zero(eltype(kcoeffs))
239308
ci = icoeffs[ivind]
240309
ivind += 1
241310
if ivind > numivars
242311
dobreak = true
243312
else
244313
ivv = ivars[ivind]
245314
end
315+
ci = exactdiv(Base.Checked.checked_mul(pivot, ci), last_pivot)
246316
end
247317
if _debug_mode
248318
@assert v == vars[vi += 1]
249319
end
250-
if v != vpivot
251-
p1 = Base.Checked.checked_mul(pivot, ci)
252-
p2 = Base.Checked.checked_mul(coeff, ck)
253-
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
254-
if !iszero(ci)
255-
push!(tmp_incidence, v)
256-
push!(tmp_coeffs, ci)
257-
end
320+
if v != vpivot && !iszero(ci)
321+
tmp_incidence[tmp_len += 1] = v
322+
tmp_coeffs[tmp_len] = ci
258323
end
259324
dobreak && break
260325
end
@@ -274,10 +339,10 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
274339
if v != vpivot
275340
ck = kcoeffs[kvind]
276341
p2 = Base.Checked.checked_mul(coeff, ck)
277-
ci = exactdiv(Base.Checked.checked_sub(0, p2), last_pivot)
342+
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
278343
if !iszero(ci)
279-
push!(tmp_incidence, v)
280-
push!(tmp_coeffs, ci)
344+
tmp_incidence[tmp_len += 1] = v
345+
tmp_coeffs[tmp_len] = ci
281346
end
282347
end
283348
(kvind == numkvars) && break
@@ -293,14 +358,16 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
293358
p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind])
294359
ci = exactdiv(p1, last_pivot)
295360
if !iszero(ci)
296-
push!(tmp_incidence, v)
297-
push!(tmp_coeffs, ci)
361+
tmp_incidence[tmp_len += 1] = v
362+
tmp_coeffs[tmp_len] = ci
298363
end
299364
end
300365
(ivind == numivars) && break
301366
v = ivars[ivind += 1]
302367
end
303368
end
369+
resize!(tmp_incidence, tmp_len)
370+
resize!(tmp_coeffs, tmp_len)
304371
eadj[ei] = tmp_incidence
305372
old_cadj[ei] = tmp_coeffs
306373
end

0 commit comments

Comments
 (0)