Skip to content

Commit cbb3e07

Browse files
Generalize Flux to allow use of other AD systems
1 parent 2993fb7 commit cbb3e07

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

src/solve.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
7575
# Flux is silly and doesn't have an abstract type on its optimizers, so assume
7676
# this is a Flux optimizer
7777
θ = copy(prob.u0)
78-
ps = Flux.params(θ)
78+
G = copy(θ)
7979

8080
t0 = time()
8181

@@ -87,10 +87,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
8787

8888
@withprogress progress name="Training" begin
8989
for (i,d) in enumerate(data)
90-
gs = Flux.Zygote.gradient(ps) do
91-
x = prob.f(θ,prob.p, d...)
92-
first(x)
93-
end
90+
f.grad(G, θ, d...)
9491
x = f.f(θ, prob.p, d...)
9592
cb_call = cb(θ, x...)
9693
if !(typeof(cb_call) <: Bool)
@@ -100,7 +97,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
10097
end
10198
msg = @sprintf("loss: %.3g", x[1])
10299
progress && ProgressLogging.@logprogress msg i/maxiters
103-
Flux.update!(opt, ps, gs)
100+
Flux.update!(opt, θ, G)
104101

105102
if save_best
106103
if first(x) < first(min_err) #found a better solution

0 commit comments

Comments
 (0)