Skip to content

Commit 145db5e

Browse files
committed
feat: add initial implementation
1 parent 0bd6e2e commit 145db5e

File tree

10 files changed

+260
-13
lines changed

10 files changed

+260
-13
lines changed

.JuliaFormatter.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options
1+
style = "sciml"
2+
format_markdown = true
3+
format_docstrings = true
4+
annotate_untyped_fields_with_any = false

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
/Manifest.toml
55
/docs/Manifest.toml
66
/docs/build/
7+
.vscode

Project.toml

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,48 @@ uuid = "f162e290-f571-43a6-83d9-22ecc16da15f"
33
authors = ["Sebastian Micluța-Câmpeanu <[email protected]> and contributors"]
44
version = "1.0.0-DEV"
55

6+
[deps]
7+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
8+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
9+
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
10+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
11+
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
12+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
13+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
15+
616
[compat]
17+
Aqua = "0.8"
18+
ComponentArrays = "0.15"
19+
ForwardDiff = "0.10.36"
20+
JET = "0.8"
21+
Lux = "0.5.32"
22+
LuxCore = "0.1.14"
23+
ModelingToolkit = "9.9.0"
24+
ModelingToolkitStandardLibrary = "2.6"
25+
NNlib = "0.9"
26+
Optimization = "3.22"
27+
OptimizationOptimisers = "0.2"
28+
OrdinaryDiffEq = "6.74"
29+
Random = "1"
30+
SafeTestsets = "0.1"
31+
SciMLStructures = "1.1.0"
32+
SymbolicIndexingInterface = "0.3.15"
33+
Symbolics = "5.27"
34+
Test = "1"
735
julia = "1.10"
836

937
[extras]
1038
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
39+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1140
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
41+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
42+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
43+
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
44+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
45+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
46+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1247
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1348

1449
[targets]
15-
test = ["Aqua", "JET", "Test"]
50+
test = ["Aqua", "JET", "Test", "OrdinaryDiffEq", "ForwardDiff", "Optimization", "OptimizationOptimisers", "SafeTestsets", "SciMLStructures", "SymbolicIndexingInterface"]

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@
66
[![Coverage](https://codecov.io/gh/SebastianM-C/UDEComponents.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/SebastianM-C/UDEComponents.jl)
77
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
88
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
9+
10+
## Build UDEs with ModelingToolkit

src/UDEComponents.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,47 @@
11
module UDEComponents
22

3-
# Write your package code here.
3+
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
4+
using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput
5+
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
6+
using LuxCore: stateless_apply
7+
using Lux: Lux
8+
using Random: Xoshiro
9+
using NNlib: softplus
10+
using ComponentArrays: ComponentArray
11+
12+
export create_ude_component, multi_layer_feed_forward
13+
14+
include("utils.jl")
15+
include("hacks.jl") # this should be removed / upstreamed
16+
17+
"""
18+
19+
create_ude_component(n_input = 1, n_output = 1;
20+
chain = multi_layer_feed_forward(n_input, n_output),
21+
rng = Xoshiro(0))
22+
23+
Create an `ODESystem` with a neural network inside.
24+
"""
25+
function create_ude_component(n_input = 1,
26+
n_output = 1;
27+
chain = multi_layer_feed_forward(n_input, n_output),
28+
rng = Xoshiro(0))
29+
lux_p, st = Lux.setup(rng, chain)
30+
ca = ComponentArray(lux_p)
31+
32+
@parameters p[1:length(ca)] = Vector(ca)
33+
@parameters T::typeof(typeof(p))=typeof(p) [tunable = false]
34+
35+
@named input = RealInput(nin = n_input)
36+
@named output = RealOutput(nout = n_output)
37+
38+
out = stateless_apply(chain, input.u, lazyconvert(typeof(ca), p))
39+
40+
eqs = [output.u ~ out]
41+
42+
@named ude_comp = ODESystem(
43+
eqs, t_nounits, [], [p, T], systems = [input, output])
44+
return ude_comp
45+
end
446

547
end

src/hacks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
lazyconvert(x, y) = convert(x, y)
2+
lazyconvert(x, y::Symbolics.Arr) = Symbolics.array_term(convert, x, y)
3+
Symbolics.propagate_ndims(::typeof(convert), x, y) = ndims(y)
4+
Symbolics.propagate_shape(::typeof(convert), x, y) = Symbolics.shape(y)

src/utils.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
function multi_layer_feed_forward(input_length, output_length; width::Int = 5,
2+
depth::Int = 1, activation = softplus)
3+
Lux.Chain(Lux.Dense(input_length, width, activation),
4+
[Lux.Dense(width, width, activation) for _ in 1:(depth)]...,
5+
Lux.Dense(width, output_length); disable_optimizations = true)
6+
end
7+
8+
# Symbolics.@register_array_symbolic print_input(x) begin
9+
# size = size(x)
10+
# eltype = eltype(x)
11+
# end
12+
13+
# function print_input(x)
14+
# @info x
15+
# x
16+
# end
17+
18+
# function debug_component(n_input, n_output)
19+
# @named input = RealInput(nin = n_input)
20+
# @named output = RealOutput(nout = n_output)
21+
22+
# eqs = [output.u ~ print_input(input.u)]
23+
24+
# @named dbg_comp = ODESystem(eqs, t_nounits, [], [], systems = [input, output])
25+
# end

test/lotka_volterra.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
using Test
2+
using JET
3+
using UDEComponents
4+
using ModelingToolkit
5+
using ModelingToolkitStandardLibrary.Blocks
6+
using OrdinaryDiffEq
7+
using SymbolicIndexingInterface
8+
using Optimization
9+
using OptimizationOptimisers: Adam
10+
using SciMLStructures
11+
using SciMLStructures: Tunable
12+
using ForwardDiff
13+
14+
function lotka_ude()
15+
@variables t x(t)=3.1 y(t)=1.5
16+
@parameters α=1.3 β=0.9 γ=0.8 δ=1.8
17+
Dt = ModelingToolkit.D_nounits
18+
@named nn_in = RealInput(nin = 2)
19+
@named nn_out = RealOutput(nout = 2)
20+
21+
eqs = [
22+
Dt(x) ~ α * x + nn_in.u[1],
23+
Dt(y) ~ -δ * y + nn_in.u[2],
24+
nn_out.u[1] ~ x,
25+
nn_out.u[2] ~ y
26+
]
27+
return ODESystem(
28+
eqs, ModelingToolkit.t_nounits, name = :lotka, systems = [nn_in, nn_out])
29+
end
30+
31+
function lotka_true()
32+
@variables t x(t)=3.1 y(t)=1.5
33+
@parameters α=1.3 β=0.9 γ=0.8 δ=1.8
34+
Dt = ModelingToolkit.D_nounits
35+
36+
eqs = [
37+
Dt(x) ~ α * x - β * x * y,
38+
Dt(y) ~ -δ * y + δ * x * y
39+
]
40+
return ODESystem(eqs, ModelingToolkit.t_nounits, name = :lotka_true)
41+
end
42+
43+
model = lotka_ude()
44+
nn = create_ude_component(2, 2)
45+
46+
eqs = [
47+
connect(model.nn_in, nn.output)
48+
connect(model.nn_out, nn.input)
49+
]
50+
51+
ude_sys = complete(ODESystem(
52+
eqs, ModelingToolkit.t_nounits, systems = [model, nn], name = :ude_sys))
53+
54+
sys = structural_simplify(ude_sys)
55+
56+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])
57+
58+
model_true = structural_simplify(lotka_true())
59+
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0), [])
60+
sol_ref = solve(prob_true, Rodas4())
61+
62+
x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys)))
63+
64+
get_vars = getu(sys, [sys.lotka.x, sys.lotka.y])
65+
get_refs = getu(model_true, [model_true.x, model_true.y])
66+
67+
function loss(x, (prob, sol_ref, get_vars, get_refs))
68+
new_p = SciMLStructures.replace(Tunable(), prob.p, x)
69+
new_prob = remake(prob, p = new_p)
70+
ts = sol_ref.t
71+
new_sol = solve(new_prob, Rodas4(), saveat = ts)
72+
73+
loss = zero(eltype(x))
74+
75+
for i in eachindex(new_sol.u)
76+
loss += sum(sqrt.(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i))))
77+
end
78+
79+
if SciMLBase.successful_retcode(new_sol)
80+
loss
81+
else
82+
Inf
83+
end
84+
end
85+
86+
87+
of = OptimizationFunction{true}(loss, AutoForwardDiff())
88+
89+
ps = (prob, sol_ref, get_vars, get_refs);
90+
91+
@test_call target_modules=(UDEComponents,) loss(x0, ps)
92+
@test_opt target_modules=(UDEComponents,) loss(x0, ps)
93+
94+
@test all(.!isnan.(ForwardDiff.gradient(Base.Fix2(of, ps), x0)))
95+
96+
op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs))
97+
98+
99+
# using Plots
100+
101+
# oh = []
102+
103+
plot_cb = (opt_state, loss) -> begin
104+
@info "step $(opt_state.iter), loss: $loss"
105+
# push!(oh, opt_state)
106+
# new_p = SciMLStructures.replace(Tunable(), prob.p, opt_state.u)
107+
# new_prob = remake(prob, p = new_p)
108+
# sol = solve(new_prob, Rodas4())
109+
# display(plot(sol))
110+
false
111+
end
112+
113+
res = solve(op, Adam(), maxiters = 2000)#, callback = plot_cb)
114+
115+
@test res.objective < 1
116+
117+
res_p = SciMLStructures.replace(Tunable(), prob.p, res)
118+
res_prob = remake(prob, p = res_p)
119+
res_sol = solve(res_prob, Rodas4())
120+
121+
@test SciMLBase.successful_retcode(res_sol)

test/qa.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using Test
2+
using UDEComponents
3+
using Aqua
4+
using JET
5+
6+
@testset verbose = true "Code quality (Aqua.jl)" begin
7+
Aqua.find_persistent_tasks_deps(UDEComponents)
8+
Aqua.test_ambiguities(UDEComponents, recursive = false)
9+
Aqua.test_deps_compat(UDEComponents)
10+
# TODO: fix type piracy in propagate_ndims and propagate_shape
11+
Aqua.test_piracies(UDEComponents, broken=true)
12+
Aqua.test_project_extras(UDEComponents)
13+
Aqua.test_stale_deps(UDEComponents, ignore = Symbol[])
14+
Aqua.test_unbound_args(UDEComponents)
15+
Aqua.test_undefined_exports(UDEComponents)
16+
end
17+
18+
@testset "Code linting (JET.jl)" begin
19+
JET.test_package(UDEComponents; target_defined_modules = true)
20+
end

test/runtests.jl

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
using UDEComponents
22
using Test
3-
using Aqua
4-
using JET
3+
using SafeTestsets
54

6-
@testset "UDEComponents.jl" begin
7-
@testset "Code quality (Aqua.jl)" begin
8-
Aqua.test_all(UDEComponents)
9-
end
10-
@testset "Code linting (JET.jl)" begin
11-
JET.test_package(UDEComponents; target_defined_modules = true)
12-
end
13-
# Write your tests here.
5+
@testset verbose=true "UDEComponents.jl" begin
6+
@safetestset "QA" include("qa.jl")
7+
@safetestset "Basic" include("lotka_volterra.jl")
148
end

0 commit comments

Comments
 (0)