Skip to content

Commit 688b61f

Browse files
kaandocalChrisRackauckas
authored andcommitted
Reorganised dependencies
1 parent f393ae7 commit 688b61f

File tree

9 files changed

+312
-297
lines changed

9 files changed

+312
-297
lines changed

Project.toml

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,12 @@ version = "1.3.0"
66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
9-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
109
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1110
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1211
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
13-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1412
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1513
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1614
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
17-
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
18-
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1915
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2016
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
2117
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -29,15 +25,11 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2925
[compat]
3026
ArrayInterface = "2.13, 3.0"
3127
ConsoleProgressMonitor = "0.1"
32-
DiffEqBase = "6.48.1"
3328
DiffResults = "1.0"
3429
DocStringExtensions = "0.8"
3530
FiniteDiff = "2.5"
36-
Flux = "0.11, 0.12"
3731
ForwardDiff = "0.10"
3832
LoggingExtras = "0.4"
39-
ModelingToolkit = "5.2"
40-
Optim = "1"
4133
ProgressLogging = "0.1"
4234
Reexport = "0.2, 1.0"
4335
Requires = "1.0"
@@ -53,13 +45,16 @@ BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209"
5345
CMAEvolutionStrategy = "8d3b24bd-414e-49e0-94fb-163cc3a3e411"
5446
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
5547
Evolutionary = "86b6b26d-c046-49b6-aa0b-5f0f74682bd6"
48+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5649
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
50+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
5751
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
52+
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
5853
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5954
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
6055
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6156
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
6257
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6358

6459
[targets]
65-
test = ["BlackBoxOptim", "Evolutionary", "DiffEqFlux", "IterTools", "OrdinaryDiffEq", "NLopt", "CMAEvolutionStrategy", "Pkg", "Random", "SafeTestsets", "Test"]
60+
test = ["Flux", "ModelingToolkit", "Optim", "BlackBoxOptim", "Evolutionary", "DiffEqFlux", "IterTools", "OrdinaryDiffEq", "NLopt", "CMAEvolutionStrategy", "Pkg", "Random", "SafeTestsets", "Test"]

src/GalacticOptim.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,15 @@ module GalacticOptim
55

66
using DocStringExtensions
77
using Reexport
8-
@reexport using DiffEqBase
98
@reexport using SciMLBase
109
using Requires
11-
using DiffResults, ForwardDiff, Zygote, ReverseDiff, FiniteDiff
12-
import Tracker
13-
@reexport using Optim, Flux
10+
using DiffResults, ForwardDiff, Zygote, ReverseDiff, Tracker, FiniteDiff
1411
using Logging, ProgressLogging, Printf, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
1512
using ArrayInterface, Base.Iterators
1613

1714
using ForwardDiff: DEFAULT_CHUNK_THRESHOLD
1815
import SciMLBase: OptimizationProblem, OptimizationFunction, AbstractADType, __solve
1916

20-
import ModelingToolkit
21-
import ModelingToolkit: AutoModelingToolkit
22-
export AutoModelingToolkit
23-
2417
include("solve.jl")
2518
include("function.jl")
2619

src/function.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ function instantiate_function(f, x, ::AbstractADType, p, num_cons = 0)
3030
cons_j = f.cons_j === nothing ? nothing : (res,x)->f.cons_j(res,x,p)
3131
cons_h = f.cons_h === nothing ? nothing : (res,x)->f.cons_h(res,x,p)
3232

33-
OptimizationFunction{true,DiffEqBase.NoAD,typeof(f.f),typeof(grad),
33+
OptimizationFunction{true,SciMLBase.NoAD,typeof(f.f),typeof(grad),
3434
typeof(hess),typeof(hv),typeof(cons),
3535
typeof(cons_j),typeof(cons_h)}(f.f,
36-
DiffEqBase.NoAD(),grad,hess,hv,cons,
36+
SciMLBase.NoAD(),grad,hess,hv,cons,
3737
cons_j,cons_h)
3838
end
3939

@@ -138,7 +138,7 @@ function instantiate_function(f, x, ::AutoZygote, p, num_cons = 0)
138138
return OptimizationFunction{false,AutoZygote,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoZygote(),grad,hess,hv,nothing,nothing,nothing)
139139
end
140140

141-
function instantiate_function(f, x, ::AutoReverseDiff, p=DiffEqBase.NullParameters(), num_cons = 0)
141+
function instantiate_function(f, x, ::AutoReverseDiff, p=SciMLBase.NullParameters(), num_cons = 0)
142142
num_cons != 0 && error("AutoReverseDiff does not currently support constraints")
143143

144144
_f = (θ, args...) -> first(f.f(θ,p, args...))

src/solve.jl

Lines changed: 3 additions & 256 deletions
Original file line numberDiff line numberDiff line change
@@ -57,268 +57,15 @@ macro withprogress(progress, exprs...)
5757
$(exprs[end])
5858
end
5959
end |> esc
60-
end
61-
62-
function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
63-
maxiters::Number = 0, cb = (args...) -> (false),
64-
progress = false, save_best = true, kwargs...)
65-
66-
if data != DEFAULT_DATA
67-
maxiters = length(data)
68-
else
69-
if maxiters <= 0.0
70-
error("The number of maxiters has to be a non-negative and non-zero number.")
71-
end
72-
data = take(data, maxiters)
73-
end
74-
75-
# Flux is silly and doesn't have an abstract type on its optimizers, so assume
76-
# this is a Flux optimizer
77-
θ = copy(prob.u0)
78-
G = copy(θ)
79-
80-
t0 = time()
81-
82-
local x, min_err, _loss
83-
min_err = typemax(eltype(prob.u0)) #dummy variables
84-
min_opt = 1
85-
86-
f = instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p)
87-
88-
@withprogress progress name="Training" begin
89-
for (i,d) in enumerate(data)
90-
f.grad(G, θ, d...)
91-
x = f.f(θ, prob.p, d...)
92-
cb_call = cb(θ, x...)
93-
if !(typeof(cb_call) <: Bool)
94-
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
95-
elseif cb_call
96-
break
97-
end
98-
msg = @sprintf("loss: %.3g", x[1])
99-
progress && ProgressLogging.@logprogress msg i/maxiters
100-
Flux.update!(opt, θ, G)
101-
102-
if save_best
103-
if first(x) < first(min_err) #found a better solution
104-
min_opt = opt
105-
min_err = x
106-
end
107-
if i == maxiters #Last iteration, revert to best.
108-
opt = min_opt
109-
cb(θ,min_err...)
110-
end
111-
end
112-
end
113-
end
114-
115-
_time = time()
116-
117-
SciMLBase.build_solution(prob, opt, θ, x[1])
118-
# here should be build_solution to create the output message
11960
end
12061

121-
122-
decompose_trace(trace::Optim.OptimizationTrace) = last(trace)
12362
decompose_trace(trace) = trace
12463

125-
function __solve(prob::OptimizationProblem, opt::Optim.AbstractOptimizer,
126-
data = DEFAULT_DATA;
127-
maxiters = nothing,
128-
cb = (args...) -> (false),
129-
progress = false,
130-
kwargs...)
131-
local x, cur, state
132-
133-
if data != DEFAULT_DATA
134-
maxiters = length(data)
135-
end
136-
137-
cur, state = iterate(data)
138-
139-
function _cb(trace)
140-
cb_call = opt == NelderMead() ? cb(decompose_trace(trace).metadata["centroid"],x...) : cb(decompose_trace(trace).metadata["x"],x...)
141-
if !(typeof(cb_call) <: Bool)
142-
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
143-
end
144-
cur, state = iterate(data, state)
145-
cb_call
146-
end
147-
148-
if !(isnothing(maxiters)) && maxiters <= 0.0
149-
error("The number of maxiters has to be a non-negative and non-zero number.")
150-
elseif !(isnothing(maxiters))
151-
maxiters = convert(Int, maxiters)
152-
end
153-
154-
f = instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p)
155-
156-
!(opt isa Optim.ZerothOrderOptimizer) && f.grad === nothing && error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
157-
158-
_loss = function(θ)
159-
x = f.f(θ, prob.p, cur...)
160-
return first(x)
161-
end
162-
163-
fg! = function (G,θ)
164-
if G !== nothing
165-
f.grad(G, θ, cur...)
166-
end
167-
return _loss(θ)
168-
end
169-
170-
if opt isa Optim.KrylovTrustRegion
171-
optim_f = Optim.TwiceDifferentiableHV(_loss, fg!, (H,θ,v) -> f.hv(H,θ,v,cur...), prob.u0)
172-
else
173-
optim_f = TwiceDifferentiable(_loss, (G, θ) -> f.grad(G, θ, cur...), fg!, (H,θ) -> f.hess(H,θ,cur...), prob.u0)
174-
end
175-
176-
original = Optim.optimize(optim_f, prob.u0, opt,
177-
!(isnothing(maxiters)) ?
178-
Optim.Options(;extended_trace = true,
179-
callback = _cb,
180-
iterations = maxiters,
181-
kwargs...) :
182-
Optim.Options(;extended_trace = true,
183-
callback = _cb, kwargs...))
184-
SciMLBase.build_solution(prob, opt, original.minimizer,
185-
original.minimum; original=original)
186-
end
187-
188-
function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN},
189-
data = DEFAULT_DATA;
190-
maxiters = nothing,
191-
cb = (args...) -> (false),
192-
progress = false,
193-
kwargs...)
194-
195-
local x, cur, state
196-
197-
if data != DEFAULT_DATA
198-
maxiters = length(data)
199-
end
200-
201-
cur, state = iterate(data)
202-
203-
function _cb(trace)
204-
cb_call = !(opt isa Optim.SAMIN) && opt.method == NelderMead() ? cb(decompose_trace(trace).metadata["centroid"],x...) : cb(decompose_trace(trace).metadata["x"],x...)
205-
if !(typeof(cb_call) <: Bool)
206-
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
207-
end
208-
cur, state = iterate(data, state)
209-
cb_call
210-
end
211-
212-
if !(isnothing(maxiters)) && maxiters <= 0.0
213-
error("The number of maxiters has to be a non-negative and non-zero number.")
214-
elseif !(isnothing(maxiters))
215-
maxiters = convert(Int, maxiters)
216-
end
217-
218-
f = instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p)
219-
220-
!(opt isa Optim.ZerothOrderOptimizer) && f.grad === nothing && error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
221-
222-
_loss = function(θ)
223-
x = f.f(θ, prob.p, cur...)
224-
return first(x)
225-
end
226-
fg! = function (G,θ)
227-
if G !== nothing
228-
f.grad(G, θ, cur...)
229-
end
230-
231-
return _loss(θ)
232-
end
233-
optim_f = OnceDifferentiable(_loss, (G, θ) -> f.grad(G, θ, cur...), fg!, prob.u0)
234-
235-
original = Optim.optimize(optim_f, prob.lb, prob.ub, prob.u0, opt,
236-
!(isnothing(maxiters)) ? Optim.Options(;
237-
extended_trace = true, callback = _cb,
238-
iterations = maxiters, kwargs...) :
239-
Optim.Options(;extended_trace = true,
240-
callback = _cb, kwargs...))
241-
SciMLBase.build_solution(prob, opt, original.minimizer,
242-
original.minimum; original=original)
243-
end
244-
245-
246-
function __solve(prob::OptimizationProblem, opt::Optim.ConstrainedOptimizer,
247-
data = DEFAULT_DATA;
248-
maxiters = nothing,
249-
cb = (args...) -> (false),
250-
progress = false,
251-
kwargs...)
252-
253-
local x, cur, state
254-
255-
if data != DEFAULT_DATA
256-
maxiters = length(data)
257-
end
258-
259-
cur, state = iterate(data)
260-
261-
function _cb(trace)
262-
cb_call = cb(decompose_trace(trace).metadata["x"],x...)
263-
if !(typeof(cb_call) <: Bool)
264-
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
265-
end
266-
cur, state = iterate(data, state)
267-
cb_call
268-
end
269-
270-
if !(isnothing(maxiters)) && maxiters <= 0.0
271-
error("The number of maxiters has to be a non-negative and non-zero number.")
272-
elseif !(isnothing(maxiters))
273-
maxiters = convert(Int, maxiters)
274-
end
275-
276-
f = instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p,prob.ucons === nothing ? 0 : length(prob.ucons))
277-
278-
f.cons_j ===nothing && error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
279-
280-
_loss = function(θ)
281-
x = f.f(θ, prob.p, cur...)
282-
return x[1]
283-
end
284-
fg! = function (G,θ)
285-
if G !== nothing
286-
f.grad(G, θ, cur...)
287-
end
288-
return _loss(θ)
289-
end
290-
optim_f = TwiceDifferentiable(_loss, (G, θ) -> f.grad(G, θ, cur...), fg!, (H,θ) -> f.hess(H, θ, cur...), prob.u0)
291-
292-
cons! = (res, θ) -> res .= f.cons(θ);
293-
294-
cons_j! = function(J, x)
295-
f.cons_j(J, x)
296-
end
297-
298-
cons_hl! = function (h, θ, λ)
299-
res = [similar(h) for i in 1:length(λ)]
300-
f.cons_h(res, θ)
301-
for i in 1:length(λ)
302-
h .+= λ[i]*res[i]
303-
end
304-
end
305-
306-
lb = prob.lb === nothing ? [] : prob.lb
307-
ub = prob.ub === nothing ? [] : prob.ub
308-
optim_fc = TwiceDifferentiableConstraints(cons!, cons_j!, cons_hl!, lb, ub, prob.lcons, prob.ucons)
309-
310-
original = Optim.optimize(optim_f, optim_fc, prob.u0, opt,
311-
!(isnothing(maxiters)) ? Optim.Options(;
312-
extended_trace = true, callback = _cb,
313-
iterations = maxiters, kwargs...) :
314-
Optim.Options(;extended_trace = true,
315-
callback = _cb, kwargs...))
316-
SciMLBase.build_solution(prob, opt, original.minimizer,
317-
original.minimum; original=original)
318-
end
64+
function __init__()
65+
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("solve_flux.jl")
31966

67+
@require Optim="429524aa-4258-5aef-a3af-852621145aeb" include("solve_optim.jl")
32068

321-
function __init__()
32269
@require BlackBoxOptim="a134a8b2-14d6-55f6-9291-3336d3ab0209" begin
32370
decompose_trace(opt::BlackBoxOptim.OptRunController) = BlackBoxOptim.best_candidate(opt)
32471

0 commit comments

Comments
 (0)