Skip to content

Commit 6015036

Browse files
feat: implement SII.remake_buffer for MTKParameters
1 parent 7eff5dd commit 6015036

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ SimpleNonlinearSolve = "0.1.0, 1"
102102
SparseArrays = "1"
103103
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
104104
StaticArrays = "0.10, 0.11, 0.12, 1.0"
105-
SymbolicIndexingInterface = "0.3.11"
105+
SymbolicIndexingInterface = "0.3.12"
106106
SymbolicUtils = "1.0"
107107
Symbolics = "5.26"
108108
URIs = "1"

src/systems/parameter_buffer.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,43 @@ function _set_parameter_unchecked!(
318318
p.dependent_update_iip(ArrayPartition(p.dependent), p...)
319319
end
320320

321+
function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, vals::Dict)
322+
buftypes = Dict{Tuple{Any, Int}, Any}()
323+
for (p, val) in vals
324+
(idx = parameter_index(sys, p)) isa ParameterIndex || continue
325+
k = (idx.portion, idx.idx[1])
326+
buftypes[k] = Union{get(buftypes, k, Union{}), typeof(val)}
327+
end
328+
329+
newbufs = []
330+
for (portion, old) in [(SciMLStructures.Tunable(), oldbuf.tunable)
331+
(SciMLStructures.Discrete(), oldbuf.discrete)
332+
(SciMLStructures.Constants(), oldbuf.constant)
333+
(DEPENDENT_PORTION, oldbuf.dependent)
334+
(NONNUMERIC_PORTION, oldbuf.nonnumeric)]
335+
if isempty(old)
336+
push!(newbufs, old)
337+
continue
338+
end
339+
new = Any[copy(i) for i in old]
340+
for i in eachindex(new)
341+
buftype = get(buftypes, (portion, i), eltype(new[i]))
342+
new[i] = similar(new[i], buftype)
343+
end
344+
push!(newbufs, Tuple(new))
345+
end
346+
newbuf = MTKParameters(
347+
newbufs..., oldbuf.dependent_update_iip, oldbuf.dependent_update_oop)
348+
for (p, val) in vals
349+
_set_parameter_unchecked!(
350+
newbuf, val, parameter_index(sys, p); update_dependent = false)
351+
end
352+
if newbuf.dependent_update_iip !== nothing
353+
newbuf.dependent_update_iip(ArrayPartition(newbuf.dependent), newbuf...)
354+
end
355+
return newbuf
356+
end
357+
321358
_subarrays(v::AbstractVector) = isempty(v) ? () : (v,)
322359
_subarrays(v::ArrayPartition) = v.x
323360
_subarrays(v::Tuple) = v

0 commit comments

Comments
 (0)