Skip to content

Commit 9ac6d78

Browse files
fix: fix intialization of array parameters with unknown size
1 parent 971359a commit 9ac6d78

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/systems/parameter_buffer.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ function MTKParameters(
1818
error("Cannot create MTKParameters if system does not have index_cache")
1919
end
2020
all_ps = Set(unwrap.(full_parameters(sys)))
21+
arr_ps = Set(arguments(p)[1] for p in all_ps if istree(p) && operation(p) === getindex)
22+
all_ps = union(all_ps, arr_ps)
2123
union!(all_ps, default_toterm.(unwrap.(full_parameters(sys))))
2224
if p isa Vector && !(eltype(p) <: Pair) && !isempty(p)
2325
ps = full_parameters(sys)
@@ -93,7 +95,14 @@ function MTKParameters(
9395
val = symconvert(ctype, unwrap(fixpoint_sub(val, p)))
9496
done = set_value(sym, val)
9597
if !done && Symbolics.isarraysymbolic(sym)
96-
done = all(set_value.(collect(sym), val))
98+
if Symbolics.shape(sym) === Symbolics.Unknown()
99+
done = all(set_value(sym[i], val[i]) for i in eachindex(val))
100+
else
101+
if size(sym) != size(val)
102+
error("Got value of size $(size(val)) for parameter $sym of size $(size(sym))")
103+
end
104+
done = all(set_value.(collect(sym), val))
105+
end
97106
end
98107
if !done
99108
error("Symbol $sym does not have an index")

test/mtkparameters.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,14 @@ function loss(value, sys, ps)
101101
end
102102

103103
@test ForwardDiff.derivative(x -> loss(x, sys, ps), 1.5) == 3.0
104+
105+
# Issue#2615
106+
@parameters p::Vector{Float64}
107+
@variables X(t)
108+
eq = D(X) ~ p[1] - p[2] * X
109+
@mtkbuild osys = ODESystem([eq], t)
110+
111+
u0 = [X => 1.0]
112+
ps = [p => [2.0, 0.1]]
113+
p = MTKParameters(osys, ps, u0)
114+
@test p.tunable[1] == [2.0, 0.1]

0 commit comments

Comments
 (0)