@@ -2,7 +2,7 @@ module ModelingToolkitNeuralNets
2
2
3
3
using ModelingToolkit: @parameters , @named , ODESystem, t_nounits, @connector , @variables ,
4
4
Equation
5
- using ModelingToolkitStandardLibrary. Blocks: RealInput, RealOutput
5
+ using ModelingToolkitStandardLibrary. Blocks: RealInputArray, RealOutputArray
6
6
using Symbolics: Symbolics, @register_array_symbolic , @wrapped
7
7
using LuxCore: stateless_apply
8
8
using Lux: Lux
@@ -13,24 +13,6 @@ export NeuralNetworkBlock, multi_layer_feed_forward
13
13
14
14
include (" utils.jl" )
15
15
16
- @connector function RealInput2 (; name, nin = 1 , u_start = zeros (nin))
17
- @variables u (t_nounits)[1 : nin]= u_start [
18
- input = true ,
19
- description = " Inner variable in RealInput $name "
20
- ]
21
- u = collect (u)
22
- ODESystem (Equation[], t_nounits, [u... ], []; name = name)
23
- end
24
-
25
- @connector function RealOutput2 (; name, nout = 1 , u_start = zeros (nout))
26
- @variables u (t_nounits)[1 : nout]= u_start [
27
- output = true ,
28
- description = " Inner variable in RealOutput $name "
29
- ]
30
- u = collect (u)
31
- ODESystem (Equation[], t_nounits, [u... ], []; name = name)
32
- end
33
-
34
16
"""
35
17
NeuralNetworkBlock(n_input = 1, n_output = 1;
36
18
chain = multi_layer_feed_forward(n_input, n_output),
@@ -49,17 +31,17 @@ function NeuralNetworkBlock(n_input = 1,
49
31
ca = ComponentArray {eltype} (init_params)
50
32
51
33
@parameters p[1 : length (ca)] = Vector (ca)
52
- @parameters T:: typeof (typeof (p))= typeof (p) [tunable = false ]
34
+ # @parameters T::typeof(typeof(p))=typeof(p) [tunable = false]
53
35
54
- @named input = RealInput2 (nin = n_input)
55
- @named output = RealOutput2 (nout = n_output)
36
+ @named input = RealInputArray (nin = n_input)
37
+ @named output = RealOutputArray (nout = n_output)
56
38
57
39
out = stateless_apply (chain, input. u, lazyconvert (typeof (ca), p))
58
40
59
41
eqs = [output. u ~ out]
60
42
61
43
@named ude_comp = ODESystem (
62
- eqs, t_nounits, [], [p, T ], systems = [input, output])
44
+ eqs, t_nounits, [], [p], systems = [input, output])
63
45
return ude_comp
64
46
end
65
47
0 commit comments