Skip to content

Commit ae26f23

Browse files
Merge pull request #146 from sjdaines/forwardiff_color_jacobian_fast_sparse_path
Add fast path for forwarddiff_color_jacobian! with sparse J
2 parents 8884f5c + b968680 commit ae26f23

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Adapt = "1, 2.0, 3.0"
2222
ArrayInterface = "2.8, 3.0"
2323
Compat = "2.2, 3"
2424
DataStructures = "0.17, 0.18"
25-
FiniteDiff = "2"
25+
FiniteDiff = "2.8.1"
2626
ForwardDiff = "0.10"
2727
LightGraphs = "1.3"
2828
Requires = "0.5, 1.0"

src/differentiation/compute_jacobian_ad.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
264264
chunksize = jac_cache.chunksize
265265
color_i = 1
266266
maxcolor = maximum(colorvec)
267+
267268
fill!(J, zero(eltype(J)))
268269

269270
if FiniteDiff._use_findstructralnz(sparsity)
@@ -273,6 +274,9 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
273274
cols_index = 1:size(J,2)
274275
end
275276

277+
# fast path if J and sparsity are both SparseMatrixCSC and have the same sparsity pattern
278+
sparseCSC_common_sparsity = FiniteDiff._use_sparseCSC_common_sparsity(J, sparsity)
279+
276280
vecx = vec(x)
277281
vect = vec(t)
278282
vecfx= vec(fx)
@@ -287,9 +291,14 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
287291
if !(sparsity isa Nothing)
288292
for j in 1:chunksize
289293
dx .= partials.(fx, j)
294+
290295
if ArrayInterface.fast_scalar_indexing(dx)
291296
#dx is implicitly used in vecdx
292-
FiniteDiff._colorediteration!(J,sparsity,rows_index,cols_index,vecdx,colorvec,color_i,ncols)
297+
if sparseCSC_common_sparsity
298+
FiniteDiff._colorediteration!(J,vecdx,colorvec,color_i,ncols)
299+
else
300+
FiniteDiff._colorediteration!(J,sparsity,rows_index,cols_index,vecdx,colorvec,color_i,ncols)
301+
end
293302
else
294303
#=
295304
J.nzval[rows_index] .+= (colorvec[cols_index] .== color_i) .* dx[rows_index]
@@ -316,3 +325,6 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
316325
end
317326
return J
318327
end
328+
329+
330+

0 commit comments

Comments
 (0)