Skip to content

Commit bec8951

Browse files
Merge pull request #21 from SciML/sb/initparams
refactor: add `init_params` kwarg for passing parameters for the NeuralNetworkBlock
2 parents a34e555 + db92f53 commit bec8951

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/ModelingToolkitNeuralNets.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,19 @@ include("utils.jl")
1515
"""
1616
NeuralNetworkBlock(n_input = 1, n_output = 1;
1717
chain = multi_layer_feed_forward(n_input, n_output),
18-
rng = Xoshiro(0), eltype = Float64)
18+
rng = Xoshiro(0),
19+
init_params = Lux.initialparameters(rng, chain),
20+
eltype = Float64)
1921
2022
Create an `ODESystem` with a neural network inside.
2123
"""
2224
function NeuralNetworkBlock(n_input = 1,
2325
n_output = 1;
2426
chain = multi_layer_feed_forward(n_input, n_output),
25-
rng = Xoshiro(0), eltype = Float64)
26-
lux_p = Lux.initialparameters(rng, chain)
27-
ca = ComponentArray{eltype}(lux_p)
27+
rng = Xoshiro(0),
28+
init_params = Lux.initialparameters(rng, chain),
29+
eltype = Float64)
30+
ca = ComponentArray{eltype}(init_params)
2831

2932
@parameters p[1:length(ca)] = Vector(ca)
3033
@parameters T::typeof(typeof(p))=typeof(p) [tunable = false]

0 commit comments

Comments
 (0)