Skip to content

Commit 0116cd1

Browse files
Merge pull request #2570 from AayushSabharwal/as/remake-hook
feat: implement SII.remake_buffer for MTKParameters, fix bugs, add tests
2 parents b0542be + e476efa commit 0116cd1

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ SimpleNonlinearSolve = "0.1.0, 1"
102102
SparseArrays = "1"
103103
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
104104
StaticArrays = "0.10, 0.11, 0.12, 1.0"
105-
SymbolicIndexingInterface = "0.3.11"
105+
SymbolicIndexingInterface = "0.3.12"
106106
SymbolicUtils = "1.0"
107107
Symbolics = "5.26"
108108
URIs = "1"

src/systems/parameter_buffer.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,46 @@ function _set_parameter_unchecked!(
318318
p.dependent_update_iip(ArrayPartition(p.dependent), p...)
319319
end
320320

321+
function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, vals::Dict)
322+
buftypes = Dict{Tuple{Any, Int}, Any}()
323+
for (p, val) in vals
324+
(idx = parameter_index(sys, p)) isa ParameterIndex || continue
325+
k = (idx.portion, idx.idx[1])
326+
buftypes[k] = Union{get(buftypes, k, Union{}), typeof(val)}
327+
end
328+
329+
newbufs = []
330+
for (portion, old) in [(SciMLStructures.Tunable(), oldbuf.tunable)
331+
(SciMLStructures.Discrete(), oldbuf.discrete)
332+
(SciMLStructures.Constants(), oldbuf.constant)
333+
(NONNUMERIC_PORTION, oldbuf.nonnumeric)]
334+
if isempty(old)
335+
push!(newbufs, old)
336+
continue
337+
end
338+
new = Any[copy(i) for i in old]
339+
for i in eachindex(new)
340+
buftype = get(buftypes, (portion, i), eltype(new[i]))
341+
new[i] = similar(new[i], buftype)
342+
end
343+
push!(newbufs, Tuple(new))
344+
end
345+
tmpbuf = MTKParameters(
346+
newbufs[1], newbufs[2], newbufs[3], oldbuf.dependent, newbufs[4], nothing, nothing)
347+
for (p, val) in vals
348+
_set_parameter_unchecked!(
349+
tmpbuf, val, parameter_index(sys, p); update_dependent = false)
350+
end
351+
if oldbuf.dependent_update_oop !== nothing
352+
dependent = oldbuf.dependent_update_oop(tmpbuf...)
353+
else
354+
dependent = ()
355+
end
356+
newbuf = MTKParameters(newbufs[1], newbufs[2], newbufs[3], dependent, newbufs[4],
357+
oldbuf.dependent_update_iip, oldbuf.dependent_update_oop)
358+
return newbuf
359+
end
360+
321361
_subarrays(v::AbstractVector) = isempty(v) ? () : (v,)
322362
_subarrays(v::ArrayPartition) = v.x
323363
_subarrays(v::Tuple) = v

test/mtkparameters.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ModelingToolkit
22
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
33
using SymbolicIndexingInterface
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5+
using ForwardDiff
56

67
@parameters a b c d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
78
@named sys = ODESystem(
@@ -70,3 +71,33 @@ setp(sys, f[2, 2])(ps, 42) # with a sub-index
7071

7172
setp(sys, h)(ps, "bar") # with a non-numeric
7273
@test getp(sys, h)(ps) == "bar"
74+
75+
newps = remake_buffer(sys,
76+
ps,
77+
Dict(a => 1.0f0, b => 5.0f0, c => 2.0, d => 0x5, e => [0.4, 0.5, 0.6],
78+
f => 3ones(UInt, 3, 3), g => ones(Float32, 4), h => "bar"))
79+
80+
for fname in (:tunable, :discrete, :constant, :dependent)
81+
# ensure same number of sub-buffers
82+
@test length(getfield(ps, fname)) == length(getfield(newps, fname))
83+
end
84+
@test ps.dependent_update_iip === newps.dependent_update_iip
85+
@test ps.dependent_update_oop === newps.dependent_update_oop
86+
87+
@test getp(sys, a)(newps) isa Float32
88+
@test getp(sys, b)(newps) == 2.0f0 # ensure dependent update still happened, despite explicit value
89+
@test getp(sys, c)(newps) isa Float64
90+
@test getp(sys, d)(newps) isa UInt8
91+
@test getp(sys, f)(newps) isa Matrix{UInt}
92+
# SII bug
93+
@test_broken getp(sys, g)(newps) isa Vector{Float32}
94+
95+
ps = MTKParameters(sys, ivs)
96+
function loss(value, sys, ps)
97+
@test value isa ForwardDiff.Dual
98+
vals = merge(Dict(parameters(sys) .=> getp(sys, parameters(sys))(ps)), Dict(a => value))
99+
ps = remake_buffer(sys, ps, vals)
100+
getp(sys, a)(ps) + getp(sys, b)(ps)
101+
end
102+
103+
@test ForwardDiff.derivative(x -> loss(x, sys, ps), 1.5) == 3.0

0 commit comments

Comments
 (0)