Skip to content

Commit 90f68f5

Browse files
use updated flux
1 parent 23f4b02 commit 90f68f5

File tree

2 files changed

+21
-30
lines changed

2 files changed

+21
-30
lines changed

src/solve/flux.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
const AbstractFluxOptimiser = Union{Flux.Momentum,
1+
const AbstractFluxOptimiser = Union{Flux.Momentum,
22
Flux.Nesterov,
33
Flux.RMSProp,
44
Flux.ADAM,
@@ -12,7 +12,6 @@ const AbstractFluxOptimiser = Union{Flux.Momentum,
1212
Flux.AdaBelief,
1313
Flux.Optimiser}
1414

15-
1615
function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
1716
maxiters::Number = 0, cb = (args...) -> (false),
1817
progress = false, save_best = true, kwargs...)
@@ -29,7 +28,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
2928
# Flux is silly and doesn't have an abstract type on its optimizers, so assume
3029
# this is a Flux optimizer
3130
θ = copy(prob.u0)
32-
ps = Flux.params(θ)
31+
G = copy(θ)
3332

3433
t0 = time()
3534

@@ -41,10 +40,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
4140

4241
@withprogress progress name="Training" begin
4342
for (i,d) in enumerate(data)
44-
gs = Flux.Zygote.gradient(ps) do
45-
x = prob.f(θ,prob.p, d...)
46-
first(x)
47-
end
43+
f.grad(G, θ, d...)
4844
x = f.f(θ, prob.p, d...)
4945
cb_call = cb(θ, x...)
5046
if !(typeof(cb_call) <: Bool)
@@ -54,7 +50,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
5450
end
5551
msg = @sprintf("loss: %.3g", x[1])
5652
progress && ProgressLogging.@logprogress msg i/maxiters
57-
Flux.update!(opt, ps, gs)
53+
Flux.update!(opt, θ, G)
5854

5955
if save_best
6056
if first(x) < first(min_err) #found a better solution
@@ -75,5 +71,22 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
7571
# here should be build_solution to create the output message
7672
end
7773

74+
function Flux.update!(x::AbstractArray, x̄::AbstractArray{<:ForwardDiff.Dual})
75+
x .-=
76+
end
77+
78+
function Flux.update!(x::AbstractArray, x̄)
79+
x .-= getindex.(ForwardDiff.partials.(x̄),1)
80+
end
81+
82+
function Flux.update!(opt, x, x̄)
83+
x .-= Flux.Optimise.apply!(opt, x, x̄)
84+
end
7885

86+
function Flux.update!(opt, x, x̄::AbstractArray{<:ForwardDiff.Dual})
87+
x .-= Flux.Optimise.apply!(opt, x, getindex.(ForwardDiff.partials.(x̄),1))
88+
end
7989

90+
function Flux.update!(opt, xs::Flux.Zygote.Params, gs)
91+
update!(opt, xs[1], gs)
92+
end

src/solve/solve.jl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,6 @@ get_maxiters(data) = Iterators.IteratorSize(typeof(DEFAULT_DATA)) isa Iterators.
77
Iterators.IteratorSize(typeof(DEFAULT_DATA)) isa Iterators.SizeUnknown ?
88
typemax(Int) : length(data)
99

10-
#=
11-
function update!(x::AbstractArray, x̄::AbstractArray{<:ForwardDiff.Dual})
12-
x .-= x̄
13-
end
14-
15-
function update!(x::AbstractArray, x̄)
16-
x .-= getindex.(ForwardDiff.partials.(x̄),1)
17-
end
18-
19-
function update!(opt, x, x̄)
20-
x .-= Flux.Optimise.apply!(opt, x, x̄)
21-
end
22-
23-
function update!(opt, x, x̄::AbstractArray{<:ForwardDiff.Dual})
24-
x .-= Flux.Optimise.apply!(opt, x, getindex.(ForwardDiff.partials.(x̄),1))
25-
end
26-
27-
function update!(opt, xs::Flux.Zygote.Params, gs)
28-
update!(opt, xs[1], gs)
29-
end
30-
=#
31-
3210
maybe_with_logger(f, logger) = logger === nothing ? f() : Logging.with_logger(f, logger)
3311

3412
function default_logger(logger)

0 commit comments

Comments
 (0)