Skip to content

Commit e721fe7

Browse files
feat: support inplace parameter observed
1 parent 7392a5d commit e721fe7

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/systems/abstractsystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
201201
end
202202
end
203203

204+
function wrap_assignments(isscalar, assignments; let_block = false)
205+
function wrapper(expr)
206+
Func(expr.args, [], Let(assignments, expr.body, let_block))
207+
end
208+
if isscalar
209+
wrapper
210+
else
211+
wrapper, wrapper
212+
end
213+
end
214+
204215
function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
205216
isscalar = !(exprs isa AbstractArray)
206217
array_vars = Dict{Any, AbstractArray{Int}}()

src/systems/diffeqs/odesystem.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ function build_explicit_observed_function(sys, ts;
381381
drop_expr = drop_expr,
382382
ps = full_parameters(sys),
383383
param_only = false,
384+
return_inplace = false,
384385
op = Operator,
385386
throw = true)
386387
if (isscalar = symbolic_type(ts) !== NotSymbolic())
@@ -486,10 +487,22 @@ function build_explicit_observed_function(sys, ts;
486487
if inputs === nothing
487488
args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...]
488489
else
489-
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
490+
ipts = DestructuredArgs(unwrap.(inputs), inbounds = !checkbounds)
490491
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]
491492
end
492493
pre = get_postprocess_fbody(sys)
494+
res = build_function(isscalar ? ts[1] : ts,
495+
args...;
496+
postprocess_fbody = pre,
497+
wrap_code = wrap_array_vars(
498+
sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)) .∘
499+
wrap_assignments(isscalar, obsexprs),
500+
expression = Val{expression})
501+
if isscalar || return_inplace
502+
return res
503+
else
504+
return res[1]
505+
end
493506

494507
ex = Func(args, [],
495508
pre(Let(obsexprs,

0 commit comments

Comments
 (0)