Skip to content

Commit 37890ff

Browse files
author
Sathvik Bhagavan
committed
fixup! refactor: use RealInputArray and RealOutputArray
1 parent 1f281f6 commit 37890ff

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

src/ModelingToolkitNeuralNets.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module ModelingToolkitNeuralNets
22

3-
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits, @connector, @variables,
4-
Equation
3+
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
54
using ModelingToolkitStandardLibrary.Blocks: RealInputArray, RealOutputArray
65
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
76
using LuxCore: stateless_apply
@@ -31,17 +30,17 @@ function NeuralNetworkBlock(n_input = 1,
3130
ca = ComponentArray{eltype}(init_params)
3231

3332
@parameters p[1:length(ca)] = Vector(ca)
34-
# @parameters T::typeof(typeof(p))=typeof(p) [tunable = false]
33+
@parameters T::typeof(typeof(ca))=typeof(ca) [tunable = false]
3534

3635
@named input = RealInputArray(nin = n_input)
3736
@named output = RealOutputArray(nout = n_output)
3837

39-
out = stateless_apply(chain, input.u, lazyconvert(typeof(ca), p))
38+
out = stateless_apply(chain, input.u, lazyconvert(T, p))
4039

4140
eqs = [output.u ~ out]
4241

4342
@named ude_comp = ODESystem(
44-
eqs, t_nounits, [], [p], systems = [input, output])
43+
eqs, t_nounits, [], [p, T], systems = [input, output])
4544
return ude_comp
4645
end
4746

0 commit comments

Comments
 (0)