Skip to content

Commit d8ece0d

Browse files
Merge pull request #8 from SciML/smc/fix
Fix NaN in gradients
2 parents 858f3d1 + 2a018dc commit d8ece0d

File tree

4 files changed

+14
-8
lines changed

4 files changed

+14
-8
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
99
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1010
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1111
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
12-
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1312
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1413
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1514

@@ -22,13 +21,13 @@ Lux = "0.5.32"
2221
LuxCore = "0.1.14"
2322
ModelingToolkit = "9.9.0"
2423
ModelingToolkitStandardLibrary = "2.6"
25-
NNlib = "0.9"
2624
Optimization = "3.22"
2725
OptimizationOptimisers = "0.2"
2826
OrdinaryDiffEq = "6.74"
2927
Random = "1"
3028
SafeTestsets = "0.1"
3129
SciMLStructures = "1.1.0"
30+
StableRNGs = "1"
3231
SymbolicIndexingInterface = "0.3.15"
3332
Symbolics = "5.27"
3433
Test = "1"
@@ -43,8 +42,9 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
4342
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
4443
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4544
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
45+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
4646
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
4747
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4848

4949
[targets]
50-
test = ["Aqua", "JET", "Test", "OrdinaryDiffEq", "ForwardDiff", "Optimization", "OptimizationOptimisers", "SafeTestsets", "SciMLStructures", "SymbolicIndexingInterface"]
50+
test = ["Aqua", "JET", "Test", "OrdinaryDiffEq", "ForwardDiff", "Optimization", "OptimizationOptimisers", "SafeTestsets", "SciMLStructures", "StableRNGs", "SymbolicIndexingInterface"]

src/UDEComponents.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using Symbolics: Symbolics, @register_array_symbolic, @wrapped
66
using LuxCore: stateless_apply
77
using Lux: Lux
88
using Random: Xoshiro
9-
using NNlib: softplus
109
using ComponentArrays: ComponentArray
1110

1211
export create_ude_component, multi_layer_feed_forward

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function multi_layer_feed_forward(input_length, output_length; width::Int = 5,
2-
depth::Int = 1, activation = softplus, disable_optimizations = false)
2+
depth::Int = 1, activation = tanh, disable_optimizations = false)
33
Lux.Chain(Lux.Dense(input_length, width, activation),
44
[Lux.Dense(width, width, activation) for _ in 1:(depth)]...,
55
Lux.Dense(width, output_length); disable_optimizations)

test/lotka_volterra.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using OptimizationOptimisers: Adam
1010
using SciMLStructures
1111
using SciMLStructures: Tunable
1212
using ForwardDiff
13+
using StableRNGs
1314

1415
function lotka_ude()
1516
@variables t x(t)=3.1 y(t)=1.5
@@ -41,7 +42,9 @@ function lotka_true()
4142
end
4243

4344
model = lotka_ude()
44-
nn = create_ude_component(2, 2)
45+
46+
chain = multi_layer_feed_forward(2, 2)
47+
nn = create_ude_component(2, 2; chain, rng = StableRNG(42))
4548

4649
eqs = [connect(model.nn_in, nn.output)
4750
connect(model.nn_out, nn.input)]
@@ -71,7 +74,7 @@ function loss(x, (prob, sol_ref, get_vars, get_refs))
7174
loss = zero(eltype(x))
7275

7376
for i in eachindex(new_sol.u)
74-
loss += sum(sqrt.(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i))))
77+
loss += sum(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i)))
7578
end
7679

7780
if SciMLBase.successful_retcode(new_sol)
@@ -106,12 +109,16 @@ op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs))
106109
# false
107110
# end
108111

109-
res = solve(op, Adam(), maxiters = 2000)#, callback = plot_cb)
112+
res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
110113

111114
@test res.objective < 1
112115

113116
res_p = SciMLStructures.replace(Tunable(), prob.p, res)
114117
res_prob = remake(prob, p = res_p)
115118
res_sol = solve(res_prob, Rodas4())
116119

120+
# using Plots
121+
# plot(sol_ref, idxs = [model_true.x, model_true.y])
122+
# plot!(res_sol, idxs = [sys.lotka.x, sys.lotka.y])
123+
117124
@test SciMLBase.successful_retcode(res_sol)

0 commit comments

Comments
 (0)