Skip to content

Commit 628de91

Browse files
fix: fix bug in remake_buffer
1 parent e1befe0 commit 628de91

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

src/systems/parameter_buffer.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ function MTKParameters(
130130
if has_parameter_dependencies(sys) &&
131131
(pdeps = get_parameter_dependencies(sys)) !== nothing
132132
pdeps = Dict(k => fixpoint_sub(v, pdeps) for (k, v) in pdeps)
133-
dep_exprs = ArrayPartition((wrap.(v) for v in dep_buffer)...)
133+
dep_exprs = ArrayPartition((Any[0 for _ in eachindex(v)] for v in dep_buffer)...)
134134
for (sym, val) in pdeps
135135
i, j = ic.dependent_idx[sym]
136-
dep_exprs.x[i][j] = wrap(val)
136+
dep_exprs.x[i][j] = unwrap(val)
137137
end
138138
p = reorder_parameters(ic, full_parameters(sys))
139139
oop, iip = build_function(dep_exprs, p...)
@@ -398,7 +398,10 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
398398
@set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.(
399399
oldbuf.nonnumeric, newbuf.nonnumeric)
400400
if newbuf.dependent_update_oop !== nothing
401-
@set! newbuf.dependent = newbuf.dependent_update_oop(newbuf...)
401+
@set! newbuf.dependent = narrow_buffer_type_and_fallback_undefs.(
402+
oldbuf.dependent,
403+
split_into_buffers(
404+
newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(false)))
402405
end
403406
return newbuf
404407
end
@@ -422,6 +425,7 @@ _num_subarrays(v::Tuple) = length(v)
422425
# getindex indexes the vectors, setindex! linearly indexes values
423426
# it's inconsistent, but we need it to be this way
424427
function Base.getindex(buf::MTKParameters, i)
428+
i_orig = i
425429
if !isempty(buf.tunable)
426430
i <= _num_subarrays(buf.tunable) && return _subarrays(buf.tunable)[i]
427431
i -= _num_subarrays(buf.tunable)
@@ -442,7 +446,7 @@ function Base.getindex(buf::MTKParameters, i)
442446
i <= _num_subarrays(buf.dependent) && return _subarrays(buf.dependent)[i]
443447
i -= _num_subarrays(buf.dependent)
444448
end
445-
throw(BoundsError(buf, i))
449+
throw(BoundsError(buf, i_orig))
446450
end
447451
function Base.setindex!(p::MTKParameters, val, i)
448452
function _helper(buf)
@@ -526,9 +530,6 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
526530
for (i, val) in zip(input_idxs, p_small_inner)
527531
_set_parameter_unchecked!(p_big, val, i)
528532
end
529-
# tunable, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p_big)
530-
# tunable[input_idxs] .= p_small_inner
531-
# p_big = repack(tunable)
532533
return if pf isa SciMLBase.ParamJacobianWrapper
533534
buffer = Array{dualtype}(undef, size(pf.u))
534535
pf(buffer, p_big)
@@ -538,8 +539,6 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
538539
end
539540
end
540541
end
541-
# tunable, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
542-
# p_small = tunable[input_idxs]
543542
p_small = parameter_values.((p,), input_idxs)
544543
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
545544
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))

test/mtkparameters.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,20 @@ function loss(x)
224224
end
225225

226226
@test_nowarn ForwardDiff.gradient(loss, collect(tunables))
227+
228+
# Ensure dependent parameters are `Tuple{...}` and not `ArrayPartition` when using
229+
# `remake_buffer`.
230+
@parameters p1 p2 p3[1:2] p4[1:2]
231+
@named sys = ODESystem(
232+
Equation[], t, [], [p1, p2, p3, p4]; parameter_dependencies = [p2 => 2p1, p4 => 3p3])
233+
sys = complete(sys)
234+
ps = MTKParameters(sys, [p1 => 1.0, p3 => [2.0, 3.0]])
235+
@test ps[parameter_index(sys, p2)] == 2.0
236+
@test ps[parameter_index(sys, p4)] == [6.0, 9.0]
237+
238+
newps = remake_buffer(
239+
sys, ps, Dict(p1 => ForwardDiff.Dual(2.0), p3 => ForwardDiff.Dual.([3.0, 4.0])))
240+
241+
VDual = Vector{<:ForwardDiff.Dual}
242+
VVDual = Vector{<:Vector{<:ForwardDiff.Dual}}
243+
@test newps.dependent isa Union{Tuple{VDual, VVDual}, Tuple{VVDual, VDual}}

0 commit comments

Comments
 (0)