Skip to content

Commit ccf6358

Browse files
fix: fix bug in remake_buffer
1 parent b77d998 commit ccf6358

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

src/systems/parameter_buffer.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,8 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
397397
@set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.(
398398
oldbuf.nonnumeric, newbuf.nonnumeric)
399399
if newbuf.dependent_update_oop !== nothing
400-
@set! newbuf.dependent = newbuf.dependent_update_oop(newbuf...)
400+
@set! newbuf.dependent = split_into_buffers(
401+
newbuf.dependent_update_oop(newbuf...), oldbuf.dependent)
401402
end
402403
return newbuf
403404
end
@@ -421,6 +422,7 @@ _num_subarrays(v::Tuple) = length(v)
421422
# getindex indexes the vectors, setindex! linearly indexes values
422423
# it's inconsistent, but we need it to be this way
423424
function Base.getindex(buf::MTKParameters, i)
425+
i_orig = i
424426
if !isempty(buf.tunable)
425427
i <= _num_subarrays(buf.tunable) && return _subarrays(buf.tunable)[i]
426428
i -= _num_subarrays(buf.tunable)
@@ -441,7 +443,7 @@ function Base.getindex(buf::MTKParameters, i)
441443
i <= _num_subarrays(buf.dependent) && return _subarrays(buf.dependent)[i]
442444
i -= _num_subarrays(buf.dependent)
443445
end
444-
throw(BoundsError(buf, i))
446+
throw(BoundsError(buf, i_orig))
445447
end
446448
function Base.setindex!(p::MTKParameters, val, i)
447449
function _helper(buf)
@@ -525,9 +527,6 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
525527
for (i, val) in zip(input_idxs, p_small_inner)
526528
_set_parameter_unchecked!(p_big, val, i)
527529
end
528-
# tunable, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p_big)
529-
# tunable[input_idxs] .= p_small_inner
530-
# p_big = repack(tunable)
531530
return if pf isa SciMLBase.ParamJacobianWrapper
532531
buffer = Array{dualtype}(undef, size(pf.u))
533532
pf(buffer, p_big)
@@ -537,8 +536,6 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
537536
end
538537
end
539538
end
540-
# tunable, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
541-
# p_small = tunable[input_idxs]
542539
p_small = parameter_values.((p,), input_idxs)
543540
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
544541
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))

0 commit comments

Comments
 (0)