Skip to content

Commit ada31d0

Browse files
feat: relax type restrictions in MTKParameters construction
1 parent e1befe0 commit ada31d0

File tree

3 files changed

+47
-14
lines changed

3 files changed

+47
-14
lines changed

src/systems/index_cache.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function IndexCache(sys::AbstractSystem)
7979

8080
function insert_by_type!(buffers::Dict{Any, Set{BasicSymbolic}}, sym)
8181
sym = unwrap(sym)
82-
ctype = concrete_symtype(sym)
82+
ctype = symtype(sym)
8383
buf = get!(buffers, ctype, Set{BasicSymbolic}())
8484
push!(buf, sym)
8585
end
@@ -116,7 +116,7 @@ function IndexCache(sys::AbstractSystem)
116116

117117
for p in parameters(sys)
118118
p = unwrap(p)
119-
ctype = concrete_symtype(p)
119+
ctype = symtype(p)
120120
haskey(disc_buffers, ctype) && p in disc_buffers[ctype] && continue
121121
haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue
122122
insert_by_type!(
@@ -312,9 +312,3 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
312312
end
313313
return result
314314
end
315-
316-
concrete_symtype(x::BasicSymbolic) = concrete_symtype(symtype(x))
317-
concrete_symtype(::Type{Real}) = Float64
318-
concrete_symtype(::Type{Integer}) = Int
319-
concrete_symtype(::Type{A}) where {T, N, A <: Array{T, N}} = Array{concrete_symtype(T), N}
320-
concrete_symtype(::Type{T}) where {T} = T

src/systems/parameter_buffer.jl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x)
22
symconvert(::Type{T}, x) where {T} = convert(T, x)
3+
symconvert(::Type{Real}, x::Integer) = convert(Float64, x)
4+
symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x))
5+
36
struct MTKParameters{T, D, C, E, N, F, G}
47
tunable::T
58
discrete::D
@@ -107,7 +110,7 @@ function MTKParameters(
107110
for (sym, val) in p
108111
sym = unwrap(sym)
109112
val = unwrap(val)
110-
ctype = concrete_symtype(sym)
113+
ctype = symtype(sym)
111114
if symbolic_type(val) !== NotSymbolic()
112115
continue
113116
end
@@ -126,19 +129,27 @@ function MTKParameters(
126129
end
127130
end
128131
end
132+
tunable_buffer = narrow_buffer_type.(tunable_buffer)
133+
disc_buffer = narrow_buffer_type.(disc_buffer)
134+
const_buffer = narrow_buffer_type.(const_buffer)
135+
nonnumeric_buffer = narrow_buffer_type.(nonnumeric_buffer)
129136

130137
if has_parameter_dependencies(sys) &&
131138
(pdeps = get_parameter_dependencies(sys)) !== nothing
132139
pdeps = Dict(k => fixpoint_sub(v, pdeps) for (k, v) in pdeps)
133-
dep_exprs = ArrayPartition((wrap.(v) for v in dep_buffer)...)
140+
dep_exprs = ArrayPartition((Any[missing for _ in 1:length(v)] for v in dep_buffer)...)
134141
for (sym, val) in pdeps
135142
i, j = ic.dependent_idx[sym]
136143
dep_exprs.x[i][j] = wrap(val)
137144
end
145+
dep_exprs = identity.(dep_exprs)
138146
p = reorder_parameters(ic, full_parameters(sys))
139147
oop, iip = build_function(dep_exprs, p...)
140148
update_function_iip, update_function_oop = RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(iip),
141149
RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(oop)
150+
update_function_iip(ArrayPartition(dep_buffer), tunable_buffer..., disc_buffer...,
151+
const_buffer..., nonnumeric_buffer..., dep_buffer...)
152+
dep_buffer = narrow_buffer_type.(dep_buffer)
142153
else
143154
update_function_iip = update_function_oop = nothing
144155
end
@@ -148,12 +159,26 @@ function MTKParameters(
148159
typeof(dep_buffer), typeof(nonnumeric_buffer), typeof(update_function_iip),
149160
typeof(update_function_oop)}(tunable_buffer, disc_buffer, const_buffer, dep_buffer,
150161
nonnumeric_buffer, update_function_iip, update_function_oop)
151-
if mtkps.dependent_update_iip !== nothing
152-
mtkps.dependent_update_iip(ArrayPartition(mtkps.dependent), mtkps...)
153-
end
154162
return mtkps
155163
end
156164

165+
function narrow_buffer_type(buffer::AbstractArray)
166+
type = Union{}
167+
for x in buffer
168+
type = promote_type(type, typeof(x))
169+
end
170+
return convert.(type, buffer)
171+
end
172+
173+
function narrow_buffer_type(buffer::AbstractArray{<:AbstractArray})
174+
buffer = narrow_buffer_type.(buffer)
175+
type = Union{}
176+
for x in buffer
177+
type = promote_type(type, eltype(x))
178+
end
179+
return broadcast.(convert, type, buffer)
180+
end
181+
157182
function buffer_to_arraypartition(buf)
158183
return ArrayPartition(ntuple(i -> _buffer_to_arrp_helper(buf[i]), Val(length(buf))))
159184
end

test/odesystem.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using DiffEqBase, SparseArrays
66
using StaticArrays
77
using Test
88
using SymbolicUtils: issym
9-
9+
using ForwardDiff
1010
using ModelingToolkit: value
1111
using ModelingToolkit: t_nounits as t, D_nounits as D
1212

@@ -1168,3 +1168,17 @@ end
11681168
@named sys = ODESystem(Equation[], t)
11691169
@test getname(unknowns(sys, x)) == :sys₊x
11701170
@test size(unknowns(sys, x)) == size(x)
1171+
1172+
# Issue#2667
1173+
@testset "ForwardDiff through ODEProblem constructor" begin
1174+
@parameters P
1175+
@variables x(t)
1176+
sys = structural_simplify(ODESystem([D(x) ~ P], t, [x], [P]; name = :sys))
1177+
1178+
function x_at_1(P)
1179+
prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0), [sys.P => P])
1180+
return solve(prob, Tsit5())(1.0)
1181+
end
1182+
1183+
@test_nowarn ForwardDiff.derivative(P -> x_at_1(P), 1.0)
1184+
end

0 commit comments

Comments
 (0)