Skip to content

Commit d708e71

Browse files
feat: relax type restrictions in MTKParameters construction
1 parent a119d82 commit d708e71

File tree

3 files changed

+47
-14
lines changed

3 files changed

+47
-14
lines changed

src/systems/index_cache.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ function IndexCache(sys::AbstractSystem)
7777

7878
function insert_by_type!(buffers::Dict{Any, Set{BasicSymbolic}}, sym)
7979
sym = unwrap(sym)
80-
ctype = concrete_symtype(sym)
80+
ctype = symtype(sym)
8181
buf = get!(buffers, ctype, Set{BasicSymbolic}())
8282
push!(buf, sym)
8383
end
@@ -114,7 +114,7 @@ function IndexCache(sys::AbstractSystem)
114114

115115
for p in parameters(sys)
116116
p = unwrap(p)
117-
ctype = concrete_symtype(p)
117+
ctype = symtype(p)
118118
haskey(disc_buffers, ctype) && p in disc_buffers[ctype] && continue
119119
haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue
120120
insert_by_type!(
@@ -310,9 +310,3 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
310310
end
311311
return result
312312
end
313-
314-
concrete_symtype(x::BasicSymbolic) = concrete_symtype(symtype(x))
315-
concrete_symtype(::Type{Real}) = Float64
316-
concrete_symtype(::Type{Integer}) = Int
317-
concrete_symtype(::Type{A}) where {T, N, A <: Array{T, N}} = Array{concrete_symtype(T), N}
318-
concrete_symtype(::Type{T}) where {T} = T

src/systems/parameter_buffer.jl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x)
22
symconvert(::Type{T}, x) where {T} = convert(T, x)
3+
symconvert(::Type{Real}, x::Integer) = convert(Float64, x)
4+
symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x))
5+
36
struct MTKParameters{T, D, C, E, N, F, G}
47
tunable::T
58
discrete::D
@@ -107,7 +110,7 @@ function MTKParameters(
107110
for (sym, val) in p
108111
sym = unwrap(sym)
109112
val = unwrap(val)
110-
ctype = concrete_symtype(sym)
113+
ctype = symtype(sym)
111114
if symbolic_type(val) !== NotSymbolic()
112115
continue
113116
end
@@ -126,19 +129,27 @@ function MTKParameters(
126129
end
127130
end
128131
end
132+
tunable_buffer = narrow_buffer_type.(tunable_buffer)
133+
disc_buffer = narrow_buffer_type.(disc_buffer)
134+
const_buffer = narrow_buffer_type.(const_buffer)
135+
nonnumeric_buffer = narrow_buffer_type.(nonnumeric_buffer)
129136

130137
if has_parameter_dependencies(sys) &&
131138
(pdeps = get_parameter_dependencies(sys)) !== nothing
132139
pdeps = Dict(k => fixpoint_sub(v, pdeps) for (k, v) in pdeps)
133-
dep_exprs = ArrayPartition((wrap.(v) for v in dep_buffer)...)
140+
dep_exprs = ArrayPartition((Any[missing for _ in 1:length(v)] for v in dep_buffer)...)
134141
for (sym, val) in pdeps
135142
i, j = ic.dependent_idx[sym]
136143
dep_exprs.x[i][j] = wrap(val)
137144
end
145+
dep_exprs = identity.(dep_exprs)
138146
p = reorder_parameters(ic, full_parameters(sys))
139147
oop, iip = build_function(dep_exprs, p...)
140148
update_function_iip, update_function_oop = RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(iip),
141149
RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(oop)
150+
update_function_iip(ArrayPartition(dep_buffer), tunable_buffer..., disc_buffer...,
151+
const_buffer..., nonnumeric_buffer..., dep_buffer...)
152+
dep_buffer = narrow_buffer_type.(dep_buffer)
142153
else
143154
update_function_iip = update_function_oop = nothing
144155
end
@@ -148,12 +159,26 @@ function MTKParameters(
148159
typeof(dep_buffer), typeof(nonnumeric_buffer), typeof(update_function_iip),
149160
typeof(update_function_oop)}(tunable_buffer, disc_buffer, const_buffer, dep_buffer,
150161
nonnumeric_buffer, update_function_iip, update_function_oop)
151-
if mtkps.dependent_update_iip !== nothing
152-
mtkps.dependent_update_iip(ArrayPartition(mtkps.dependent), mtkps...)
153-
end
154162
return mtkps
155163
end
156164

165+
function narrow_buffer_type(buffer::AbstractArray)
166+
type = Union{}
167+
for x in buffer
168+
type = promote_type(type, typeof(x))
169+
end
170+
return convert.(type, buffer)
171+
end
172+
173+
function narrow_buffer_type(buffer::AbstractArray{<:AbstractArray})
174+
buffer = narrow_buffer_type.(buffer)
175+
type = Union{}
176+
for x in buffer
177+
type = promote_type(type, eltype(x))
178+
end
179+
return broadcast.(convert, type, buffer)
180+
end
181+
157182
function buffer_to_arraypartition(buf)
158183
return ArrayPartition(ntuple(i -> _buffer_to_arrp_helper(buf[i]), Val(length(buf))))
159184
end

test/odesystem.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using DiffEqBase, SparseArrays
66
using StaticArrays
77
using Test
88
using SymbolicUtils: issym
9-
9+
using ForwardDiff
1010
using ModelingToolkit: value
1111
using ModelingToolkit: t_nounits as t, D_nounits as D
1212

@@ -1159,3 +1159,17 @@ for sys in [sys1, sys2]
11591159
@test variable_index(sys, x[i]) == variable_index(sys, x)[i]
11601160
end
11611161
end
1162+
1163+
# Issue#2667
1164+
@testset "ForwardDiff through ODEProblem constructor" begin
1165+
@parameters P
1166+
@variables x(t)
1167+
sys = structural_simplify(ODESystem([D(x) ~ P], t, [x], [P]; name = :sys))
1168+
1169+
function x_at_1(P)
1170+
prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0), [sys.P => P])
1171+
return solve(prob)(1.0)
1172+
end
1173+
1174+
@test_nowarn ForwardDiff.derivative(P -> x_at_1(P), 1.0)
1175+
end

0 commit comments

Comments
 (0)