Skip to content

Commit 0cc954d

Browse files
Merge pull request #2797 from AayushSabharwal/as/inplace-observed
feat: support inplace observed
2 parents 7cdb131 + 8ea472d commit 0cc954d

File tree

3 files changed

+45
-6
lines changed

3 files changed

+45
-6
lines changed

src/systems/abstractsystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
201201
end
202202
end
203203

204+
function wrap_assignments(isscalar, assignments; let_block = false)
205+
function wrapper(expr)
206+
Func(expr.args, [], Let(assignments, expr.body, let_block))
207+
end
208+
if isscalar
209+
wrapper
210+
else
211+
wrapper, wrapper
212+
end
213+
end
214+
204215
function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
205216
isscalar = !(exprs isa AbstractArray)
206217
array_vars = Dict{Any, AbstractArray{Int}}()

src/systems/diffeqs/odesystem.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ function build_explicit_observed_function(sys, ts;
382382
checkbounds = true,
383383
drop_expr = drop_expr,
384384
ps = full_parameters(sys),
385+
return_inplace = false,
385386
op = Operator,
386387
throw = true)
387388
if (isscalar = symbolic_type(ts) !== NotSymbolic())
@@ -479,16 +480,31 @@ function build_explicit_observed_function(sys, ts;
479480
if inputs === nothing
480481
args = [dvs, ps..., ivs...]
481482
else
482-
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
483+
ipts = DestructuredArgs(unwrap.(inputs), inbounds = !checkbounds)
483484
args = [dvs, ipts, ps..., ivs...]
484485
end
485486
pre = get_postprocess_fbody(sys)
486487

487-
ex = Func(args, [],
488-
pre(Let(obsexprs,
489-
isscalar ? ts[1] : MakeArray(ts, output_type),
490-
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
491-
expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
488+
# Need to keep old method of building the function since it uses `output_type`,
489+
# which can't be provided to `build_function`
490+
oop_fn = Func(args, [],
491+
pre(Let(obsexprs,
492+
isscalar ? ts[1] : MakeArray(ts, output_type),
493+
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
494+
oop_fn = expression ? oop_fn : drop_expr(@RuntimeGeneratedFunction(oop_fn))
495+
496+
if !isscalar
497+
iip_fn = build_function(ts,
498+
args...;
499+
postprocess_fbody = pre,
500+
wrap_code = wrap_array_vars(sys, ts) .∘ wrap_assignments(isscalar, obsexprs),
501+
expression = Val{expression})[2]
502+
end
503+
if isscalar || !return_inplace
504+
return oop_fn
505+
else
506+
return oop_fn, iip_fn
507+
end
492508
end
493509

494510
function _eq_unordered(a, b)

test/odesystem.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,3 +1182,15 @@ end
11821182

11831183
@test_nowarn ForwardDiff.derivative(P -> x_at_1(P), 1.0)
11841184
end
1185+
1186+
@testset "Inplace observed functions" begin
1187+
@parameters P
1188+
@variables x(t)
1189+
sys = structural_simplify(ODESystem([D(x) ~ P], t, [x], [P]; name = :sys))
1190+
obsfn = ModelingToolkit.build_explicit_observed_function(
1191+
sys, [x + 1, x + P, x + t], return_inplace = true)[2]
1192+
ps = ModelingToolkit.MTKParameters(sys, [P => 2.0])
1193+
buffer = zeros(3)
1194+
@test_nowarn obsfn(buffer, [1.0], ps..., 3.0)
1195+
@test buffer [2.0, 3.0, 4.0]
1196+
end

0 commit comments

Comments
 (0)