Skip to content

Commit 04e8d12

Browse files
feat: allow build_explicit_function to generate param-only observed
1 parent 110ebfd commit 04e8d12

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ function build_explicit_observed_function(sys, ts;
387387
drop_expr = drop_expr,
388388
ps = full_parameters(sys),
389389
return_inplace = false,
390+
param_only = false,
390391
op = Operator,
391392
throw = true)
392393
if (isscalar = symbolic_type(ts) !== NotSymbolic())
@@ -399,16 +400,26 @@ function build_explicit_observed_function(sys, ts;
399400
ivs = independent_variables(sys)
400401
dep_vars = scalarize(setdiff(vars, ivs))
401402

402-
obs = observed(sys)
403+
obs = param_only ? Equation[] : observed(sys)
404+
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
405+
# each subsystem is topologically sorted independently. We can append the
406+
# equations to override the `lhs ~ 0` equations in `observed(sys)`
407+
syss, _, continuous_id, _... = dss
408+
for (i, subsys) in enumerate(syss)
409+
i == continuous_id && continue
410+
append!(obs, observed(subsys))
411+
end
412+
end
403413

404414
cs = collect_constants(obs)
405415
if !isempty(cs) > 0
406416
cmap = map(x -> x => getdefault(x), cs)
407417
obs = map(x -> x.lhs ~ substitute(x.rhs, cmap), obs)
408418
end
409419

410-
sts = Set(unknowns(sys))
411-
sts = union(sts,
420+
sts = param_only ? Set() : Set(unknowns(sys))
421+
sts = param_only ? Set() :
422+
union(sts,
412423
Set(arguments(st)[1] for st in sts if iscall(st) && operation(st) === getindex))
413424

414425
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
@@ -420,7 +431,8 @@ function build_explicit_observed_function(sys, ts;
420431
Set(arguments(p)[1]
421432
for p in param_set_ns if iscall(p) && operation(p) === getindex))
422433
namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs)
423-
namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys))
434+
namespaced_to_sts = param_only ? Dict() :
435+
Dict(unknowns(sys, x) => x for x in unknowns(sys))
424436

425437
# FIXME: This is a rather rough estimate of dependencies. We assume
426438
# the expression depends on everything before the `maxidx`.
@@ -485,11 +497,11 @@ function build_explicit_observed_function(sys, ts;
485497
end
486498
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
487499
if inputs === nothing
488-
args = [dvs, ps..., ivs...]
500+
args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...]
489501
else
490502
inputs = unwrap.(inputs)
491503
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
492-
args = [dvs, ipts, ps..., ivs...]
504+
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]
493505
end
494506
pre = get_postprocess_fbody(sys)
495507

0 commit comments

Comments
 (0)