Skip to content

Commit d080318

Browse files
Use lbfgsb as the default solver
1 parent efe6038 commit d080318

File tree

4 files changed

+100
-0
lines changed

4 files changed

+100
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
99
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
10+
LBFGSB = "5be7bae1-8223-5378-bac3-9e7378a2f6e6"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1213
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"

src/Optimization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ export ObjSense, MaxSense, MinSense
2323

2424
include("utils.jl")
2525
include("state.jl")
26+
include("lbfgsb.jl")
2627

2728
export solve
2829

src/lbfgsb.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using Optimization.SciMLBase, LBFGSB
2+
3+
@kwdef struct LBFGS
4+
m::Int=10
5+
end
6+
7+
SciMLBase.supports_opt_cache_interface(::LBFGS) = true
8+
SciMLBase.allowsbounds(::LBFGS) = true
9+
# SciMLBase.requiresgradient(::LBFGS) = true
10+
11+
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem,
12+
opt::LBFGS,
13+
data = Optimization.DEFAULT_DATA; save_best = true,
14+
callback = (args...) -> (false),
15+
progress = false, kwargs...)
16+
return OptimizationCache(prob, opt, data; save_best, callback, progress,
17+
kwargs...)
18+
end
19+
20+
function SciMLBase.__solve(cache::OptimizationCache{
21+
F,
22+
RC,
23+
LB,
24+
UB,
25+
LC,
26+
UC,
27+
S,
28+
O,
29+
D,
30+
P,
31+
C
32+
}) where {
33+
F,
34+
RC,
35+
LB,
36+
UB,
37+
LC,
38+
UC,
39+
S,
40+
O <:
41+
LBFGS,
42+
D,
43+
P,
44+
C
45+
}
46+
if cache.data != Optimization.DEFAULT_DATA
47+
maxiters = length(cache.data)
48+
data = cache.data
49+
else
50+
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
51+
data = Optimization.take(cache.data, maxiters)
52+
end
53+
54+
local x
55+
56+
_loss = function (θ)
57+
x = cache.f(θ, cache.p)
58+
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
59+
if cache.callback(opt_state, x...)
60+
error("Optimization halted by callback.")
61+
end
62+
return x[1]
63+
end
64+
65+
t0 = time()
66+
if cache.lb !== nothing && cache.ub !== nothing
67+
res = lbfgsb(_loss, cache.f.grad, cache.u0; m = cache.opt.m, maxiter = maxiters,
68+
lb = cache.lb, ub = cache.ub)
69+
else
70+
res = lbfgsb(_loss, cache.f.grad, cache.u0; m = cache.opt.m, maxiter = maxiters)
71+
end
72+
73+
t1 = time()
74+
stats = Optimization.OptimizationStats(; iterations = maxiters,
75+
time = t1 - t0, fevals = maxiters, gevals = maxiters)
76+
77+
return SciMLBase.build_solution(cache, cache.opt, res[2], res[1], stats = stats)
78+
end
79+

test/lbfgsb.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using Optimization
2+
using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker
3+
using ModelingToolkit, Enzyme, Random
4+
5+
x0 = zeros(2)
6+
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
7+
l1 = rosenbrock(x0)
8+
9+
optf = OptimizationFunction(rosenbrock, AutoForwardDiff())
10+
prob = OptimizationProblem(optf, x0)
11+
res = solve(prob, Optimization.LBFGS(), maxiters = 100)
12+
13+
@test res.u [1.0, 1.0] atol=1e-3
14+
15+
optf = OptimizationFunction(rosenbrock, AutoZygote())
16+
prob = OptimizationProblem(optf, x0, lb = [0.0, 0.0], ub = [0.3, 0.3])
17+
res = solve(prob, Optimization.LBFGS(), maxiters = 100)
18+
19+
@test res.u [0.3, 0.09] atol=1e-3

0 commit comments

Comments
 (0)