Skip to content

Commit 9db4a5b

Browse files
author
Sathvik Bhagavan
committed
refactor: [temp] add RealInput/RealOutput which works for size equal to 1
1 parent c1330c6 commit 9db4a5b

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

src/ModelingToolkitNeuralNets.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ModelingToolkitNeuralNets
22

3-
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
3+
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits, @connector, @variables,
4+
Equation
45
using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput
56
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
67
using LuxCore: stateless_apply
@@ -12,6 +13,24 @@ export NeuralNetworkBlock, multi_layer_feed_forward
1213

1314
include("utils.jl")
1415

16+
@connector function RealInput2(; name, nin = 1, u_start = zeros(nin))
17+
@variables u(t_nounits)[1:nin]=u_start [
18+
input = true,
19+
description = "Inner variable in RealInput $name"
20+
]
21+
u = collect(u)
22+
ODESystem(Equation[], t_nounits, [u...], []; name = name)
23+
end
24+
25+
@connector function RealOutput2(; name, nout = 1, u_start = zeros(nout))
26+
@variables u(t_nounits)[1:nout]=u_start [
27+
output = true,
28+
description = "Inner variable in RealOutput $name"
29+
]
30+
u = collect(u)
31+
ODESystem(Equation[], t_nounits, [u...], []; name = name)
32+
end
33+
1534
"""
1635
NeuralNetworkBlock(n_input = 1, n_output = 1;
1736
chain = multi_layer_feed_forward(n_input, n_output),
@@ -29,8 +48,8 @@ function NeuralNetworkBlock(n_input = 1,
2948
@parameters p[1:length(ca)] = Vector(ca)
3049
@parameters T::typeof(typeof(p))=typeof(p) [tunable = false]
3150

32-
@named input = RealInput(nin = n_input)
33-
@named output = RealOutput(nout = n_output)
51+
@named input = RealInput2(nin = n_input)
52+
@named output = RealOutput2(nout = n_output)
3453

3554
out = stateless_apply(chain, input.u, lazyconvert(typeof(ca), p))
3655

0 commit comments

Comments
 (0)