Skip to content

Commit f2197e3

Browse files
Merge pull request #2630 from AayushSabharwal/as/fix-arrayparam-initialization
fix: fix intialization of array parameters with unknown size
2 parents dde5b83 + f15fa79 commit f2197e3

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

src/systems/parameter_buffer.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ function MTKParameters(
4343
p = merge(defs, p)
4444
p = merge(Dict(unwrap(k) => v for (k, v) in p),
4545
Dict(default_toterm(unwrap(k)) => v for (k, v) in p))
46-
p = Dict(k => fixpoint_sub(v, p)
47-
for (k, v) in p if k in all_ps || default_toterm(k) in all_ps)
46+
p = Dict(k => fixpoint_sub(v, p) for (k, v) in p)
4847
for (sym, _) in p
4948
if istree(sym) && operation(sym) === getindex &&
5049
first(arguments(sym)) in all_ps
@@ -89,14 +88,24 @@ function MTKParameters(
8988

9089
for (sym, val) in p
9190
sym = unwrap(sym)
91+
val = unwrap(val)
9292
ctype = concrete_symtype(sym)
93-
val = symconvert(ctype, unwrap(fixpoint_sub(val, p)))
93+
if symbolic_type(val) !== NotSymbolic()
94+
continue
95+
end
96+
val = symconvert(ctype, val)
9497
done = set_value(sym, val)
9598
if !done && Symbolics.isarraysymbolic(sym)
96-
done = all(set_value.(collect(sym), val))
97-
end
98-
if !done
99-
error("Symbol $sym does not have an index")
99+
if Symbolics.shape(sym) === Symbolics.Unknown()
100+
for i in eachindex(val)
101+
set_value(sym[i], val[i])
102+
end
103+
else
104+
if size(sym) != size(val)
105+
error("Got value of size $(size(val)) for parameter $sym of size $(size(sym))")
106+
end
107+
set_value.(collect(sym), val)
108+
end
100109
end
101110
end
102111

@@ -341,6 +350,7 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
341350
if newbuf.dependent_update_oop !== nothing
342351
@set! newbuf.dependent = newbuf.dependent_update_oop(newbuf...)
343352
end
353+
return newbuf
344354
end
345355

346356
_subarrays(v::AbstractVector) = isempty(v) ? () : (v,)

test/mtkparameters.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,14 @@ function loss(value, sys, ps)
101101
end
102102

103103
@test ForwardDiff.derivative(x -> loss(x, sys, ps), 1.5) == 3.0
104+
105+
# Issue#2615
106+
@parameters p::Vector{Float64}
107+
@variables X(t)
108+
eq = D(X) ~ p[1] - p[2] * X
109+
@mtkbuild osys = ODESystem([eq], t)
110+
111+
u0 = [X => 1.0]
112+
ps = [p => [2.0, 0.1]]
113+
p = MTKParameters(osys, ps, u0)
114+
@test p.tunable[1] == [2.0, 0.1]

test/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ sol2 = @test_nowarn solve(prob2, Tsit5())
10241024
@test sol1 sol2
10251025

10261026
# Requires fix in symbolics for `linear_expansion(p * x, D(y))`
1027-
@test_broken begin
1027+
@test_skip begin
10281028
@variables x(t)[1:3] y(t)
10291029
@parameters p[1:3, 1:3]
10301030
@test_nowarn @mtkbuild sys = ODESystem([D(x) ~ p * x, D(y) ~ x' * p * x], t)

0 commit comments

Comments
 (0)