Skip to content

Commit 909aec1

Browse files
feat: allow build_explicit_function to generate param-only observed
1 parent 3898fce commit 909aec1

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ function build_explicit_observed_function(sys, ts;
380380
checkbounds = true,
381381
drop_expr = drop_expr,
382382
ps = full_parameters(sys),
383+
param_only = false,
383384
op = Operator,
384385
throw = true)
385386
if (isscalar = symbolic_type(ts) !== NotSymbolic())
@@ -392,16 +393,26 @@ function build_explicit_observed_function(sys, ts;
392393
ivs = independent_variables(sys)
393394
dep_vars = scalarize(setdiff(vars, ivs))
394395

395-
obs = observed(sys)
396+
obs = param_only ? Equation[] : observed(sys)
397+
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
398+
# each subsystem is topologically sorted independently. We can append the
399+
# equations to override the `lhs ~ 0` equations in `observed(sys)`
400+
syss, _, continuous_id, _... = dss
401+
for (i, subsys) in enumerate(syss)
402+
i == continuous_id && continue
403+
append!(obs, observed(subsys))
404+
end
405+
end
396406

397407
cs = collect_constants(obs)
398408
if !isempty(cs) > 0
399409
cmap = map(x -> x => getdefault(x), cs)
400410
obs = map(x -> x.lhs ~ substitute(x.rhs, cmap), obs)
401411
end
402412

403-
sts = Set(unknowns(sys))
404-
sts = union(sts,
413+
sts = param_only ? Set() : Set(unknowns(sys))
414+
sts = param_only ? Set() :
415+
union(sts,
405416
Set(arguments(st)[1] for st in sts if istree(st) && operation(st) === getindex))
406417

407418
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
@@ -413,7 +424,8 @@ function build_explicit_observed_function(sys, ts;
413424
Set(arguments(p)[1]
414425
for p in param_set_ns if istree(p) && operation(p) === getindex))
415426
namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs)
416-
namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys))
427+
namespaced_to_sts = param_only ? Dict() :
428+
Dict(unknowns(sys, x) => x for x in unknowns(sys))
417429

418430
# FIXME: This is a rather rough estimate of dependencies. We assume
419431
# the expression depends on everything before the `maxidx`.
@@ -472,17 +484,18 @@ function build_explicit_observed_function(sys, ts;
472484
end
473485
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
474486
if inputs === nothing
475-
args = [dvs, ps..., ivs...]
487+
args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...]
476488
else
477489
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
478-
args = [dvs, ipts, ps..., ivs...]
490+
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]
479491
end
480492
pre = get_postprocess_fbody(sys)
481493

482494
ex = Func(args, [],
483495
pre(Let(obsexprs,
484496
isscalar ? ts[1] : MakeArray(ts, output_type),
485-
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
497+
false))) |>
498+
wrap_array_vars(sys, ts; dvs = param_only ? [] : unknowns(sys))[1] |> toexpr
486499
expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
487500
end
488501

0 commit comments

Comments
 (0)