Skip to content

Commit 916bec4

Browse files
committed
Reorganised dependencies
1 parent 65af75c commit 916bec4

File tree

9 files changed

+311
-298
lines changed

9 files changed

+311
-298
lines changed

Project.toml

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,12 @@ version = "1.2.0"
66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
9-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
109
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1110
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1211
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
13-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1412
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1513
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1614
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
17-
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
18-
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1915
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2016
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
2117
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -29,15 +25,11 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2925
[compat]
3026
ArrayInterface = "2.13, 3.0"
3127
ConsoleProgressMonitor = "0.1"
32-
DiffEqBase = "6.48.1"
3328
DiffResults = "1.0"
3429
DocStringExtensions = "0.8"
3530
FiniteDiff = "2.5"
36-
Flux = "0.11, 0.12"
3731
ForwardDiff = "0.10"
3832
LoggingExtras = "0.4"
39-
ModelingToolkit = "5.2"
40-
Optim = "1"
4133
ProgressLogging = "0.1"
4234
Reexport = "0.2, 1.0"
4335
Requires = "1.0"
@@ -53,13 +45,16 @@ BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209"
5345
CMAEvolutionStrategy = "8d3b24bd-414e-49e0-94fb-163cc3a3e411"
5446
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
5547
Evolutionary = "86b6b26d-c046-49b6-aa0b-5f0f74682bd6"
48+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5649
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
50+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
5751
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
52+
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
5853
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5954
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
6055
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6156
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
6257
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6358

6459
[targets]
65-
test = ["BlackBoxOptim", "Evolutionary", "DiffEqFlux", "IterTools", "OrdinaryDiffEq", "NLopt", "CMAEvolutionStrategy", "Pkg", "Random", "SafeTestsets", "Test"]
60+
test = ["Flux", "ModelingToolkit", "Optim", "BlackBoxOptim", "Evolutionary", "DiffEqFlux", "IterTools", "OrdinaryDiffEq", "NLopt", "CMAEvolutionStrategy", "Pkg", "Random", "SafeTestsets", "Test"]

src/GalacticOptim.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,15 @@ module GalacticOptim
55

66
using DocStringExtensions
77
using Reexport
8-
@reexport using DiffEqBase
98
@reexport using SciMLBase
109
using Requires
1110
using DiffResults, ForwardDiff, Zygote, ReverseDiff, Tracker, FiniteDiff
12-
@reexport using Optim, Flux
1311
using Logging, ProgressLogging, Printf, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
1412
using ArrayInterface, Base.Iterators
1513

1614
using ForwardDiff: DEFAULT_CHUNK_THRESHOLD
1715
import SciMLBase: OptimizationProblem, OptimizationFunction, AbstractADType, __solve
1816

19-
import ModelingToolkit
20-
import ModelingToolkit: AutoModelingToolkit
21-
export AutoModelingToolkit
22-
2317
include("solve.jl")
2418
include("function.jl")
2519

src/function.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ function instantiate_function(f, x, ::AbstractADType, p, num_cons = 0)
3030
cons_j = f.cons_j === nothing ? nothing : (res,x)->f.cons_j(res,x,p)
3131
cons_h = f.cons_h === nothing ? nothing : (res,x)->f.cons_h(res,x,p)
3232

33-
OptimizationFunction{true,DiffEqBase.NoAD,typeof(f.f),typeof(grad),
33+
OptimizationFunction{true,SciMLBase.NoAD,typeof(f.f),typeof(grad),
3434
typeof(hess),typeof(hv),typeof(cons),
3535
typeof(cons_j),typeof(cons_h)}(f.f,
36-
DiffEqBase.NoAD(),grad,hess,hv,cons,
36+
SciMLBase.NoAD(),grad,hess,hv,cons,
3737
cons_j,cons_h)
3838
end
3939

@@ -138,7 +138,7 @@ function instantiate_function(f, x, ::AutoZygote, p, num_cons = 0)
138138
return OptimizationFunction{false,AutoZygote,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoZygote(),grad,hess,hv,nothing,nothing,nothing)
139139
end
140140

141-
function instantiate_function(f, x, ::AutoReverseDiff, p=DiffEqBase.NullParameters(), num_cons = 0)
141+
function instantiate_function(f, x, ::AutoReverseDiff, p=SciMLBase.NullParameters(), num_cons = 0)
142142
num_cons != 0 && error("AutoReverseDiff does not currently support constraints")
143143

144144
_f = (θ, args...) -> first(f.f(θ,p, args...))

src/solve.jl

Lines changed: 3 additions & 259 deletions
Original file line numberDiff line numberDiff line change
@@ -57,271 +57,15 @@ macro withprogress(progress, exprs...)
5757
$(exprs[end])
5858
end
5959
end |> esc
60-
end
61-
62-
function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
63-
maxiters::Number = 0, cb = (args...) -> (false),
64-
progress = false, save_best = true, kwargs...)
65-
66-
if data != DEFAULT_DATA
67-
maxiters = length(data)
68-
else
69-
if maxiters <= 0.0
70-
error("The number of maxiters has to be a non-negative and non-zero number.")
71-
end
72-
data = take(data, maxiters)
73-
end
74-
75-
# Flux is silly and doesn't have an abstract type on its optimizers, so assume
76-
# this is a Flux optimizer
77-
θ = copy(prob.u0)
78-
ps = Flux.params(θ)
79-
80-
t0 = time()
81-
82-
local x, min_err, _loss
83-
min_err = typemax(eltype(prob.u0)) #dummy variables
84-
min_opt = 1
85-
86-
f = instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p)
87-
88-
@withprogress progress name="Training" begin
89-
for (i,d) in enumerate(data)
90-
gs = Flux.Zygote.gradient(ps) do
91-
x = prob.f(θ,prob.p, d...)
92-
first(x)
93-
end
94-
x = f.f(θ, prob.p, d...)
95-
cb_call = cb(θ, x...)
96-
if !(typeof(cb_call) <: Bool)
97-
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
98-
elseif cb_call
99-
break
100-
end
101-
msg = @sprintf("loss: %.3g", x[1])
102-
progress && ProgressLogging.@logprogress msg i/maxiters
103-
Flux.update!(opt, ps, gs)
104-
105-
if save_best
106-
if first(x) < first(min_err) #found a better solution
107-
min_opt = opt
108-
min_err = x
109-
end
110-
if i == maxiters #Last iteration, revert to best.
111-
opt = min_opt
112-
cb(θ,min_err...)
113-
end
114-
end
115-
end
116-
end
117-
118-
_time = time()
119-
120-
SciMLBase.build_solution(prob, opt, θ, x[1])
121-
# here should be build_solution to create the output message
12260
end
12361

124-
125-
decompose_trace(trace::Optim.OptimizationTrace) = last(trace)
12662
decompose_trace(trace) = trace
12763

128-
function __solve(prob::OptimizationProblem, opt::Optim.AbstractOptimizer,
129-
data = DEFAULT_DATA;
130-
maxiters = nothing,
131-
cb = (args...) -> (false),
132-
progress = false,
133-
kwargs...)
134-
local x, cur, state
135-
136-
if data != DEFAULT_DATA
137-
maxiters = length(data)
138-
end
139-
140-
cur, state = iterate(data)
141-
142-
function _cb(trace)
143-
cb_call = opt == NelderMead() ? cb(decompose_trace(trace).metadata["centroid"],x...) : cb(decompose_trace(trace).metadata["x"],x...)
144-
if !(typeof(cb_call) <: Bool)
145-
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
146-
end
147-
cur, state = iterate(data, state)
148-
cb_call
149-
end
150-
151-
if !(isnothing(maxiters)) && maxiters <= 0.0
152-
error("The number of maxiters has to be a non-negative and non-zero number.")
153-
elseif !(isnothing(maxiters))
154-
maxiters = convert(Int, maxiters)
155-
end
156-
157-
f = instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p)
158-
159-
!(opt isa Optim.ZerothOrderOptimizer) && f.grad === nothing && error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
160-
161-
_loss = function(θ)
162-
x = f.f(θ, prob.p, cur...)
163-
return first(x)
164-
end
165-
166-
fg! = function (G,θ)
167-
if G !== nothing
168-
f.grad(G, θ, cur...)
169-
end
170-
return _loss(θ)
171-
end
172-
173-
if opt isa Optim.KrylovTrustRegion
174-
optim_f = Optim.TwiceDifferentiableHV(_loss, fg!, (H,θ,v) -> f.hv(H,θ,v,cur...), prob.u0)
175-
else
176-
optim_f = TwiceDifferentiable(_loss, (G, θ) -> f.grad(G, θ, cur...), fg!, (H,θ) -> f.hess(H,θ,cur...), prob.u0)
177-
end
178-
179-
original = Optim.optimize(optim_f, prob.u0, opt,
180-
!(isnothing(maxiters)) ?
181-
Optim.Options(;extended_trace = true,
182-
callback = _cb,
183-
iterations = maxiters,
184-
kwargs...) :
185-
Optim.Options(;extended_trace = true,
186-
callback = _cb, kwargs...))
187-
SciMLBase.build_solution(prob, opt, original.minimizer,
188-
original.minimum; original=original)
189-
end
190-
191-
function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN},
192-
data = DEFAULT_DATA;
193-
maxiters = nothing,
194-
cb = (args...) -> (false),
195-
progress = false,
196-
kwargs...)
197-
198-
local x, cur, state
199-
200-
if data != DEFAULT_DATA
201-
maxiters = length(data)
202-
end
203-
204-
cur, state = iterate(data)
205-
206-
function _cb(trace)
207-
cb_call = !(opt isa Optim.SAMIN) && opt.method == NelderMead() ? cb(decompose_trace(trace).metadata["centroid"],x...) : cb(decompose_trace(trace).metadata["x"],x...)
208-
if !(typeof(cb_call) <: Bool)
209-
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
210-
end
211-
cur, state = iterate(data, state)
212-
cb_call
213-
end
214-
215-
if !(isnothing(maxiters)) && maxiters <= 0.0
216-
error("The number of maxiters has to be a non-negative and non-zero number.")
217-
elseif !(isnothing(maxiters))
218-
maxiters = convert(Int, maxiters)
219-
end
220-
221-
f = instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p)
222-
223-
!(opt isa Optim.ZerothOrderOptimizer) && f.grad === nothing && error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
224-
225-
_loss = function(θ)
226-
x = f.f(θ, prob.p, cur...)
227-
return first(x)
228-
end
229-
fg! = function (G,θ)
230-
if G !== nothing
231-
f.grad(G, θ, cur...)
232-
end
233-
234-
return _loss(θ)
235-
end
236-
optim_f = OnceDifferentiable(_loss, (G, θ) -> f.grad(G, θ, cur...), fg!, prob.u0)
237-
238-
original = Optim.optimize(optim_f, prob.lb, prob.ub, prob.u0, opt,
239-
!(isnothing(maxiters)) ? Optim.Options(;
240-
extended_trace = true, callback = _cb,
241-
iterations = maxiters, kwargs...) :
242-
Optim.Options(;extended_trace = true,
243-
callback = _cb, kwargs...))
244-
SciMLBase.build_solution(prob, opt, original.minimizer,
245-
original.minimum; original=original)
246-
end
247-
248-
249-
function __solve(prob::OptimizationProblem, opt::Optim.ConstrainedOptimizer,
250-
data = DEFAULT_DATA;
251-
maxiters = nothing,
252-
cb = (args...) -> (false),
253-
progress = false,
254-
kwargs...)
255-
256-
local x, cur, state
257-
258-
if data != DEFAULT_DATA
259-
maxiters = length(data)
260-
end
261-
262-
cur, state = iterate(data)
263-
264-
function _cb(trace)
265-
cb_call = cb(decompose_trace(trace).metadata["x"],x...)
266-
if !(typeof(cb_call) <: Bool)
267-
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
268-
end
269-
cur, state = iterate(data, state)
270-
cb_call
271-
end
272-
273-
if !(isnothing(maxiters)) && maxiters <= 0.0
274-
error("The number of maxiters has to be a non-negative and non-zero number.")
275-
elseif !(isnothing(maxiters))
276-
maxiters = convert(Int, maxiters)
277-
end
278-
279-
f = instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p,prob.ucons === nothing ? 0 : length(prob.ucons))
280-
281-
f.cons_j ===nothing && error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
282-
283-
_loss = function(θ)
284-
x = f.f(θ, prob.p, cur...)
285-
return x[1]
286-
end
287-
fg! = function (G,θ)
288-
if G !== nothing
289-
f.grad(G, θ, cur...)
290-
end
291-
return _loss(θ)
292-
end
293-
optim_f = TwiceDifferentiable(_loss, (G, θ) -> f.grad(G, θ, cur...), fg!, (H,θ) -> f.hess(H, θ, cur...), prob.u0)
294-
295-
cons! = (res, θ) -> res .= f.cons(θ);
296-
297-
cons_j! = function(J, x)
298-
f.cons_j(J, x)
299-
end
300-
301-
cons_hl! = function (h, θ, λ)
302-
res = [similar(h) for i in 1:length(λ)]
303-
f.cons_h(res, θ)
304-
for i in 1:length(λ)
305-
h .+= λ[i]*res[i]
306-
end
307-
end
308-
309-
lb = prob.lb === nothing ? [] : prob.lb
310-
ub = prob.ub === nothing ? [] : prob.ub
311-
optim_fc = TwiceDifferentiableConstraints(cons!, cons_j!, cons_hl!, lb, ub, prob.lcons, prob.ucons)
312-
313-
original = Optim.optimize(optim_f, optim_fc, prob.u0, opt,
314-
!(isnothing(maxiters)) ? Optim.Options(;
315-
extended_trace = true, callback = _cb,
316-
iterations = maxiters, kwargs...) :
317-
Optim.Options(;extended_trace = true,
318-
callback = _cb, kwargs...))
319-
SciMLBase.build_solution(prob, opt, original.minimizer,
320-
original.minimum; original=original)
321-
end
64+
function __init__()
65+
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("solve_flux.jl")
32266

67+
@require Optim="429524aa-4258-5aef-a3af-852621145aeb" include("solve_optim.jl")
32368

324-
function __init__()
32569
@require BlackBoxOptim="a134a8b2-14d6-55f6-9291-3336d3ab0209" begin
32670
decompose_trace(opt::BlackBoxOptim.OptRunController) = BlackBoxOptim.best_candidate(opt)
32771

0 commit comments

Comments
 (0)