Skip to content

Commit 7c3817e

Browse files
fixup! feat: support inplace parameter observed
1 parent c79851f commit 7c3817e

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

src/systems/abstractsystem.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
201201
end
202202
end
203203

204+
function wrap_assignments(isscalar, assignments; let_block = false)
205+
function wrapper(expr)
206+
Func(expr.args, [], Let(assignments, expr.body, let_block))
207+
end
208+
if isscalar
209+
wrapper
210+
else
211+
wrapper, wrapper
212+
end
213+
end
214+
204215
function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
205216
isscalar = !(exprs isa AbstractArray)
206217
array_vars = Dict{Any, AbstractArray{Int}}()
@@ -505,7 +516,7 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
505516
ts_idx = nothing
506517
end
507518
rawobs = build_explicit_observed_function(
508-
sys, sym; param_only = true, return_inplace = true)
519+
sys, sym; param_only = true, return_inplace = true)
509520
if rawobs isa Tuple
510521
obsfn = let oop = rawobs[1], iip = rawobs[2]
511522
f1(p::MTKParameters, t) = oop(p..., t)

src/systems/diffeqs/odesystem.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,17 +487,23 @@ function build_explicit_observed_function(sys, ts;
487487
if inputs === nothing
488488
args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...]
489489
else
490-
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
490+
ipts = DestructuredArgs(unwrap.(inputs), inbounds = !checkbounds)
491491
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]
492492
end
493493
pre = get_postprocess_fbody(sys)
494-
res = build_function(isscalar ? ts[1] : ts, args...; get_postprocess_fbody = pre, wrap_code = wrap_array_vars(sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)), expression = Val{expression})
494+
res = build_function(isscalar ? ts[1] : ts,
495+
args...;
496+
postprocess_fbody = pre,
497+
wrap_code = wrap_array_vars(
498+
sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)) .∘
499+
wrap_assignments(isscalar, obsexprs),
500+
expression = Val{expression})
495501
if isscalar || return_inplace
496502
return res
497503
else
498504
return res[1]
499505
end
500-
506+
501507
ex = Func(args, [],
502508
pre(Let(obsexprs,
503509
isscalar ? ts[1] : MakeArray(ts, output_type),

0 commit comments

Comments
 (0)