Skip to content

Commit 9e8154d

Browse files
Merge pull request #509 from SciML/ChrisRackauckas-patch-2
Bump DiffEqFlux in tests
2 parents 6f95092 + 3d40fbd commit 9e8154d

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

test/Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
[deps]
22
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
34
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
45
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
56
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
67
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
78
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
9+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
810
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
911
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1012
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
@@ -19,11 +21,13 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1921
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2022

2123
[compat]
22-
DiffEqFlux = ">= 1.43.0"
24+
ComponentArrays = ">= 0.13.9"
25+
DiffEqFlux = ">= 2"
2326
FiniteDiff = ">= 2.8.1"
2427
Optimisers = ">= 0.2.5"
2528
ForwardDiff = ">= 0.10.19"
2629
IterTools = ">= 1.3.0"
30+
Lux = ">= 0.4.50"
2731
ModelingToolkit = ">= 8.11.0"
2832
Optim = ">= 1.4.1"
2933
OrdinaryDiffEq = ">= 5"

test/diffeqfluxtests.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using OrdinaryDiffEq, DiffEqFlux, Optimization, OptimizationOptimJL, OptimizationOptimisers,
2-
ForwardDiff
1+
using OrdinaryDiffEq, DiffEqFlux, Lux, Optimization, OptimizationOptimJL,
2+
OptimizationOptimisers, ForwardDiff, ComponentArrays, Random
3+
rng = Random.default_rng()
34

45
function lotka_volterra!(du, u, p, t)
56
x, y = u
@@ -68,17 +69,15 @@ end
6869
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
6970
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
7071

71-
dudt2 = FastChain((x, p) -> x .^ 3,
72-
FastDense(2, 50, tanh),
73-
FastDense(50, 2))
72+
dudt2 = Lux.Chain(x -> x .^ 3,
73+
Lux.Dense(2, 50, tanh),
74+
Lux.Dense(50, 2))
7475
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
75-
76-
dudt2 = Chain(x -> x .^ 3,
77-
Dense(2, 50, tanh),
78-
Dense(50, 2))
76+
pp, st = Lux.setup(rng, dudt2)
77+
pp = ComponentArray(pp)
7978

8079
function predict_neuralode(p)
81-
Array(prob_neuralode(u0, p))
80+
Array(prob_neuralode(u0, p, st)[1])
8281
end
8382

8483
function loss_neuralode(p)
@@ -98,7 +97,7 @@ end
9897

9998
optprob = OptimizationFunction((p, x) -> loss_neuralode(p), Optimization.AutoForwardDiff())
10099

101-
prob = Optimization.OptimizationProblem(optprob, prob_neuralode.p)
100+
prob = Optimization.OptimizationProblem(optprob, pp)
102101

103102
result_neuralode = Optimization.solve(prob,
104103
OptimizationOptimisers.ADAM(), callback = callback,

test/minibatch.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using DiffEqFlux, Optimization, OrdinaryDiffEq, OptimizationOptimisers, ModelingToolkit,
2-
SciMLSensitivity
2+
SciMLSensitivity, Lux, Random, ComponentArrays
3+
4+
rng = Random.default_rng()
35

46
function newtons_cooling(du, u, p, t)
57
temp = u[1]
@@ -13,7 +15,7 @@ function true_sol(du, u, p, t)
1315
end
1416

1517
function dudt_(u, p, t)
16-
ann(u, p) .* u
18+
ann(u, p, st)[1] .* u
1719
end
1820

1921
callback = function (p, l, pred; doplot = false) #callback function to observe training
@@ -35,8 +37,10 @@ t = range(tspan[1], tspan[2], length = datasize)
3537
true_prob = ODEProblem(true_sol, u0, tspan)
3638
ode_data = Array(solve(true_prob, Tsit5(), saveat = t))
3739

38-
ann = FastChain(FastDense(1, 8, tanh), FastDense(8, 1, tanh))
39-
pp = initial_params(ann)
40+
ann = Lux.Chain(Lux.Dense(1, 8, tanh), Lux.Dense(8, 1, tanh))
41+
pp, st = Lux.setup(rng, ann)
42+
pp = ComponentArray(pp)
43+
4044
prob = ODEProblem{false}(dudt_, u0, tspan, pp)
4145

4246
function predict_adjoint(fullp, time_batch)

0 commit comments

Comments
 (0)