Skip to content

Commit 346d822

Browse files
fix: fix remake_buffer with dependent update, test with ForwardDiff
1 parent 36c1ad6 commit 346d822

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

src/systems/parameter_buffer.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,6 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
330330
for (portion, old) in [(SciMLStructures.Tunable(), oldbuf.tunable)
331331
(SciMLStructures.Discrete(), oldbuf.discrete)
332332
(SciMLStructures.Constants(), oldbuf.constant)
333-
(DEPENDENT_PORTION, oldbuf.dependent)
334333
(NONNUMERIC_PORTION, oldbuf.nonnumeric)]
335334
if isempty(old)
336335
push!(newbufs, old)
@@ -343,15 +342,19 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
343342
end
344343
push!(newbufs, Tuple(new))
345344
end
346-
newbuf = MTKParameters(
347-
newbufs..., oldbuf.dependent_update_iip, oldbuf.dependent_update_oop)
345+
tmpbuf = MTKParameters(
346+
newbufs[1], newbufs[2], newbufs[3], oldbuf.dependent, newbufs[4], nothing, nothing)
348347
for (p, val) in vals
349348
_set_parameter_unchecked!(
350-
newbuf, val, parameter_index(sys, p); update_dependent = false)
349+
tmpbuf, val, parameter_index(sys, p); update_dependent = false)
351350
end
352-
if newbuf.dependent_update_iip !== nothing
353-
newbuf.dependent_update_iip(ArrayPartition(newbuf.dependent), newbuf...)
351+
if oldbuf.dependent_update_oop !== nothing
352+
dependent = oldbuf.dependent_update_oop(tmpbuf...)
353+
else
354+
dependent = ()
354355
end
356+
newbuf = MTKParameters(newbufs[1], newbufs[2], newbufs[3], dependent, newbufs[4],
357+
oldbuf.dependent_update_iip, oldbuf.dependent_update_oop)
355358
return newbuf
356359
end
357360

test/mtkparameters.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ModelingToolkit
22
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
33
using SymbolicIndexingInterface
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5+
using ForwarDiff
56

67
@parameters a b c d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
78
@named sys = ODESystem(
@@ -90,3 +91,13 @@ end
9091
@test getp(sys, f)(newps) isa Matrix{UInt}
9192
# SII bug
9293
@test_broken getp(sys, g)(newps) isa Vector{Float32}
94+
95+
ps = MTKParameters(sys, ivs)
96+
function loss(value, sys, ps)
97+
@test value isa ForwardDiff.Dual
98+
vals = merge(Dict(parameters(sys) .=> getp(sys, parameters(sys))(ps)), Dict(a => value))
99+
ps = remake_buffer(sys, ps, vals)
100+
getp(sys, a)(ps) + getp(sys, b)(ps)
101+
end
102+
103+
@test ForwardDiff.derivative(x -> loss(x, sys, ps), 1.5) == 3.0

0 commit comments

Comments
 (0)