Skip to content

Commit d6880d0

Browse files
fix: implement DiffEqBase.anyeltypedual for MTKParameters
1 parent 709148e commit d6880d0

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

docs/src/examples/remake.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,8 @@ function loss(x, p)
5151
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
5252
ps = parameter_values(odeprob) # obtain the parameter object from the problem
5353
ps = replace(Tunable(), ps, x) # create a copy with the values passed to the loss function
54-
T = eltype(x)
55-
# we also have to convert the `u0` vector
56-
u0 = T.(state_values(odeprob))
5754
# remake the problem, passing in our new parameter object
58-
newprob = remake(odeprob; u0 = u0, p = ps)
55+
newprob = remake(odeprob; p = ps)
5956
timesteps = p[2]
6057
sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
6158
truth = p[3]

src/systems/parameter_buffer.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,15 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
384384
return newbuf
385385
end
386386

387+
function DiffEqBase.anyeltypedual(
388+
p::MTKParameters, ::Type{Val{counter}} = Val{0}) where {counter}
389+
DiffEqBase.anyeltypedual(p.tunable)
390+
end
391+
function DiffEqBase.anyeltypedual(p::Type{<:MTKParameters{T}},
392+
::Type{Val{counter}} = Val{0}) where {counter} where {T}
393+
DiffEqBase.__anyeltypedual(T)
394+
end
395+
387396
_subarrays(v::AbstractVector) = isempty(v) ? () : (v,)
388397
_subarrays(v::ArrayPartition) = v.x
389398
_subarrays(v::Tuple) = v

test/mtkparameters.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,24 @@ ps = [p => 1.0] # Value for `d` is missing
134134

135135
@test_throws ModelingToolkit.MissingVariablesError ODEProblem(sys, u0, tspan, ps)
136136
@test_nowarn ODEProblem(sys, u0, tspan, [ps..., d => 1.0])
137+
138+
# Issue#2642
139+
@parameters α β γ δ
140+
@variables x(t) y(t)
141+
eqs = [D(x) ~- β * y) * x
142+
D(y) ~* x - γ) * y]
143+
@mtkbuild odesys = ODESystem(eqs, t)
144+
odeprob = ODEProblem(
145+
odesys, [x => 1.0, y => 1.0], (0.0, 10.0), [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0])
146+
tunables, _... = canonicalize(Tunable(), odeprob.p)
147+
@test tunables isa AbstractVector{Float64}
148+
149+
function loss(x)
150+
ps = odeprob.p
151+
newps = replace(Tunable(), ps, x)
152+
newprob = SciMLStructures.remake(odeprob, p = newps)
153+
sol = solve(newprob, Tsit5())
154+
return sum(sol)
155+
end
156+
157+
@test_nowarn ForwardDiff.gradient(loss, collect(tunables))

0 commit comments

Comments
 (0)