@@ -267,24 +267,16 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
267
267
268
268
fill! (J, zero (eltype (J)))
269
269
270
- # fast path if J and sparsity are both SparseMatrixCSC and have the same number of columns and stored values
271
- sparseCSC_common = (J isa SparseMatrixCSC &&
272
- sparsity isa SparseMatrixCSC &&
273
- length (J. colptr) == length (sparsity. colptr) &&
274
- length (J. nzval) == length (sparsity. nzval))
275
-
276
- if sparseCSC_common
277
- J. colptr .= sparsity. colptr
278
- J. rowval .= sparsity. rowval
279
- end
280
-
281
270
if FiniteDiff. _use_findstructralnz (sparsity)
282
271
rows_index, cols_index = ArrayInterface. findstructralnz (sparsity)
283
272
else
284
273
rows_index = 1 : size (J,1 )
285
274
cols_index = 1 : size (J,2 )
286
275
end
287
276
277
+ # fast path if J and sparsity are both SparseMatrixCSC and have the same number of columns and stored values
278
+ sparseCSC_common_sparsity = FiniteDiff. _use_sparseCSC_common_sparsity! (J, sparsity)
279
+
288
280
vecx = vec (x)
289
281
vect = vec (t)
290
282
vecfx= vec (fx)
@@ -302,8 +294,8 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
302
294
303
295
if ArrayInterface. fast_scalar_indexing (dx)
304
296
# dx is implicitly used in vecdx
305
- if sparseCSC_common
306
- _colorediteration! (J,vecdx,colorvec,color_i,ncols)
297
+ if sparseCSC_common_sparsity
298
+ FiniteDiff . _colorediteration! (J,vecdx,colorvec,color_i,ncols)
307
299
else
308
300
FiniteDiff. _colorediteration! (J,sparsity,rows_index,cols_index,vecdx,colorvec,color_i,ncols)
309
301
end
@@ -334,16 +326,5 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
334
326
return J
335
327
end
336
328
337
- # fast version of FiniteDiff._colorediteration! for the case where J and sparsity have the same sparsity pattern
338
- @inline function _colorediteration! (Jsparsity:: SparseMatrixCSC ,vfx,colorvec,color_i,ncols)
339
- @inbounds for col_index in 1 : ncols
340
- if colorvec[col_index] == color_i
341
- @inbounds for spidx in nzrange (Jsparsity, col_index)
342
- row_index = Jsparsity. rowval[spidx]
343
- Jsparsity. nzval[spidx]= vfx[row_index]
344
- end
345
- end
346
- end
347
- end
348
329
349
330
0 commit comments