Skip to content

Simplify _eval_hessian_inner #2730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 39 additions & 48 deletions src/Nonlinear/ReverseAD/forward_over_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,61 +50,19 @@ function _eval_hessian_inner(
@assert length(ex.hess_I) == 0
return 0
end
T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type.
Coloring.prepare_seed_matrix!(ex.seed_matrix, ex.rinfo)
local_to_global_idx = ex.rinfo.local_indices
input_ϵ_raw, output_ϵ_raw = d.input_ϵ, d.output_ϵ
input_ϵ = _reinterpret_unsafe(T, input_ϵ_raw)
output_ϵ = _reinterpret_unsafe(T, output_ϵ_raw)
# Compute hessian-vector products
num_products = size(ex.seed_matrix, 2) # number of hessian-vector products
num_chunks = div(num_products, CHUNK)
@assert size(ex.seed_matrix, 1) == length(local_to_global_idx)
for k in 1:CHUNK:(CHUNK*num_chunks)
for r in 1:length(local_to_global_idx)
# set up directional derivatives
@inbounds idx = local_to_global_idx[r]
# load up ex.seed_matrix[r,k,k+1,...,k+CHUNK-1] into input_ϵ
for s in 1:CHUNK
input_ϵ_raw[(idx-1)*CHUNK+s] = ex.seed_matrix[r, k+s-1]
end
@inbounds output_ϵ[idx] = zero(T)
end
_hessian_slice_inner(d, ex, input_ϵ, output_ϵ, T)
# collect directional derivatives
for r in 1:length(local_to_global_idx)
idx = local_to_global_idx[r]
# load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+CHUNK-1]
for s in 1:CHUNK
ex.seed_matrix[r, k+s-1] = output_ϵ_raw[(idx-1)*CHUNK+s]
end
@inbounds input_ϵ[idx] = zero(T)
end
@assert size(ex.seed_matrix, 1) == length(ex.rinfo.local_indices)
for offset in 1:CHUNK:(CHUNK*num_chunks)
_eval_hessian_chunk(d, ex, offset, CHUNK, Val(CHUNK))
end
# leftover chunk
remaining = num_products - CHUNK * num_chunks
if remaining > 0
k = CHUNK * num_chunks + 1
for r in 1:length(local_to_global_idx)
# set up directional derivatives
@inbounds idx = local_to_global_idx[r]
# load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
for s in 1:remaining
# leave junk in the unused components
input_ϵ_raw[(idx-1)*CHUNK+s] = ex.seed_matrix[r, k+s-1]
end
@inbounds output_ϵ[idx] = zero(T)
end
_hessian_slice_inner(d, ex, input_ϵ, output_ϵ, T)
# collect directional derivatives
for r in 1:length(local_to_global_idx)
idx = local_to_global_idx[r]
# load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
for s in 1:remaining
ex.seed_matrix[r, k+s-1] = output_ϵ_raw[(idx-1)*CHUNK+s]
end
@inbounds input_ϵ[idx] = zero(T)
end
offset = CHUNK * num_chunks + 1
_eval_hessian_chunk(d, ex, offset, remaining, Val(CHUNK))
end
want, got = nzcount + length(ex.hess_I), length(H)
if want > got
Expand All @@ -127,7 +85,40 @@ function _eval_hessian_inner(
return length(ex.hess_I)
end

function _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, ::Type{T}) where {T}
function _eval_hessian_chunk(
d::NLPEvaluator,
ex::_FunctionStorage,
offset::Int,
chunk::Int,
::Val{CHUNK},
) where {CHUNK}
for r in eachindex(ex.rinfo.local_indices)
# set up directional derivatives
@inbounds idx = ex.rinfo.local_indices[r]
# load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
for s in 1:chunk
# If `chunk < CHUNK`, leaves junk in the unused components
d.input_ϵ[(idx-1)*CHUNK+s] = ex.seed_matrix[r, offset+s-1]
end
end
_hessian_slice_inner(d, ex, Val(CHUNK))
fill!(d.input_ϵ, 0.0)
# collect directional derivatives
for r in eachindex(ex.rinfo.local_indices)
@inbounds idx = ex.rinfo.local_indices[r]
# load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
for s in 1:chunk
ex.seed_matrix[r, offset+s-1] = d.output_ϵ[(idx-1)*CHUNK+s]
end
end
return
end

function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type.
input_ϵ = _reinterpret_unsafe(T, d.input_ϵ)
fill!(d.output_ϵ, 0.0)
output_ϵ = _reinterpret_unsafe(T, d.output_ϵ)
subexpr_forward_values_ϵ =
_reinterpret_unsafe(T, d.subexpression_forward_values_ϵ)
for i in ex.dependent_subexpressions
Expand Down
Loading