Skip to content

Commit e32a713

Browse files
author
Stuart Daines
committed
Move fast sparse _colorediteration! to FiniteDiff
1 parent b1a3244 commit e32a713

File tree

1 file changed

+5
-24
lines changed

1 file changed

+5
-24
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -267,24 +267,16 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
267267

268268
fill!(J, zero(eltype(J)))
269269

270-
# fast path if J and sparsity are both SparseMatrixCSC and have the same number of columns and stored values
271-
sparseCSC_common = (J isa SparseMatrixCSC &&
272-
sparsity isa SparseMatrixCSC &&
273-
length(J.colptr) == length(sparsity.colptr) &&
274-
length(J.nzval) == length(sparsity.nzval))
275-
276-
if sparseCSC_common
277-
J.colptr .= sparsity.colptr
278-
J.rowval .= sparsity.rowval
279-
end
280-
281270
if FiniteDiff._use_findstructralnz(sparsity)
282271
rows_index, cols_index = ArrayInterface.findstructralnz(sparsity)
283272
else
284273
rows_index = 1:size(J,1)
285274
cols_index = 1:size(J,2)
286275
end
287276

277+
# fast path if J and sparsity are both SparseMatrixCSC and have the same number of columns and stored values
278+
sparseCSC_common_sparsity = FiniteDiff._use_sparseCSC_common_sparsity!(J, sparsity)
279+
288280
vecx = vec(x)
289281
vect = vec(t)
290282
vecfx= vec(fx)
@@ -302,8 +294,8 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
302294

303295
if ArrayInterface.fast_scalar_indexing(dx)
304296
#dx is implicitly used in vecdx
305-
if sparseCSC_common
306-
_colorediteration!(J,vecdx,colorvec,color_i,ncols)
297+
if sparseCSC_common_sparsity
298+
FiniteDiff._colorediteration!(J,vecdx,colorvec,color_i,ncols)
307299
else
308300
FiniteDiff._colorediteration!(J,sparsity,rows_index,cols_index,vecdx,colorvec,color_i,ncols)
309301
end
@@ -334,16 +326,5 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
334326
return J
335327
end
336328

337-
# fast version of FiniteDiff._colorediteration! for the case where J and sparsity have the same sparsity pattern
338-
@inline function _colorediteration!(Jsparsity::SparseMatrixCSC,vfx,colorvec,color_i,ncols)
339-
@inbounds for col_index in 1:ncols
340-
if colorvec[col_index] == color_i
341-
@inbounds for spidx in nzrange(Jsparsity, col_index)
342-
row_index = Jsparsity.rowval[spidx]
343-
Jsparsity.nzval[spidx]=vfx[row_index]
344-
end
345-
end
346-
end
347-
end
348329

349330

0 commit comments

Comments
 (0)