You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
f =instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p)
87
-
88
-
@withprogress progress name="Training"begin
89
-
for (i,d) inenumerate(data)
90
-
gs = Flux.Zygote.gradient(ps) do
91
-
x = prob.f(θ,prob.p, d...)
92
-
first(x)
93
-
end
94
-
x = f.f(θ, prob.p, d...)
95
-
cb_call =cb(θ, x...)
96
-
if!(typeof(cb_call) <:Bool)
97
-
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
146
-
end
147
-
cur, state =iterate(data, state)
148
-
cb_call
149
-
end
150
-
151
-
if!(isnothing(maxiters)) && maxiters <=0.0
152
-
error("The number of maxiters has to be a non-negative and non-zero number.")
153
-
elseif!(isnothing(maxiters))
154
-
maxiters =convert(Int, maxiters)
155
-
end
156
-
157
-
f =instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p)
158
-
159
-
!(opt isa Optim.ZerothOrderOptimizer) && f.grad ===nothing&&error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
cb_call =!(opt isa Optim.SAMIN) && opt.method ==NelderMead() ?cb(decompose_trace(trace).metadata["centroid"],x...) :cb(decompose_trace(trace).metadata["x"],x...)
208
-
if!(typeof(cb_call) <:Bool)
209
-
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
210
-
end
211
-
cur, state =iterate(data, state)
212
-
cb_call
213
-
end
214
-
215
-
if!(isnothing(maxiters)) && maxiters <=0.0
216
-
error("The number of maxiters has to be a non-negative and non-zero number.")
217
-
elseif!(isnothing(maxiters))
218
-
maxiters =convert(Int, maxiters)
219
-
end
220
-
221
-
f =instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p)
222
-
223
-
!(opt isa Optim.ZerothOrderOptimizer) && f.grad ===nothing&&error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
0 commit comments