Skip to content

Commit d2df4c3

Browse files
fix: implement DiffEqBase.anyeltypedual for MTKParameters
1 parent 709148e commit d2df4c3

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

docs/src/examples/remake.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,8 @@ function loss(x, p)
5151
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
5252
ps = parameter_values(odeprob) # obtain the parameter object from the problem
5353
ps = replace(Tunable(), ps, x) # create a copy with the values passed to the loss function
54-
T = eltype(x)
55-
# we also have to convert the `u0` vector
56-
u0 = T.(state_values(odeprob))
5754
# remake the problem, passing in our new parameter object
58-
newprob = remake(odeprob; u0 = u0, p = ps)
55+
newprob = remake(odeprob; p = ps)
5956
timesteps = p[2]
6057
sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
6158
truth = p[3]

src/systems/parameter_buffer.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,15 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
384384
return newbuf
385385
end
386386

387+
function DiffEqBase.anyeltypedual(
388+
p::MTKParameters, ::Type{Val{counter}} = Val{0}) where {counter}
389+
DiffEqBase.anyeltypedual(p.tunable)
390+
end
391+
function DiffEqBase.anyeltypedual(p::Type{<:MTKParameters{T}},
392+
::Type{Val{counter}} = Val{0}) where {counter} where {T}
393+
DiffEqBase.__anyeltypedual(T)
394+
end
395+
387396
_subarrays(v::AbstractVector) = isempty(v) ? () : (v,)
388397
_subarrays(v::ArrayPartition) = v.x
389398
_subarrays(v::Tuple) = v

test/mtkparameters.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ModelingToolkit
22
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
33
using SymbolicIndexingInterface
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5+
using OrdinaryDiffEq
56
using ForwardDiff
67

78
@parameters a b c d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
@@ -134,3 +135,24 @@ ps = [p => 1.0] # Value for `d` is missing
134135

135136
@test_throws ModelingToolkit.MissingVariablesError ODEProblem(sys, u0, tspan, ps)
136137
@test_nowarn ODEProblem(sys, u0, tspan, [ps..., d => 1.0])
138+
139+
# Issue#2642
140+
@parameters α β γ δ
141+
@variables x(t) y(t)
142+
eqs = [D(x) ~- β * y) * x
143+
D(y) ~* x - γ) * y]
144+
@mtkbuild odesys = ODESystem(eqs, t)
145+
odeprob = ODEProblem(
146+
odesys, [x => 1.0, y => 1.0], (0.0, 10.0), [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0])
147+
tunables, _... = canonicalize(Tunable(), odeprob.p)
148+
@test tunables isa AbstractVector{Float64}
149+
150+
function loss(x)
151+
ps = odeprob.p
152+
newps = SciMLStructures.replace(Tunable(), ps, x)
153+
newprob = remake(odeprob, p = newps)
154+
sol = solve(newprob, Tsit5())
155+
return sum(sol)
156+
end
157+
158+
@test_nowarn ForwardDiff.gradient(loss, collect(tunables))

0 commit comments

Comments
 (0)