Skip to content

Commit 59c178e

Browse files
committed
test: only train the neural network parameters and promote the u0 type
1 parent c0ee339 commit 59c178e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

test/lotka_volterra.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using StableRNGs
1414

1515
function lotka_ude()
1616
@variables t x(t)=3.1 y(t)=1.5
17-
@parameters α=1.3 β=0.9 γ=0.8 δ=1.8
17+
@parameters α=1.3 [tunable=false] δ=1.8 [tunable=false]
1818
Dt = ModelingToolkit.D_nounits
1919
@named nn_in = RealInput(nin = 2)
2020
@named nn_out = RealOutput(nout = 2)
@@ -44,7 +44,7 @@ end
4444
model = lotka_ude()
4545

4646
chain = multi_layer_feed_forward(2, 2)
47-
nn = create_ude_component(2, 2; chain, rng = StableRNG(42))
47+
nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42))
4848

4949
eqs = [connect(model.nn_in, nn.output)
5050
connect(model.nn_out, nn.input)]
@@ -67,7 +67,7 @@ get_refs = getu(model_true, [model_true.x, model_true.y])
6767

6868
function loss(x, (prob, sol_ref, get_vars, get_refs))
6969
new_p = SciMLStructures.replace(Tunable(), prob.p, x)
70-
new_prob = remake(prob, p = new_p)
70+
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
7171
ts = sol_ref.t
7272
new_sol = solve(new_prob, Rodas4(), saveat = ts)
7373

@@ -115,7 +115,7 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
115115

116116
res_p = SciMLStructures.replace(Tunable(), prob.p, res)
117117
res_prob = remake(prob, p = res_p)
118-
res_sol = solve(res_prob, Rodas4())
118+
res_sol = solve(res_prob, Rodas4(), saveat=sol_ref.t)
119119

120120
# using Plots
121121
# plot(sol_ref, idxs = [model_true.x, model_true.y])

0 commit comments

Comments
 (0)