@@ -14,7 +14,7 @@ using StableRNGs
14
14
15
15
function lotka_ude ()
16
16
@variables t x (t)= 3.1 y (t)= 1.5
17
- @parameters α= 1.3 β = 0.9 γ = 0.8 δ= 1.8
17
+ @parameters α= 1.3 [tunable = false ] δ= 1.8 [tunable = false ]
18
18
Dt = ModelingToolkit. D_nounits
19
19
@named nn_in = RealInput (nin = 2 )
20
20
@named nn_out = RealOutput (nout = 2 )
44
44
model = lotka_ude ()
45
45
46
46
chain = multi_layer_feed_forward (2 , 2 )
47
- nn = create_ude_component (2 , 2 ; chain, rng = StableRNG (42 ))
47
+ nn = NeuralNetworkBlock (2 , 2 ; chain, rng = StableRNG (42 ))
48
48
49
49
eqs = [connect (model. nn_in, nn. output)
50
50
connect (model. nn_out, nn. input)]
@@ -67,7 +67,7 @@ get_refs = getu(model_true, [model_true.x, model_true.y])
67
67
68
68
function loss (x, (prob, sol_ref, get_vars, get_refs))
69
69
new_p = SciMLStructures. replace (Tunable (), prob. p, x)
70
- new_prob = remake (prob, p = new_p)
70
+ new_prob = remake (prob, p = new_p, u0 = eltype (x).(prob . u0) )
71
71
ts = sol_ref. t
72
72
new_sol = solve (new_prob, Rodas4 (), saveat = ts)
73
73
@@ -115,7 +115,7 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
115
115
116
116
res_p = SciMLStructures. replace (Tunable (), prob. p, res)
117
117
res_prob = remake (prob, p = res_p)
118
- res_sol = solve (res_prob, Rodas4 ())
118
+ res_sol = solve (res_prob, Rodas4 (), saveat = sol_ref . t )
119
119
120
120
# using Plots
121
121
# plot(sol_ref, idxs = [model_true.x, model_true.y])
0 commit comments