File tree Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -15,16 +15,19 @@ include("utils.jl")
15
15
"""
16
16
NeuralNetworkBlock(n_input = 1, n_output = 1;
17
17
chain = multi_layer_feed_forward(n_input, n_output),
18
- rng = Xoshiro(0), eltype = Float64)
18
+ rng = Xoshiro(0),
19
+ init_params = Lux.initialparameters(rng, chain),
20
+ eltype = Float64)
19
21
20
22
Create an `ODESystem` with a neural network inside.
21
23
"""
22
24
function NeuralNetworkBlock (n_input = 1 ,
23
25
n_output = 1 ;
24
26
chain = multi_layer_feed_forward (n_input, n_output),
25
- rng = Xoshiro (0 ), eltype = Float64)
26
- lux_p = Lux. initialparameters (rng, chain)
27
- ca = ComponentArray {eltype} (lux_p)
27
+ rng = Xoshiro (0 ),
28
+ init_params = Lux. initialparameters (rng, chain),
29
+ eltype = Float64)
30
+ ca = ComponentArray {eltype} (init_params)
28
31
29
32
@parameters p[1 : length (ca)] = Vector (ca)
30
33
@parameters T:: typeof (typeof (p))= typeof (p) [tunable = false ]
You can’t perform that action at this time.
0 commit comments