Skip to content

Commit eec597c

Browse files
authored
Merge pull request #12 from SciML/smc/rename
Rename `create_ude_component` to `NeuralNetworkBlock`
2 parents 2cd876e + 59c178e commit eec597c

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

src/UDEComponents.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,23 @@ using Lux: Lux
88
using Random: Xoshiro
99
using ComponentArrays: ComponentArray
1010

11-
export create_ude_component, multi_layer_feed_forward
11+
export NeuralNetworkBlock, multi_layer_feed_forward
1212

1313
include("utils.jl")
1414

1515
"""
16-
create_ude_component(n_input = 1, n_output = 1;
16+
NeuralNetworkBlock(n_input = 1, n_output = 1;
1717
chain = multi_layer_feed_forward(n_input, n_output),
1818
rng = Xoshiro(0))
1919
2020
Create an `ODESystem` with a neural network inside.
2121
"""
22-
function create_ude_component(n_input = 1,
22+
function NeuralNetworkBlock(n_input = 1,
2323
n_output = 1;
2424
chain = multi_layer_feed_forward(n_input, n_output),
25-
rng = Xoshiro(0))
26-
lux_p, st = Lux.setup(rng, chain)
27-
ca = ComponentArray(lux_p)
25+
rng = Xoshiro(0), eltype=Float64)
26+
lux_p = Lux.initialparameters(rng, chain)
27+
ca = ComponentArray{eltype}(lux_p)
2828

2929
@parameters p[1:length(ca)] = Vector(ca)
3030
@parameters T::typeof(typeof(p))=typeof(p) [tunable = false]

test/lotka_volterra.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using StableRNGs
1414

1515
function lotka_ude()
1616
@variables t x(t)=3.1 y(t)=1.5
17-
@parameters α=1.3 β=0.9 γ=0.8 δ=1.8
17+
@parameters α=1.3 [tunable=false] δ=1.8 [tunable=false]
1818
Dt = ModelingToolkit.D_nounits
1919
@named nn_in = RealInput(nin = 2)
2020
@named nn_out = RealOutput(nout = 2)
@@ -44,7 +44,7 @@ end
4444
model = lotka_ude()
4545

4646
chain = multi_layer_feed_forward(2, 2)
47-
nn = create_ude_component(2, 2; chain, rng = StableRNG(42))
47+
nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42))
4848

4949
eqs = [connect(model.nn_in, nn.output)
5050
connect(model.nn_out, nn.input)]
@@ -67,7 +67,7 @@ get_refs = getu(model_true, [model_true.x, model_true.y])
6767

6868
function loss(x, (prob, sol_ref, get_vars, get_refs))
6969
new_p = SciMLStructures.replace(Tunable(), prob.p, x)
70-
new_prob = remake(prob, p = new_p)
70+
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
7171
ts = sol_ref.t
7272
new_sol = solve(new_prob, Rodas4(), saveat = ts)
7373

@@ -115,7 +115,7 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
115115

116116
res_p = SciMLStructures.replace(Tunable(), prob.p, res)
117117
res_prob = remake(prob, p = res_p)
118-
res_sol = solve(res_prob, Rodas4())
118+
res_sol = solve(res_prob, Rodas4(), saveat=sol_ref.t)
119119

120120
# using Plots
121121
# plot(sol_ref, idxs = [model_true.x, model_true.y])

0 commit comments

Comments
 (0)