Skip to content

Commit 8885137

Browse files
feat: allow build_explicit_function to generate param-only observed
1 parent d64f973 commit 8885137

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ function build_explicit_observed_function(sys, ts;
385385
drop_expr = drop_expr,
386386
ps = full_parameters(sys),
387387
return_inplace = false,
388+
param_only = false,
388389
op = Operator,
389390
throw = true)
390391
if (isscalar = symbolic_type(ts) !== NotSymbolic())
@@ -397,16 +398,26 @@ function build_explicit_observed_function(sys, ts;
397398
ivs = independent_variables(sys)
398399
dep_vars = scalarize(setdiff(vars, ivs))
399400

400-
obs = observed(sys)
401+
obs = param_only ? Equation[] : observed(sys)
402+
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
403+
# each subsystem is topologically sorted independently. We can append the
404+
# equations to override the `lhs ~ 0` equations in `observed(sys)`
405+
syss, _, continuous_id, _... = dss
406+
for (i, subsys) in enumerate(syss)
407+
i == continuous_id && continue
408+
append!(obs, observed(subsys))
409+
end
410+
end
401411

402412
cs = collect_constants(obs)
403413
if !isempty(cs) > 0
404414
cmap = map(x -> x => getdefault(x), cs)
405415
obs = map(x -> x.lhs ~ substitute(x.rhs, cmap), obs)
406416
end
407417

408-
sts = Set(unknowns(sys))
409-
sts = union(sts,
418+
sts = param_only ? Set() : Set(unknowns(sys))
419+
sts = param_only ? Set() :
420+
union(sts,
410421
Set(arguments(st)[1] for st in sts if iscall(st) && operation(st) === getindex))
411422

412423
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
@@ -418,7 +429,8 @@ function build_explicit_observed_function(sys, ts;
418429
Set(arguments(p)[1]
419430
for p in param_set_ns if iscall(p) && operation(p) === getindex))
420431
namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs)
421-
namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys))
432+
namespaced_to_sts = param_only ? Dict() :
433+
Dict(unknowns(sys, x) => x for x in unknowns(sys))
422434

423435
# FIXME: This is a rather rough estimate of dependencies. We assume
424436
# the expression depends on everything before the `maxidx`.
@@ -483,11 +495,11 @@ function build_explicit_observed_function(sys, ts;
483495
end
484496
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
485497
if inputs === nothing
486-
args = [dvs, ps..., ivs...]
498+
args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...]
487499
else
488500
inputs = unwrap.(inputs)
489501
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
490-
args = [dvs, ipts, ps..., ivs...]
502+
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]
491503
end
492504
pre = get_postprocess_fbody(sys)
493505

0 commit comments

Comments
 (0)