Skip to content

Commit c0ee339

Browse files
committed
refactor: rename create_ude_component to NeuralNetworkBlock
1 parent 2cd876e commit c0ee339

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/UDEComponents.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,23 @@ using Lux: Lux
88
using Random: Xoshiro
99
using ComponentArrays: ComponentArray
1010

11-
export create_ude_component, multi_layer_feed_forward
11+
export NeuralNetworkBlock, multi_layer_feed_forward
1212

1313
include("utils.jl")
1414

1515
"""
16-
create_ude_component(n_input = 1, n_output = 1;
16+
NeuralNetworkBlock(n_input = 1, n_output = 1;
1717
chain = multi_layer_feed_forward(n_input, n_output),
1818
rng = Xoshiro(0))
1919
2020
Create an `ODESystem` with a neural network inside.
2121
"""
22-
function create_ude_component(n_input = 1,
22+
function NeuralNetworkBlock(n_input = 1,
2323
n_output = 1;
2424
chain = multi_layer_feed_forward(n_input, n_output),
25-
rng = Xoshiro(0))
26-
lux_p, st = Lux.setup(rng, chain)
27-
ca = ComponentArray(lux_p)
25+
rng = Xoshiro(0), eltype=Float64)
26+
lux_p = Lux.initialparameters(rng, chain)
27+
ca = ComponentArray{eltype}(lux_p)
2828

2929
@parameters p[1:length(ca)] = Vector(ca)
3030
@parameters T::typeof(typeof(p))=typeof(p) [tunable = false]

0 commit comments

Comments
 (0)