Skip to content

Commit 7e115c9

Browse files
Merge pull request #716 from sethaxen/optimjl_state_grad_hess
Store grad/hess in state for Optim.jl
2 parents 1723965 + 0a417ea commit 7e115c9

File tree

2 files changed

+57
-13
lines changed

2 files changed

+57
-13
lines changed

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,13 @@ function SciMLBase.__solve(cache::OptimizationCache{
133133
error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
134134

135135
function _cb(trace)
136-
θ = cache.opt isa Optim.NelderMead ? decompose_trace(trace).metadata["centroid"] :
137-
decompose_trace(trace).metadata["x"]
136+
metadata = decompose_trace(trace).metadata
137+
θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"]
138138
opt_state = Optimization.OptimizationState(iter = trace.iteration,
139139
u = θ,
140140
objective = x[1],
141+
grad = get(metadata, "g(x)", nothing),
142+
hess = get(metadata, "h(x)", nothing),
141143
original = trace)
142144
cb_call = cache.callback(opt_state, x...)
143145
if !(cb_call isa Bool)
@@ -252,12 +254,15 @@ function SciMLBase.__solve(cache::OptimizationCache{
252254
cur, state = iterate(cache.data)
253255

254256
function _cb(trace)
257+
metadata = decompose_trace(trace).metadata
255258
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ?
256-
decompose_trace(trace).metadata["centroid"] :
257-
decompose_trace(trace).metadata["x"]
259+
metadata["centroid"] :
260+
metadata["x"]
258261
opt_state = Optimization.OptimizationState(iter = trace.iteration,
259262
u = θ,
260263
objective = x[1],
264+
grad = get(metadata, "g(x)", nothing),
265+
hess = get(metadata, "h(x)", nothing),
261266
original = trace)
262267
cb_call = cache.callback(opt_state, x...)
263268
if !(cb_call isa Bool)
@@ -341,8 +346,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
341346
cur, state = iterate(cache.data)
342347

343348
function _cb(trace)
349+
metadata = decompose_trace(trace).metadata
344350
opt_state = Optimization.OptimizationState(iter = trace.iteration,
345-
u = decompose_trace(trace).metadata["x"],
351+
u = metadata["x"],
352+
grad = get(metadata, "g(x)", nothing),
353+
hess = get(metadata, "h(x)", nothing),
346354
objective = x[1],
347355
original = trace)
348356
cb_call = cache.callback(opt_state, x...)

lib/OptimizationOptimJL/test/runtests.jl

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,32 @@ using OptimizationOptimJL,
33
Random, ModelingToolkit
44
using Test
55

6+
struct CallbackTester
7+
dim::Int
8+
has_grad::Bool
9+
has_hess::Bool
10+
end
11+
function CallbackTester(dim::Int; has_grad = false, has_hess = false)
12+
CallbackTester(dim, has_grad, has_hess)
13+
end
14+
15+
function (cb::CallbackTester)(state, loss_val)
16+
@test length(state.u) == cb.dim
17+
if cb.has_grad
18+
@test state.grad isa AbstractVector
19+
@test length(state.grad) == cb.dim
20+
else
21+
@test state.grad === nothing
22+
end
23+
if cb.has_hess
24+
@test state.hess isa AbstractMatrix
25+
@test size(state.hess) == (cb.dim, cb.dim)
26+
else
27+
@test state.hess === nothing
28+
end
29+
return false
30+
end
31+
632
@testset "OptimizationOptimJL.jl" begin
733
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
834
x0 = zeros(2)
@@ -13,34 +39,43 @@ using Test
1339
sol = solve(prob,
1440
Optim.NelderMead(;
1541
initial_simplex = Optim.AffineSimplexer(; a = 0.025,
16-
b = 0.5)))
42+
b = 0.5)); callback = CallbackTester(length(x0)))
1743
@test 10 * sol.objective < l1
1844

1945
f = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
2046

2147
Random.seed!(1234)
2248
prob = OptimizationProblem(f, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
23-
sol = solve(prob, SAMIN())
49+
sol = solve(prob, SAMIN(); callback = CallbackTester(length(x0)))
2450
@test 10 * sol.objective < l1
2551

26-
sol = solve(prob, Optim.IPNewton())
52+
sol = solve(
53+
prob, Optim.IPNewton();
54+
callback = CallbackTester(length(x0); has_grad = true, has_hess = true)
55+
)
2756
@test 10 * sol.objective < l1
2857

2958
prob = OptimizationProblem(f, x0, _p)
3059
Random.seed!(1234)
31-
sol = solve(prob, SimulatedAnnealing())
60+
sol = solve(prob, SimulatedAnnealing(); callback = CallbackTester(length(x0)))
3261
@test 10 * sol.objective < l1
3362

34-
sol = solve(prob, Optim.BFGS())
63+
sol = solve(prob, Optim.BFGS(); callback = CallbackTester(length(x0); has_grad = true))
3564
@test 10 * sol.objective < l1
3665

37-
sol = solve(prob, Optim.Newton())
66+
sol = solve(
67+
prob, Optim.Newton();
68+
callback = CallbackTester(length(x0); has_grad = true, has_hess = true)
69+
)
3870
@test 10 * sol.objective < l1
3971

4072
sol = solve(prob, Optim.KrylovTrustRegion())
4173
@test 10 * sol.objective < l1
4274

43-
sol = solve(prob, Optim.BFGS(), maxiters = 1)
75+
sol = solve(
76+
prob, Optim.BFGS();
77+
maxiters = 1, callback = CallbackTester(length(x0); has_grad = true)
78+
)
4479
@test sol.original.iterations == 1
4580

4681
sol = solve(prob, Optim.BFGS(), maxiters = 1, local_maxiters = 2)
@@ -92,7 +127,8 @@ using Test
92127
optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
93128

94129
prob = OptimizationProblem(optprob, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
95-
sol = solve(prob, Optim.Fminbox())
130+
sol = solve(
131+
prob, Optim.Fminbox(); callback = CallbackTester(length(x0); has_grad = true))
96132
@test 10 * sol.objective < l1
97133

98134
Random.seed!(1234)

0 commit comments

Comments
 (0)