|
| 1 | +# Optimizing through an ODE solve and re-creating MTK Problems |
| 2 | + |
| 3 | +Solving an ODE as part of an `OptimizationProblem`'s loss function is a common scenario. |
| 4 | +In this example, we will go through an efficient way to model such scenarios using |
| 5 | +ModelingToolkit.jl. |
| 6 | + |
| 7 | +First, we build the ODE to be solved. For this example, we will use a Lotka-Volterra model: |
| 8 | + |
| 9 | +```@example Remake |
| 10 | +using ModelingToolkit |
| 11 | +using ModelingToolkit: t_nounits as t, D_nounits as D |
| 12 | +
|
| 13 | +@parameters α β γ δ |
| 14 | +@variables x(t) y(t) |
| 15 | +eqs = [D(x) ~ (α - β * y) * x |
| 16 | + D(y) ~ (δ * x - γ) * y] |
| 17 | +@mtkbuild odesys = ODESystem(eqs, t) |
| 18 | +``` |
| 19 | + |
| 20 | +To create the "data" for optimization, we will solve the system with a known set of |
| 21 | +parameters. |
| 22 | + |
| 23 | +```@example Remake |
| 24 | +using OrdinaryDiffEq |
| 25 | +
|
| 26 | +odeprob = ODEProblem( |
| 27 | + odesys, [x => 1.0, y => 1.0], (0.0, 10.0), [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0]) |
| 28 | +timesteps = 0.0:0.1:10.0 |
| 29 | +sol = solve(odeprob, Tsit5(); saveat = timesteps) |
| 30 | +data = Array(sol) |
| 31 | +# add some random noise |
| 32 | +data = data + 0.01 * randn(size(data)) |
| 33 | +``` |
| 34 | + |
| 35 | +Now we will create the loss function for the Optimization solve. This will require creating |
| 36 | +an `ODEProblem` with the parameter values passed to the loss function. Creating a new |
| 37 | +`ODEProblem` is expensive and requires differentiating through the code generation process. |
| 38 | +This can be bug-prone and is unnecessary. Instead, we will leverage the `remake` function. |
| 39 | +This allows creating a copy of an existing problem with updating state/parameter values. It |
| 40 | +should be noted that the types of the values passed to the loss function may not agree with |
| 41 | +the types stored in the existing `ODEProblem`. Thus, we cannot use `setp` to modify the |
| 42 | +problem in-place. Here, we will use the `replace` function from SciMLStructures.jl since |
| 43 | +it allows updating the entire `Tunable` portion of the parameter object which contains the |
| 44 | +parameters to optimize. |
| 45 | + |
| 46 | +```@example Remake |
| 47 | +using SymbolicIndexingInterface: parameter_values, state_values |
| 48 | +using SciMLStructures: Tunable, replace, replace! |
| 49 | +
|
| 50 | +function loss(x, p) |
| 51 | + odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables |
| 52 | + ps = parameter_values(odeprob) # obtain the parameter object from the problem |
| 53 | + ps = replace(Tunable(), ps, x) # create a copy with the values passed to the loss function |
| 54 | + T = eltype(x) |
| 55 | + # we also have to convert the `u0` vector |
| 56 | + u0 = T.(state_values(odeprob)) |
| 57 | + # remake the problem, passing in our new parameter object |
| 58 | + newprob = remake(odeprob; u0 = u0, p = ps) |
| 59 | + timesteps = p[2] |
| 60 | + sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps) |
| 61 | + truth = p[3] |
| 62 | + data = Array(sol) |
| 63 | + return sum((truth .- data) .^ 2) / length(truth) |
| 64 | +end |
| 65 | +``` |
| 66 | + |
| 67 | +Note how the problem, timesteps and true data are stored as model parameters. This helps |
| 68 | +avoid referencing global variables in the function, which would slow it down significantly. |
| 69 | + |
| 70 | +We could have done the same thing by passing `remake` a map of parameter values. For example, |
| 71 | +let us enforce that the order of ODE parameters in `x` is `[α β γ δ]`. Then, we could have |
| 72 | +done: |
| 73 | + |
| 74 | +```julia |
| 75 | +remake(odeprob; p = [α => x[1], β => x[2], γ => x[3], δ => x[4]]) |
| 76 | +``` |
| 77 | + |
| 78 | +However, passing a symbolic map to `remake` is significantly slower than passing it a |
| 79 | +parameter object directly. Thus, we use `replace` to speed up the process. In general, |
| 80 | +`remake` is the most flexible method, but the flexibility comes at a cost of performance. |
| 81 | + |
| 82 | +We can perform the optimization as below: |
| 83 | + |
| 84 | +```@example Remake |
| 85 | +using Optimization |
| 86 | +using OptimizationOptimJL |
| 87 | +
|
| 88 | +# manually create an OptimizationFunction to ensure usage of `ForwardDiff`, which will |
| 89 | +# require changing the types of parameters from `Float64` to `ForwardDiff.Dual` |
| 90 | +optfn = OptimizationFunction(loss, Optimization.AutoForwardDiff()) |
| 91 | +# parameter object is a tuple, to store differently typed objects together |
| 92 | +optprob = OptimizationProblem( |
| 93 | + optfn, rand(4), (odeprob, timesteps, data), lb = 0.1zeros(4), ub = 3ones(4)) |
| 94 | +sol = solve(optprob, BFGS()) |
| 95 | +``` |
| 96 | + |
| 97 | +To identify which values correspond to which parameters, we can `replace!` them into the |
| 98 | +`ODEProblem`: |
| 99 | + |
| 100 | +```@example Remake |
| 101 | +replace!(Tunable(), parameter_values(odeprob), sol.u) |
| 102 | +odeprob.ps[[α, β, γ, δ]] |
| 103 | +``` |
| 104 | + |
| 105 | +`replace!` operates in-place, so the values being replaced must be of the same type as those |
| 106 | +stored in the parameter object, or convertible to that type. For demonstration purposes, we |
| 107 | +can construct a loss function that uses `replace!`, and calculate gradients using |
| 108 | +`AutoFiniteDiff` rather than `AutoForwardDiff`. |
| 109 | + |
| 110 | +```@example Remake |
| 111 | +function loss2(x, p) |
| 112 | + odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables |
| 113 | + newprob = remake(odeprob) # copy the problem with `remake` |
| 114 | + # update the parameter values in-place |
| 115 | + replace!(Tunable(), parameter_values(newprob), x) |
| 116 | + timesteps = p[2] |
| 117 | + sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps) |
| 118 | + truth = p[3] |
| 119 | + data = Array(sol) |
| 120 | + return sum((truth .- data) .^ 2) / length(truth) |
| 121 | +end |
| 122 | +
|
| 123 | +# use finite-differencing to calculate derivatives |
| 124 | +optfn2 = OptimizationFunction(loss2, Optimization.AutoFiniteDiff()) |
| 125 | +optprob2 = OptimizationProblem( |
| 126 | + optfn2, rand(4), (odeprob, timesteps, data), lb = 0.1zeros(4), ub = 3ones(4)) |
| 127 | +sol = solve(optprob2, BFGS()) |
| 128 | +``` |
| 129 | + |
| 130 | +# Re-creating the problem |
| 131 | + |
| 132 | +There are multiple ways to re-create a problem with new state/parameter values. We will go |
| 133 | +over the various methods, listing their use cases. |
| 134 | + |
| 135 | +## Pure `remake` |
| 136 | + |
| 137 | +This method is the most generic. It can handle symbolic maps, initializations of |
| 138 | +parameters/states dependent on each other and partial updates. However, this comes at the |
| 139 | +cost of performance. `remake` is also not always inferrable. |
| 140 | + |
| 141 | +## `remake` and `setp`/`setu` |
| 142 | + |
| 143 | +Calling `remake(prob)` creates a copy of the existing problem. This new problem has the |
| 144 | +exact same types as the original one, and the `remake` call is fully inferred. |
| 145 | +State/parameter values can be modified after the copy by using `setp` and/or `setu`. This |
| 146 | +is most appropriate when the types of state/parameter values does not need to be changed, |
| 147 | +only their values. |
| 148 | + |
| 149 | +## `replace` and `remake` |
| 150 | + |
| 151 | +`replace` returns a copy of a parameter object, with the appropriate portion replaced by new |
| 152 | +values. This is useful for changing the type of an entire portion, such as during the |
| 153 | +optimization process described above. `remake` is used in this case to create a copy of the |
| 154 | +problem with updated state/unknown values. |
| 155 | + |
| 156 | +## `remake` and `replace!` |
| 157 | + |
| 158 | +`replace!` is similar to `replace`, except that it operates in-place. This means that the |
| 159 | +parameter values must be of the same types. This is useful for cases where bulk parameter |
| 160 | +replacement is required without needing to change types. For example, optimization methods |
| 161 | +where the gradient is not computed using dual numbers (as demonstrated above). |
0 commit comments