Skip to content

Commit 5385843

Browse files
Merge pull request #2600 from AayushSabharwal/as/remake-tutorial
docs: add tutorial for optimizing ODE solve and remake
2 parents 13ab6d9 + 6b498dc commit 5385843

File tree

3 files changed

+165
-1
lines changed

3 files changed

+165
-1
lines changed

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1313
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
1414
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1515
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
16+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
1617
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
1718
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1819
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
@@ -33,6 +34,7 @@ Optimization = "3.9"
3334
OptimizationOptimJL = "0.1"
3435
OrdinaryDiffEq = "6.31"
3536
Plots = "1.36"
37+
SciMLStructures = "1.1"
3638
StochasticDiffEq = "6"
3739
SymbolicIndexingInterface = "0.3.1"
3840
SymbolicUtils = "1"

docs/pages.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ pages = [
1717
"Basic Examples" => Any["examples/higher_order.md",
1818
"examples/spring_mass.md",
1919
"examples/modelingtoolkitize_index_reduction.md",
20-
"examples/parsing.md"],
20+
"examples/parsing.md",
21+
"examples/remake.md"],
2122
"Advanced Examples" => Any["examples/tearing_parallelism.md",
2223
"examples/sparse_jacobians.md",
2324
"examples/perturbation.md"]],

docs/src/examples/remake.md

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Optimizing through an ODE solve and re-creating MTK Problems
2+
3+
Solving an ODE as part of an `OptimizationProblem`'s loss function is a common scenario.
4+
In this example, we will go through an efficient way to model such scenarios using
5+
ModelingToolkit.jl.
6+
7+
First, we build the ODE to be solved. For this example, we will use a Lotka-Volterra model:
8+
9+
```@example Remake
10+
using ModelingToolkit
11+
using ModelingToolkit: t_nounits as t, D_nounits as D
12+
13+
@parameters α β γ δ
14+
@variables x(t) y(t)
15+
eqs = [D(x) ~ (α - β * y) * x
16+
D(y) ~ (δ * x - γ) * y]
17+
@mtkbuild odesys = ODESystem(eqs, t)
18+
```
19+
20+
To create the "data" for optimization, we will solve the system with a known set of
21+
parameters.
22+
23+
```@example Remake
24+
using OrdinaryDiffEq
25+
26+
odeprob = ODEProblem(
27+
odesys, [x => 1.0, y => 1.0], (0.0, 10.0), [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0])
28+
timesteps = 0.0:0.1:10.0
29+
sol = solve(odeprob, Tsit5(); saveat = timesteps)
30+
data = Array(sol)
31+
# add some random noise
32+
data = data + 0.01 * randn(size(data))
33+
```
34+
35+
Now we will create the loss function for the Optimization solve. This will require creating
36+
an `ODEProblem` with the parameter values passed to the loss function. Creating a new
37+
`ODEProblem` is expensive and requires differentiating through the code generation process.
38+
This can be bug-prone and is unnecessary. Instead, we will leverage the `remake` function.
39+
This allows creating a copy of an existing problem with updating state/parameter values. It
40+
should be noted that the types of the values passed to the loss function may not agree with
41+
the types stored in the existing `ODEProblem`. Thus, we cannot use `setp` to modify the
42+
problem in-place. Here, we will use the `replace` function from SciMLStructures.jl since
43+
it allows updating the entire `Tunable` portion of the parameter object which contains the
44+
parameters to optimize.
45+
46+
```@example Remake
47+
using SymbolicIndexingInterface: parameter_values, state_values
48+
using SciMLStructures: Tunable, replace, replace!
49+
50+
function loss(x, p)
51+
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
52+
ps = parameter_values(odeprob) # obtain the parameter object from the problem
53+
ps = replace(Tunable(), ps, x) # create a copy with the values passed to the loss function
54+
T = eltype(x)
55+
# we also have to convert the `u0` vector
56+
u0 = T.(state_values(odeprob))
57+
# remake the problem, passing in our new parameter object
58+
newprob = remake(odeprob; u0 = u0, p = ps)
59+
timesteps = p[2]
60+
sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
61+
truth = p[3]
62+
data = Array(sol)
63+
return sum((truth .- data) .^ 2) / length(truth)
64+
end
65+
```
66+
67+
Note how the problem, timesteps and true data are stored as model parameters. This helps
68+
avoid referencing global variables in the function, which would slow it down significantly.
69+
70+
We could have done the same thing by passing `remake` a map of parameter values. For example,
71+
let us enforce that the order of ODE parameters in `x` is `[α β γ δ]`. Then, we could have
72+
done:
73+
74+
```julia
75+
remake(odeprob; p ==> x[1], β => x[2], γ => x[3], δ => x[4]])
76+
```
77+
78+
However, passing a symbolic map to `remake` is significantly slower than passing it a
79+
parameter object directly. Thus, we use `replace` to speed up the process. In general,
80+
`remake` is the most flexible method, but the flexibility comes at a cost of performance.
81+
82+
We can perform the optimization as below:
83+
84+
```@example Remake
85+
using Optimization
86+
using OptimizationOptimJL
87+
88+
# manually create an OptimizationFunction to ensure usage of `ForwardDiff`, which will
89+
# require changing the types of parameters from `Float64` to `ForwardDiff.Dual`
90+
optfn = OptimizationFunction(loss, Optimization.AutoForwardDiff())
91+
# parameter object is a tuple, to store differently typed objects together
92+
optprob = OptimizationProblem(
93+
optfn, rand(4), (odeprob, timesteps, data), lb = 0.1zeros(4), ub = 3ones(4))
94+
sol = solve(optprob, BFGS())
95+
```
96+
97+
To identify which values correspond to which parameters, we can `replace!` them into the
98+
`ODEProblem`:
99+
100+
```@example Remake
101+
replace!(Tunable(), parameter_values(odeprob), sol.u)
102+
odeprob.ps[[α, β, γ, δ]]
103+
```
104+
105+
`replace!` operates in-place, so the values being replaced must be of the same type as those
106+
stored in the parameter object, or convertible to that type. For demonstration purposes, we
107+
can construct a loss function that uses `replace!`, and calculate gradients using
108+
`AutoFiniteDiff` rather than `AutoForwardDiff`.
109+
110+
```@example Remake
111+
function loss2(x, p)
112+
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
113+
newprob = remake(odeprob) # copy the problem with `remake`
114+
# update the parameter values in-place
115+
replace!(Tunable(), parameter_values(newprob), x)
116+
timesteps = p[2]
117+
sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
118+
truth = p[3]
119+
data = Array(sol)
120+
return sum((truth .- data) .^ 2) / length(truth)
121+
end
122+
123+
# use finite-differencing to calculate derivatives
124+
optfn2 = OptimizationFunction(loss2, Optimization.AutoFiniteDiff())
125+
optprob2 = OptimizationProblem(
126+
optfn2, rand(4), (odeprob, timesteps, data), lb = 0.1zeros(4), ub = 3ones(4))
127+
sol = solve(optprob2, BFGS())
128+
```
129+
130+
# Re-creating the problem
131+
132+
There are multiple ways to re-create a problem with new state/parameter values. We will go
133+
over the various methods, listing their use cases.
134+
135+
## Pure `remake`
136+
137+
This method is the most generic. It can handle symbolic maps, initializations of
138+
parameters/states dependent on each other and partial updates. However, this comes at the
139+
cost of performance. `remake` is also not always inferrable.
140+
141+
## `remake` and `setp`/`setu`
142+
143+
Calling `remake(prob)` creates a copy of the existing problem. This new problem has the
144+
exact same types as the original one, and the `remake` call is fully inferred.
145+
State/parameter values can be modified after the copy by using `setp` and/or `setu`. This
146+
is most appropriate when the types of state/parameter values does not need to be changed,
147+
only their values.
148+
149+
## `replace` and `remake`
150+
151+
`replace` returns a copy of a parameter object, with the appropriate portion replaced by new
152+
values. This is useful for changing the type of an entire portion, such as during the
153+
optimization process described above. `remake` is used in this case to create a copy of the
154+
problem with updated state/unknown values.
155+
156+
## `remake` and `replace!`
157+
158+
`replace!` is similar to `replace`, except that it operates in-place. This means that the
159+
parameter values must be of the same types. This is useful for cases where bulk parameter
160+
replacement is required without needing to change types. For example, optimization methods
161+
where the gradient is not computed using dual numbers (as demonstrated above).

0 commit comments

Comments
 (0)