Skip to content

Commit 5fc872e

Browse files
Merge pull request #336 from SebastianM-C/redirect_stdout
Add progress info for OptimizationBBO
2 parents fe73975 + 0f31a23 commit 5fc872e

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

lib/OptimizationBBO/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209"
88
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
99

1010
[compat]
11-
julia = "1"
1211
BlackBoxOptim = "0.6"
1312
Optimization = "3"
13+
julia = "1"
1414

1515
[extras]
1616
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

lib/OptimizationBBO/src/OptimizationBBO.jl

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,33 @@ for j in string.(BlackBoxOptim.SingleObjectiveMethodNames)
99
eval(Meta.parse("export BBO_" * j))
1010
end
1111

12-
decompose_trace(opt::BlackBoxOptim.OptRunController) = BlackBoxOptim.best_candidate(opt)
12+
function decompose_trace(opt::BlackBoxOptim.OptRunController, progress)
13+
if progress
14+
maxiters = opt.max_steps
15+
max_time = opt.max_time
16+
msg = "loss: " * sprint(show, best_fitness(opt), context = :compact => true)
17+
if iszero(max_time)
18+
# we stop at either convergence or max_steps
19+
n_steps = BlackBoxOptim.num_steps(opt)
20+
Base.@logmsg(Base.LogLevel(-1), msg, progress=n_steps / maxiters,
21+
_id=:OptimizationBBO)
22+
else
23+
# we stop at either convergence or max_time
24+
elapsed = BlackBoxOptim.elapsed_time(opt)
25+
Base.@logmsg(Base.LogLevel(-1), msg, progress=elapsed / max_time,
26+
_id=:OptimizationBBO)
27+
end
28+
end
29+
return BlackBoxOptim.best_candidate(opt)
30+
end
1331

1432
function __map_optimizer_args(prob::SciMLBase.OptimizationProblem, opt::BBO;
1533
callback = nothing,
1634
maxiters::Union{Number, Nothing} = nothing,
1735
maxtime::Union{Number, Nothing} = nothing,
1836
abstol::Union{Number, Nothing} = nothing,
1937
reltol::Union{Number, Nothing} = nothing,
38+
verbose::Bool = false,
2039
kwargs...)
2140
if !isnothing(reltol)
2241
@warn "common reltol is currently not used by $(opt)"
@@ -44,6 +63,12 @@ function __map_optimizer_args(prob::SciMLBase.OptimizationProblem, opt::BBO;
4463
mapped_args = (; mapped_args..., MinDeltaFitnessTolerance = abstol)
4564
end
4665

66+
if verbose
67+
mapped_args = (; mapped_args..., TraceMode = :verbose)
68+
else
69+
mapped_args = (; mapped_args..., TraceMode = :silent)
70+
end
71+
4772
return mapped_args
4873
end
4974

@@ -54,6 +79,7 @@ function SciMLBase.__solve(prob::SciMLBase.OptimizationProblem, opt::BBO,
5479
maxtime::Union{Number, Nothing} = nothing,
5580
abstol::Union{Number, Nothing} = nothing,
5681
reltol::Union{Number, Nothing} = nothing,
82+
verbose::Bool = false,
5783
progress = false, kwargs...)
5884
local x, cur, state
5985

@@ -64,7 +90,7 @@ function SciMLBase.__solve(prob::SciMLBase.OptimizationProblem, opt::BBO,
6490
cur, state = iterate(data)
6591

6692
function _cb(trace)
67-
cb_call = callback(decompose_trace(trace), x...)
93+
cb_call = callback(decompose_trace(trace, progress), x...)
6894
if !(typeof(cb_call) <: Bool)
6995
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
7096
end
@@ -85,12 +111,19 @@ function SciMLBase.__solve(prob::SciMLBase.OptimizationProblem, opt::BBO,
85111

86112
opt_args = __map_optimizer_args(prob, opt, callback = _cb, maxiters = maxiters,
87113
maxtime = maxtime, abstol = abstol, reltol = reltol;
88-
kwargs...)
114+
verbose = verbose, kwargs...)
89115

90116
opt_setup = BlackBoxOptim.bbsetup(_loss; opt_args...)
91117

92118
t0 = time()
119+
93120
opt_res = BlackBoxOptim.bboptimize(opt_setup)
121+
122+
if progress
123+
# Set progressbar to 1 to finish it
124+
Base.@logmsg(Base.LogLevel(-1), "", progress=1, _id=:OptimizationBBO)
125+
end
126+
94127
t1 = time()
95128

96129
opt_ret = Symbol(opt_res.stop_reason)

lib/OptimizationBBO/test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,18 @@ using Test
1212
ub = [0.8, 0.8])
1313
sol = solve(prob, BBO_adaptive_de_rand_1_bin_radiuslimited())
1414
@test 10 * sol.minimum < l1
15+
16+
@test_logs begin
17+
(Base.LogLevel(-1), "loss: 0.0")
18+
min_level = Base.LogLevel(-1)
19+
solve(prob, BBO_adaptive_de_rand_1_bin_radiuslimited(), progress = true)
20+
end
21+
22+
@test_logs begin
23+
(Base.LogLevel(-1), "loss: 0.0")
24+
min_level = Base.LogLevel(-1)
25+
solve(prob, BBO_adaptive_de_rand_1_bin_radiuslimited(),
26+
progress = true,
27+
maxtime = 5)
28+
end
1529
end

0 commit comments

Comments
 (0)