File tree Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -8,23 +8,23 @@ using Lux: Lux
8
8
using Random: Xoshiro
9
9
using ComponentArrays: ComponentArray
10
10
11
- export create_ude_component , multi_layer_feed_forward
11
+ export NeuralNetworkBlock , multi_layer_feed_forward
12
12
13
13
include (" utils.jl" )
14
14
15
15
"""
16
- create_ude_component (n_input = 1, n_output = 1;
16
+ NeuralNetworkBlock (n_input = 1, n_output = 1;
17
17
chain = multi_layer_feed_forward(n_input, n_output),
18
18
rng = Xoshiro(0))
19
19
20
20
Create an `ODESystem` with a neural network inside.
21
21
"""
22
- function create_ude_component (n_input = 1 ,
22
+ function NeuralNetworkBlock (n_input = 1 ,
23
23
n_output = 1 ;
24
24
chain = multi_layer_feed_forward (n_input, n_output),
25
- rng = Xoshiro (0 ))
26
- lux_p, st = Lux. setup (rng, chain)
27
- ca = ComponentArray (lux_p)
25
+ rng = Xoshiro (0 ), eltype = Float64 )
26
+ lux_p = Lux. initialparameters (rng, chain)
27
+ ca = ComponentArray {eltype} (lux_p)
28
28
29
29
@parameters p[1 : length (ca)] = Vector (ca)
30
30
@parameters T:: typeof (typeof (p))= typeof (p) [tunable = false ]
You can’t perform that action at this time.
0 commit comments