Skip to content

Commit 23f4b02

Browse files
kaandocalChrisRackauckas
authored andcommitted
Reorganised AD dependencies
1 parent dd7af28 commit 23f4b02

File tree

12 files changed

+312
-251
lines changed

12 files changed

+312
-251
lines changed

Project.toml

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,24 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
99
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1010
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
11-
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
12-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1311
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1412
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
1513
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1614
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1715
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1816
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
19-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2017
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2118
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
22-
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
23-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2419

2520
[compat]
2621
ArrayInterface = "2.13, 3.0"
2722
ConsoleProgressMonitor = "0.1"
2823
DiffResults = "1.0"
2924
DocStringExtensions = "0.8"
30-
FiniteDiff = "2.5"
31-
ForwardDiff = "0.10"
3225
LoggingExtras = "0.4"
3326
ProgressLogging = "0.1"
3427
Reexport = "0.2, 1.0"
3528
Requires = "1.0"
36-
ReverseDiff = "1.4"
3729
SciMLBase = "1.8.1"
3830
TerminalLoggers = "0.1"
3931
Tracker = "0.2"
@@ -45,7 +37,9 @@ BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209"
4537
CMAEvolutionStrategy = "8d3b24bd-414e-49e0-94fb-163cc3a3e411"
4638
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
4739
Evolutionary = "86b6b26d-c046-49b6-aa0b-5f0f74682bd6"
40+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
4841
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
42+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4943
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
5044
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
5145
MultistartOptimization = "3933049c-43be-478e-a8bb-6e0f7fd53575"
@@ -54,8 +48,11 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
5448
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5549
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
5650
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
51+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
5752
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5853
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
54+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
55+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5956

6057
[targets]
61-
test = ["Flux", "ModelingToolkit", "Optim", "BlackBoxOptim", "Evolutionary", "DiffEqFlux", "IterTools", "OrdinaryDiffEq", "NLopt", "CMAEvolutionStrategy", "Pkg", "Random", "SafeTestsets", "Test"]
58+
test = ["Flux", "ModelingToolkit", "Optim", "BlackBoxOptim", "Evolutionary", "DiffEqFlux", "IterTools", "OrdinaryDiffEq", "NLopt", "CMAEvolutionStrategy", "Pkg", "Random", "SafeTestsets", "Test", "FiniteDiff", "ForwardDiff", "Tracker", "ReverseDiff", "Zygote"]

src/GalacticOptim.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@ using DocStringExtensions
77
using Reexport
88
@reexport using SciMLBase
99
using Requires
10-
using DiffResults, ForwardDiff, Zygote, ReverseDiff, Tracker, FiniteDiff
10+
using DiffResults
1111
using Logging, ProgressLogging, Printf, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
1212
using ArrayInterface, Base.Iterators
1313

14-
using ForwardDiff: DEFAULT_CHUNK_THRESHOLD
1514
import SciMLBase: OptimizationProblem, OptimizationFunction, AbstractADType, __solve
1615

1716
include("solve/solve.jl")
18-
include("function.jl")
17+
include("function/function.jl")
1918

2019
function __init__()
2120
# Optimization backends
@@ -27,6 +26,13 @@ function __init__()
2726
@require NLopt="76087f3c-5699-56af-9a33-bf431cd00edd" include("solve/nlopt.jl")
2827
@require Optim="429524aa-4258-5aef-a3af-852621145aeb" include("solve/optim.jl")
2928
@require QuadDIRECT="dae52e8d-d666-5120-a592-9e15c33b8d7a" include("solve/quaddirect.jl")
29+
30+
# AD backends
31+
@require FiniteDiff="6a86dc24-6348-571c-b903-95158fe2bd41" include("function/finitediff.jl")
32+
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("function/forwarddiff.jl")
33+
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("function/reversediff.jl")
34+
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("function/tracker.jl")
35+
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("function/zygote.jl")
3036
end
3137

3238
export solve

src/function.jl

Lines changed: 0 additions & 237 deletions
This file was deleted.

src/function/finitediff.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
struct AutoFiniteDiff{T1,T2} <: AbstractADType
2+
fdtype::T1
3+
fdhtype::T2
4+
end
5+
6+
AutoFiniteDiff(;fdtype = Val(:forward), fdhtype = Val(:hcentral)) =
7+
AutoFiniteDiff(fdtype,fdhtype)
8+
9+
function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
10+
num_cons != 0 && error("AutoFiniteDiff does not currently support constraints")
11+
_f = (θ, args...) -> first(f.f(θ, p, args...))
12+
13+
if f.grad === nothing
14+
grad = (res, θ, args...) -> FiniteDiff.finite_difference_gradient!(res, x ->_f(x, args...), θ, FiniteDiff.GradientCache(res, x, adtype.fdtype))
15+
else
16+
grad = f.grad
17+
end
18+
19+
if f.hess === nothing
20+
hess = (res, θ, args...) -> FiniteDiff.finite_difference_hessian!(res, x ->_f(x, args...), θ, FiniteDiff.HessianCache(x, adtype.fdhtype))
21+
else
22+
hess = f.hess
23+
end
24+
25+
if f.hv === nothing
26+
hv = function (H, θ, v, args...)
27+
res = ArrayInterface.zeromatrix(θ)
28+
hess(res, θ, args...)
29+
H .= res*v
30+
end
31+
else
32+
hv = f.hv
33+
end
34+
35+
return OptimizationFunction{false,AutoFiniteDiff,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,adtype,grad,hess,hv,nothing,nothing,nothing)
36+
end

0 commit comments

Comments
 (0)