1
- const AbstractFluxOptimiser = Union{Flux. Momentum,
1
+ const AbstractFluxOptimiser = Union{Flux. Momentum,
2
2
Flux. Nesterov,
3
3
Flux. RMSProp,
4
4
Flux. ADAM,
@@ -12,7 +12,6 @@ const AbstractFluxOptimiser = Union{Flux.Momentum,
12
12
Flux. AdaBelief,
13
13
Flux. Optimiser}
14
14
15
-
16
15
function __solve (prob:: OptimizationProblem , opt, data = DEFAULT_DATA;
17
16
maxiters:: Number = 0 , cb = (args... ) -> (false ),
18
17
progress = false , save_best = true , kwargs... )
@@ -29,7 +28,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
29
28
# Flux is silly and doesn't have an abstract type on its optimizers, so assume
30
29
# this is a Flux optimizer
31
30
θ = copy (prob. u0)
32
- ps = Flux . params (θ)
31
+ G = copy (θ)
33
32
34
33
t0 = time ()
35
34
@@ -41,10 +40,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
41
40
42
41
@withprogress progress name= " Training" begin
43
42
for (i,d) in enumerate (data)
44
- gs = Flux. Zygote. gradient (ps) do
45
- x = prob. f (θ,prob. p, d... )
46
- first (x)
47
- end
43
+ f. grad (G, θ, d... )
48
44
x = f. f (θ, prob. p, d... )
49
45
cb_call = cb (θ, x... )
50
46
if ! (typeof (cb_call) <: Bool )
@@ -54,7 +50,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
54
50
end
55
51
msg = @sprintf (" loss: %.3g" , x[1 ])
56
52
progress && ProgressLogging. @logprogress msg i/ maxiters
57
- Flux. update! (opt, ps, gs )
53
+ Flux. update! (opt, θ, G )
58
54
59
55
if save_best
60
56
if first (x) < first (min_err) # found a better solution
@@ -75,5 +71,22 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
75
71
# here should be build_solution to create the output message
76
72
end
77
73
74
+ function Flux. update! (x:: AbstractArray , x̄:: AbstractArray{<:ForwardDiff.Dual} )
75
+ x .- = x̄
76
+ end
77
+
78
+ function Flux. update! (x:: AbstractArray , x̄)
79
+ x .- = getindex .(ForwardDiff. partials .(x̄),1 )
80
+ end
81
+
82
+ function Flux. update! (opt, x, x̄)
83
+ x .- = Flux. Optimise. apply! (opt, x, x̄)
84
+ end
78
85
86
+ function Flux. update! (opt, x, x̄:: AbstractArray{<:ForwardDiff.Dual} )
87
+ x .- = Flux. Optimise. apply! (opt, x, getindex .(ForwardDiff. partials .(x̄),1 ))
88
+ end
79
89
90
+ function Flux. update! (opt, xs:: Flux.Zygote.Params , gs)
91
+ update! (opt, xs[1 ], gs)
92
+ end
0 commit comments