Skip to content

Commit 4ece767

Browse files
refactor: use common implementation of observedfun
1 parent 0d450de commit 4ece767

File tree

6 files changed

+38
-136
lines changed

6 files changed

+38
-136
lines changed

src/systems/abstractsystem.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,30 @@ end
11921192
###
11931193
### System utils
11941194
###
1195+
struct ObservedFunctionCache{S}
1196+
sys::S
1197+
dict::Dict{Any, Any}
1198+
end
1199+
1200+
function ObservedFunctionCache(sys)
1201+
return ObservedFunctionCache(sys, Dict())
1202+
let sys = sys, dict = Dict()
1203+
function generated_observed(obsvar, args...)
1204+
end
1205+
end
1206+
end
1207+
1208+
function (ofc::ObservedFunctionCache)(obsvar, args...)
1209+
obs = get!(ofc.dict, value(obsvar)) do
1210+
SymbolicIndexingInterface.observed(ofc.sys, obsvar)
1211+
end
1212+
if args === ()
1213+
return obs
1214+
else
1215+
return obs(args...)
1216+
end
1217+
end
1218+
11951219
function push_vars!(stmt, name, typ, vars)
11961220
isempty(vars) && return
11971221
vars_expr = Expr(:macrocall, typ, nothing)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 10 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -404,82 +404,25 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
404404

405405
obs = observed(sys)
406406
observedfun = if steady_state
407-
let sys = sys, dict = Dict(), ps = ps
407+
let sys = sys, dict = Dict()
408408
function generated_observed(obsvar, args...)
409409
obs = get!(dict, value(obsvar)) do
410-
build_explicit_observed_function(sys, obsvar)
410+
SymbolicIndexingInterface.observed(sys, obsvar)
411411
end
412412
if args === ()
413-
let obs = obs, ps_T = typeof(ps)
414-
(u, p, t = Inf) -> if p isa MTKParameters
415-
obs(u, p..., t)
416-
elseif ps_T <: Tuple
417-
obs(u, p..., t)
418-
else
419-
obs(u, p, t)
420-
end
413+
return let obs = obs
414+
fn1(u, p, t = Inf) = obs(u, p, t)
415+
fn1
421416
end
417+
elseif length(args) == 2
418+
return obs(args..., Inf)
422419
else
423-
if args[2] isa MTKParameters
424-
if length(args) == 2
425-
u, p = args
426-
obs(u, p..., Inf)
427-
else
428-
u, p, t = args
429-
obs(u, p..., t)
430-
end
431-
elseif ps isa Tuple
432-
if length(args) == 2
433-
u, p = args
434-
obs(u, p..., Inf)
435-
else
436-
u, p, t = args
437-
obs(u, p..., t)
438-
end
439-
else
440-
if length(args) == 2
441-
u, p = args
442-
obs(u, p, Inf)
443-
else
444-
u, p, t = args
445-
obs(u, p, t)
446-
end
447-
end
420+
return obs(args...)
448421
end
449422
end
450423
end
451424
else
452-
let sys = sys, dict = Dict(), ps = ps
453-
function generated_observed(obsvar, args...)
454-
obs = get!(dict, value(obsvar)) do
455-
build_explicit_observed_function(sys,
456-
obsvar;
457-
checkbounds = checkbounds,
458-
ps)
459-
end
460-
if args === ()
461-
let obs = obs, ps_T = typeof(ps)
462-
(u, p, t) -> if p isa MTKParameters
463-
obs(u, p..., t)
464-
elseif ps_T <: Tuple
465-
obs(u, p..., t)
466-
else
467-
obs(u, p, t)
468-
end
469-
end
470-
else
471-
u, p, t = args
472-
if p isa MTKParameters
473-
u, p, t = args
474-
obs(u, p..., t)
475-
elseif ps isa Tuple # split parameters
476-
obs(u, p..., t)
477-
else
478-
obs(args...)
479-
end
480-
end
481-
end
482-
end
425+
ObservedFunctionCache(sys)
483426
end
484427

485428
jac_prototype = if sparse
@@ -571,24 +514,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
571514
_jac = nothing
572515
end
573516

574-
obs = observed(sys)
575-
observedfun = let sys = sys, dict = Dict()
576-
function generated_observed(obsvar, args...)
577-
obs = get!(dict, value(obsvar)) do
578-
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
579-
end
580-
if args === ()
581-
let obs = obs
582-
fun(u, p, t) = obs(u, p, t)
583-
fun(u, p::MTKParameters, t) = obs(u, p..., t)
584-
fun
585-
end
586-
else
587-
u, p, t = args
588-
p isa MTKParameters ? obs(u, p..., t) : obs(u, p, t)
589-
end
590-
end
591-
end
517+
observedfun = ObservedFunctionCache(sys)
592518

593519
jac_prototype = if sparse
594520
uElType = u0 === nothing ? Float64 : eltype(u0)

src/systems/diffeqs/sdesystem.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -484,19 +484,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
484484
M = calculate_massmatrix(sys)
485485
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
486486

487-
obs = observed(sys)
488-
observedfun = let sys = sys, dict = Dict()
489-
function generated_observed(obsvar, u, p, t)
490-
obs = get!(dict, value(obsvar)) do
491-
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
492-
end
493-
if p isa MTKParameters
494-
obs(u, p..., t)
495-
else
496-
obs(u, p, t)
497-
end
498-
end
499-
end
487+
observedfun = ObservedFunctionCache(sys)
500488

501489
SDEFunction{iip}(f, g,
502490
sys = sys,

src/systems/discrete_system/discrete_system.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
330330
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
331331
end
332332

333-
observedfun = let sys = sys, dict = Dict()
334-
function generate_observed(obsvar, u, p, t)
335-
obs = get!(dict, value(obsvar)) do
336-
build_explicit_observed_function(sys, obsvar)
337-
end
338-
p isa MTKParameters ? obs(u, p..., t) : obs(u, p, t)
339-
end
340-
end
333+
observedfun = ObservedFunctionCache(sys)
341334

342335
DiscreteFunction{iip, specialize}(f;
343336
sys = sys,

src/systems/jumps/jumpsystem.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -345,16 +345,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
345345

346346
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
347347

348-
# just taken from abstractodesystem.jl for ODEFunction def
349-
obs = observed(sys)
350-
observedfun = let sys = sys, dict = Dict()
351-
function generated_observed(obsvar, u, p, t)
352-
obs = get!(dict, value(obsvar)) do
353-
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
354-
end
355-
p isa MTKParameters ? obs(u, p..., t) : obs(u, p, t)
356-
end
357-
end
348+
observedfun = ObservedFunctionCache(sys)
358349

359350
df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
360351
DiscreteProblem(df, u0, tspan, p; kwargs...)

src/systems/optimization/optimizationsystem.jl

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -337,27 +337,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
337337
hess_prototype = nothing
338338
end
339339

340-
observedfun = let sys = sys, dict = Dict()
341-
function generated_observed(obsvar, args...)
342-
obs = get!(dict, value(obsvar)) do
343-
build_explicit_observed_function(sys, obsvar)
344-
end
345-
if args === ()
346-
let obs = obs
347-
_obs(u, p) = obs(u, p)
348-
_obs(u, p::MTKParameters) = obs(u, p...)
349-
_obs
350-
end
351-
else
352-
u, p = args
353-
if p isa MTKParameters
354-
obs(u, p...)
355-
else
356-
obs(u, p)
357-
end
358-
end
359-
end
360-
end
340+
observedfun = ObservedFunctionCache(sys)
361341

362342
if length(cstr) > 0
363343
@named cons_sys = ConstraintsSystem(cstr, dvs, ps)

0 commit comments

Comments
 (0)