Skip to content

Commit 970f8dc

Browse files
author
Sathvik Bhagavan
committed
refactor: use RealInputArray and RealOutputArray
1 parent 02041f1 commit 970f8dc

File tree

1 file changed

+5
-23
lines changed

1 file changed

+5
-23
lines changed

src/ModelingToolkitNeuralNets.jl

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module ModelingToolkitNeuralNets
22

33
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits, @connector, @variables,
44
Equation
5-
using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput
5+
using ModelingToolkitStandardLibrary.Blocks: RealInputArray, RealOutputArray
66
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
77
using LuxCore: stateless_apply
88
using Lux: Lux
@@ -13,24 +13,6 @@ export NeuralNetworkBlock, multi_layer_feed_forward
1313

1414
include("utils.jl")
1515

16-
@connector function RealInput2(; name, nin = 1, u_start = zeros(nin))
17-
@variables u(t_nounits)[1:nin]=u_start [
18-
input = true,
19-
description = "Inner variable in RealInput $name"
20-
]
21-
u = collect(u)
22-
ODESystem(Equation[], t_nounits, [u...], []; name = name)
23-
end
24-
25-
@connector function RealOutput2(; name, nout = 1, u_start = zeros(nout))
26-
@variables u(t_nounits)[1:nout]=u_start [
27-
output = true,
28-
description = "Inner variable in RealOutput $name"
29-
]
30-
u = collect(u)
31-
ODESystem(Equation[], t_nounits, [u...], []; name = name)
32-
end
33-
3416
"""
3517
NeuralNetworkBlock(n_input = 1, n_output = 1;
3618
chain = multi_layer_feed_forward(n_input, n_output),
@@ -49,17 +31,17 @@ function NeuralNetworkBlock(n_input = 1,
4931
ca = ComponentArray{eltype}(init_params)
5032

5133
@parameters p[1:length(ca)] = Vector(ca)
52-
@parameters T::typeof(typeof(p))=typeof(p) [tunable = false]
34+
# @parameters T::typeof(typeof(p))=typeof(p) [tunable = false]
5335

54-
@named input = RealInput2(nin = n_input)
55-
@named output = RealOutput2(nout = n_output)
36+
@named input = RealInputArray(nin = n_input)
37+
@named output = RealOutputArray(nout = n_output)
5638

5739
out = stateless_apply(chain, input.u, lazyconvert(typeof(ca), p))
5840

5941
eqs = [output.u ~ out]
6042

6143
@named ude_comp = ODESystem(
62-
eqs, t_nounits, [], [p, T], systems = [input, output])
44+
eqs, t_nounits, [], [p], systems = [input, output])
6345
return ude_comp
6446
end
6547

0 commit comments

Comments
 (0)