Skip to content

Commit 52a1d33

Browse files
feat: support inplace observed
1 parent ba04751 commit 52a1d33

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

src/systems/abstractsystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
201201
end
202202
end
203203

204+
function wrap_assignments(isscalar, assignments; let_block = false)
205+
function wrapper(expr)
206+
Func(expr.args, [], Let(assignments, expr.body, let_block))
207+
end
208+
if isscalar
209+
wrapper
210+
else
211+
wrapper, wrapper
212+
end
213+
end
214+
204215
function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
205216
isscalar = !(exprs isa AbstractArray)
206217
array_vars = Dict{Any, AbstractArray{Int}}()

src/systems/diffeqs/odesystem.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ function build_explicit_observed_function(sys, ts;
382382
checkbounds = true,
383383
drop_expr = drop_expr,
384384
ps = full_parameters(sys),
385+
return_inplace = false,
385386
op = Operator,
386387
throw = true)
387388
if (isscalar = symbolic_type(ts) !== NotSymbolic())
@@ -479,16 +480,21 @@ function build_explicit_observed_function(sys, ts;
479480
if inputs === nothing
480481
args = [dvs, ps..., ivs...]
481482
else
482-
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
483+
ipts = DestructuredArgs(unwrap.(inputs), inbounds = !checkbounds)
483484
args = [dvs, ipts, ps..., ivs...]
484485
end
485486
pre = get_postprocess_fbody(sys)
486-
487-
ex = Func(args, [],
488-
pre(Let(obsexprs,
489-
isscalar ? ts[1] : MakeArray(ts, output_type),
490-
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
491-
expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
487+
res = build_function(isscalar ? ts[1] : ts,
488+
args...;
489+
postprocess_fbody = pre,
490+
wrap_code = wrap_array_vars(
491+
sys, isscalar ? ts[1] : ts) .∘ wrap_assignments(isscalar, obsexprs),
492+
expression = Val{expression})
493+
if isscalar || return_inplace
494+
return res
495+
else
496+
return res[1]
497+
end
492498
end
493499

494500
function _eq_unordered(a, b)

test/odesystem.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,3 +1182,15 @@ end
11821182

11831183
@test_nowarn ForwardDiff.derivative(P -> x_at_1(P), 1.0)
11841184
end
1185+
1186+
@testset "Inplace observed functions" begin
1187+
@parameters P
1188+
@variables x(t)
1189+
sys = structural_simplify(ODESystem([D(x) ~ P], t, [x], [P]; name = :sys))
1190+
obsfn = ModelingToolkit.build_explicit_observed_function(
1191+
sys, [x + 1, x + P, x + t], return_inplace = true)
1192+
ps = ModelingToolkit.MTKParameters(sys, [P => 2.0])
1193+
buffer = zeros(3)
1194+
@test_nowarn obsfn(buffer, [1.0], ps..., 3.0)
1195+
@test buffer [2.0, 3.0, 4.0]
1196+
end

0 commit comments

Comments
 (0)