Skip to content

Commit 6349b1b

Browse files
Add GalacticBBO subpackage
1 parent f3d9338 commit 6349b1b

File tree

7 files changed

+136
-103
lines changed

7 files changed

+136
-103
lines changed

GalacticBBO/Project.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
name = "GalacticBBO"
2+
uuid = "80c49c3a-6557-47d9-8f5b-13d0a2920315"
3+
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209"
8+
GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577"
9+
10+
[compat]
11+
julia = "1"
12+
13+
[extras]
14+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15+
16+
[targets]
17+
test = ["Test"]

GalacticBBO/src/GalacticBBO.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
module GalacticBBO
2+
3+
using BlackBoxOptim, GalacticOptim, GalacticOptim.SciMLBase
4+
5+
abstract type BBO end
6+
7+
for j = string.(BlackBoxOptim.SingleObjectiveMethodNames)
8+
eval(Meta.parse("Base.@kwdef struct BBO_" * j * " <: BBO method=:" * j * " end"))
9+
eval(Meta.parse("export BBO_" * j))
10+
end
11+
12+
decompose_trace(opt::BlackBoxOptim.OptRunController) = BlackBoxOptim.best_candidate(opt)
13+
14+
function __map_optimizer_args(prob::SciMLBase.OptimizationProblem, opt::BBO;
15+
cb=nothing,
16+
maxiters::Union{Number,Nothing}=nothing,
17+
maxtime::Union{Number,Nothing}=nothing,
18+
abstol::Union{Number,Nothing}=nothing,
19+
reltol::Union{Number,Nothing}=nothing,
20+
kwargs...)
21+
22+
if !isnothing(reltol)
23+
@warn "common reltol is currently not used by $(opt)"
24+
end
25+
26+
mapped_args = (; Method=opt.method,
27+
SearchRange=[(prob.lb[i], prob.ub[i]) for i in 1:length(prob.lb)])
28+
29+
if !isnothing(cb)
30+
mapped_args = (; mapped_args..., CallbackFunction=cb, CallbackInterval=0.0)
31+
end
32+
33+
mapped_args = (; mapped_args..., kwargs...)
34+
35+
if !isnothing(maxiters)
36+
mapped_args = (; mapped_args..., MaxSteps=maxiters)
37+
end
38+
39+
if !isnothing(maxtime)
40+
mapped_args = (; mapped_args..., MaxTime=maxtime)
41+
end
42+
43+
if !isnothing(abstol)
44+
mapped_args = (; mapped_args..., MinDeltaFitnessTolerance=abstol)
45+
end
46+
47+
return mapped_args
48+
end
49+
50+
function SciMLBase.__solve(prob::SciMLBase.OptimizationProblem, opt::BBO, data=GalacticOptim.DEFAULT_DATA;
51+
cb=(args...) -> (false),
52+
maxiters::Union{Number,Nothing}=nothing,
53+
maxtime::Union{Number,Nothing}=nothing,
54+
abstol::Union{Number,Nothing}=nothing,
55+
reltol::Union{Number,Nothing}=nothing,
56+
progress=false, kwargs...)
57+
58+
local x, cur, state
59+
60+
if data != GalacticOptim.DEFAULT_DATA
61+
maxiters = length(data)
62+
end
63+
64+
cur, state = iterate(data)
65+
66+
function _cb(trace)
67+
cb_call = cb(decompose_trace(trace), x...)
68+
if !(typeof(cb_call) <: Bool)
69+
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
70+
end
71+
if cb_call == true
72+
BlackBoxOptim.shutdown_optimizer!(trace) #doesn't work
73+
end
74+
cur, state = iterate(data, state)
75+
cb_call
76+
end
77+
78+
maxiters = GalacticOptim._check_and_convert_maxiters(maxiters)
79+
maxtime = GalacticOptim._check_and_convert_maxtime(maxtime)
80+
81+
82+
_loss = function (θ)
83+
x = prob.f(θ, prob.p, cur...)
84+
return first(x)
85+
end
86+
87+
opt_args = __map_optimizer_args(prob, opt, cb=_cb, maxiters=maxiters, maxtime=maxtime, abstol=abstol, reltol=reltol; kwargs...)
88+
89+
opt_setup = BlackBoxOptim.bbsetup(_loss; opt_args...)
90+
91+
t0 = time()
92+
opt_res = BlackBoxOptim.bboptimize(opt_setup)
93+
t1 = time()
94+
95+
opt_ret = Symbol(opt_res.stop_reason)
96+
97+
SciMLBase.build_solution(prob, opt, BlackBoxOptim.best_candidate(opt_res),
98+
BlackBoxOptim.best_fitness(opt_res); original=opt_res, retcode=opt_ret)
99+
end
100+
101+
export __solve, __map_optimizer_args
102+
103+
end

GalacticBBO/test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using GalacticBBO, GalacticOptim, Zygote
2+
using Test
3+
4+
@testset "GalacticBBO.jl" begin
5+
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
6+
x0 = zeros(2)
7+
_p = [1.0, 100.0]
8+
l1 = rosenbrock(x0, _p)
9+
10+
optprob = OptimizationFunction(rosenbrock, GalacticOptim.AutoZygote())
11+
prob = GalacticOptim.OptimizationProblem(optprob, x0, _p, lb=[-1.0, -1.0], ub=[0.8, 0.8])
12+
sol = solve(prob, BBO_adaptive_de_rand_1_bin_radiuslimited())
13+
@test 10 * sol.minimum < l1
14+
end

src/GalacticOptim.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ include("function/function.jl")
2121

2222
function __init__()
2323
# Optimization backends
24-
@require BlackBoxOptim="a134a8b2-14d6-55f6-9291-3336d3ab0209" include("solve/blackboxoptim.jl")
2524
@require CMAEvolutionStrategy="8d3b24bd-414e-49e0-94fb-163cc3a3e411" include("solve/cmaevolutionstrategy.jl")
2625
@require Evolutionary="86b6b26d-c046-49b6-aa0b-5f0f74682bd6" include("solve/evolutionary.jl")
2726
@require GCMAES="4aa9d100-eb0f-11e8-15f1-25748831eb3b" include("solve/gcmaes.jl")

src/solve/blackboxoptim.jl

Lines changed: 0 additions & 95 deletions
This file was deleted.

src/solve/flux.jl

Whitespace-only changes.

src/utils.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@ end
3939

4040
decompose_trace(trace) = trace
4141

42-
43-
function _map_optimizer_args(prob::OptimizationProblem, opt; kwargs...)
44-
__map_optimizer_args(prob, opt; kwargs...)
45-
end
46-
4742
function _check_and_convert_maxiters(maxiters)
4843
if !(isnothing(maxiters)) && maxiters <= 0.0
4944
error("The number of maxiters has to be a non-negative and non-zero number.")
@@ -69,5 +64,5 @@ function check_pkg_version(pkg::String,ver::String; branch::Union{String, Nothin
6964
pkg_info[dep.name] = dep
7065
end
7166

72-
return (isnothing(branch) | (pkg_info[pkg].git_revision == branch)) ? pkg_info[pkg].version >= VersionNumber(ver) : pkg_info[pkg].version > VersionNumber(ver)
73-
end
67+
return (isnothing(branch) | (pkg_info[pkg].git_revision == branch)) ? pkg_info[pkg].version >= VersionNumber(ver) : pkg_info[pkg].version > VersionNumber(ver)
68+
end

0 commit comments

Comments
 (0)