Skip to content

Commit b28e347

Browse files
Merge pull request #2610 from AayushSabharwal/as/fix-mtkparams-replace
refactor: improve replace, remake_buffer
2 parents 0dc0f5d + 0a290b7 commit b28e347

File tree

1 file changed

+21
-36
lines changed

1 file changed

+21
-36
lines changed

src/systems/parameter_buffer.jl

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,7 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable)
192192
end
193193

194194
@eval function SciMLStructures.replace!(::$Portion, p::MTKParameters, newvals)
195-
src = split_into_buffers(newvals, p.$field)
196-
for i in 1:length(p.$field)
197-
(p.$field)[i] .= src[i]
198-
end
195+
update_tuple_of_buffers(newvals, p.$field)
199196
if p.dependent_update_iip !== nothing
200197
p.dependent_update_iip(ArrayPartition(p.dependent), p...)
201198
end
@@ -318,44 +315,32 @@ function _set_parameter_unchecked!(
318315
p.dependent_update_iip(ArrayPartition(p.dependent), p...)
319316
end
320317

321-
function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, vals::Dict)
322-
buftypes = Dict{Tuple{Any, Int}, Any}()
323-
for (p, val) in vals
324-
(idx = parameter_index(sys, p)) isa ParameterIndex || continue
325-
k = (idx.portion, idx.idx[1])
326-
buftypes[k] = Union{get(buftypes, k, Union{}), typeof(val)}
318+
function narrow_buffer_type(buffer::Vector)
319+
type = Union{}
320+
for x in buffer
321+
type = Union{type, typeof(x)}
327322
end
323+
return convert(Vector{type}, buffer)
324+
end
325+
326+
function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, vals::Dict)
327+
newbuf = @set oldbuf.tunable = similar.(oldbuf.tunable, Any)
328+
@set! newbuf.discrete = similar.(newbuf.discrete, Any)
329+
@set! newbuf.constant = similar.(newbuf.constant, Any)
330+
@set! newbuf.nonnumeric = similar.(newbuf.nonnumeric, Any)
328331

329-
newbufs = []
330-
for (portion, old) in [(SciMLStructures.Tunable(), oldbuf.tunable)
331-
(SciMLStructures.Discrete(), oldbuf.discrete)
332-
(SciMLStructures.Constants(), oldbuf.constant)
333-
(NONNUMERIC_PORTION, oldbuf.nonnumeric)]
334-
if isempty(old)
335-
push!(newbufs, old)
336-
continue
337-
end
338-
new = Any[copy(i) for i in old]
339-
for i in eachindex(new)
340-
buftype = get(buftypes, (portion, i), eltype(new[i]))
341-
new[i] = similar(new[i], buftype)
342-
end
343-
push!(newbufs, Tuple(new))
344-
end
345-
tmpbuf = MTKParameters(
346-
newbufs[1], newbufs[2], newbufs[3], oldbuf.dependent, newbufs[4], nothing, nothing)
347332
for (p, val) in vals
348333
_set_parameter_unchecked!(
349-
tmpbuf, val, parameter_index(sys, p); update_dependent = false)
334+
newbuf, val, parameter_index(sys, p); update_dependent = false)
350335
end
351-
if oldbuf.dependent_update_oop !== nothing
352-
dependent = oldbuf.dependent_update_oop(tmpbuf...)
353-
else
354-
dependent = ()
336+
337+
@set! newbuf.tunable = narrow_buffer_type.(newbuf.tunable)
338+
@set! newbuf.discrete = narrow_buffer_type.(newbuf.discrete)
339+
@set! newbuf.constant = narrow_buffer_type.(newbuf.constant)
340+
@set! newbuf.nonnumeric = narrow_buffer_type.(newbuf.nonnumeric)
341+
if newbuf.dependent_update_oop !== nothing
342+
@set! newbuf.dependent = newbuf.dependent_update_oop(newbuf...)
355343
end
356-
newbuf = MTKParameters(newbufs[1], newbufs[2], newbufs[3], dependent, newbufs[4],
357-
oldbuf.dependent_update_iip, oldbuf.dependent_update_oop)
358-
return newbuf
359344
end
360345

361346
_subarrays(v::AbstractVector) = isempty(v) ? () : (v,)

0 commit comments

Comments
 (0)