Skip to content

Commit 80def4c

Browse files
Merge pull request #2825 from SciML/eval
Make eval great again
2 parents 5f2a594 + adf98ba commit 80def4c

File tree

10 files changed

+204
-69
lines changed

10 files changed

+204
-69
lines changed

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pages = [
3131
"basics/MTKModel_Connector.md",
3232
"basics/Validation.md",
3333
"basics/DependencyGraphs.md",
34+
"basics/Precompilation.md",
3435
"basics/FAQ.md"],
3536
"System Types" => Any["systems/ODESystem.md",
3637
"systems/SDESystem.md",

docs/src/basics/Precompilation.md

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Working with Precompilation and Binary Building
2+
3+
## tl;dr, I just want precompilation to work
4+
5+
The tl;dr is, if you want to make precompilation work then instead of
6+
7+
```julia
8+
ODEProblem(sys, u0, tspan, p)
9+
```
10+
11+
use:
12+
13+
```julia
14+
ODEProblem(sys, u0, tspan, p, eval_module = @__MODULE__, eval_expression = true)
15+
```
16+
17+
As a full example, here's an example of a module that would precompile effectively:
18+
19+
```julia
20+
module PrecompilationMWE
21+
using ModelingToolkit
22+
23+
@variables x(ModelingToolkit.t_nounits)
24+
@named sys = ODESystem([ModelingToolkit.D_nounits(x) ~ -x + 1], ModelingToolkit.t_nounits)
25+
prob = ODEProblem(structural_simplify(sys), [x => 30.0], (0, 100), [],
26+
eval_expression = true, eval_module = @__MODULE__)
27+
28+
end
29+
```
30+
31+
If you use that in your package's code then 99% of the time that's the right answer to get
32+
precompilation working.
33+
34+
## I'm doing something fancier and need a bit more of an explanation
35+
36+
Oh you dapper soul, time for the bigger explanation. Julia's `eval` function evaluates a
37+
function into a module at a specified world-age. If you evaluate a function within a function
38+
and try to call it from within that same function, you will hit a world-age error. This looks like:
39+
40+
```julia
41+
function worldageerror()
42+
f = eval(:((x) -> 2x))
43+
f(2)
44+
end
45+
```
46+
47+
```
48+
julia> worldageerror()
49+
ERROR: MethodError: no method matching (::var"#5#6")(::Int64)
50+
51+
Closest candidates are:
52+
(::var"#5#6")(::Any) (method too new to be called from this world context.)
53+
@ Main REPL[12]:2
54+
```
55+
56+
This is done for many reasons, in particular if the code that is called within a function could change
57+
at any time, then Julia functions could not ever properly optimize because the meaning of any function
58+
or dispatch could always change and you would lose performance by guarding against that. For a full
59+
discussion of world-age, see [this paper](https://arxiv.org/abs/2010.07516).
60+
61+
However, this would be greatly inhibiting to standard ModelingToolkit usage because then something as
62+
simple as building an ODEProblem in a function and then using it would get a world age error:
63+
64+
```julia
65+
function wouldworldage()
66+
prob = ODEProblem(sys, [], (0.0, 1.0))
67+
sol = solve(prob)
68+
end
69+
```
70+
71+
The reason is because `prob.f` would be constructed via `eval`, and thus `prob.f` could not be called
72+
in the function, which means that no solve could ever work in the same function that generated the
73+
problem. That does mean that:
74+
75+
```julia
76+
function wouldworldage()
77+
prob = ODEProblem(sys, [], (0.0, 1.0))
78+
end
79+
sol = solve(prob)
80+
```
81+
82+
is fine, or putting
83+
84+
```julia
85+
prob = ODEProblem(sys, [], (0.0, 1.0))
86+
sol = solve(prob)
87+
```
88+
89+
at the top level of a module is perfectly fine too. They just cannot happen in the same function.
90+
91+
This would be a major limitation to ModelingToolkit, and thus we developed
92+
[RuntimeGeneratedFunctions](https://github.com/SciML/RuntimeGeneratedFunctions.jl) to get around
93+
this limitation. It will not be described beyond that, it is dark art and should not be investigated.
94+
But it does the job. But that does mean that it plays... oddly with Julia's compilation.
95+
96+
There are ways to force RuntimeGeneratedFunctions to perform their evaluation and caching within
97+
a given module, but that is not recommended because it does not play nicely with Julia v1.9's
98+
introduction of package images for binary caching.
99+
100+
Thus when trying to make things work with precompilation, we recommend using `eval`. This is
101+
done by simply adding `eval_expression=true` to the problem constructor. However, this is not
102+
a silver bullet because the moment you start using eval, all potential world-age restrictions
103+
apply, and thus it is recommended this is simply used for evaluating at the top level of modules
104+
for the purpose of precompilation and ensuring binaries of your MTK functions are built correctly.
105+
106+
However, there is one caveat that `eval` in Julia works depending on the module that it is given.
107+
If you have `MyPackage` that you are precompiling into, or say you are using `juliac` or PackageCompiler
108+
or some other static ahead-of-time (AOT) Julia compiler, then you don't want to accidentally `eval`
109+
that function to live in ModelingToolkit and instead want to make sure it is `eval`'d to live in `MyPackage`
110+
(since otherwise it will not cache into the binary). ModelingToolkit cannot know that in advance, and thus
111+
you have to pass in the module you wish for the functions to "live" in. This is done via the `eval_module`
112+
argument.
113+
114+
Hence `ODEProblem(sys, u0, tspan, p, eval_module=@__MODULE__, eval_expression=true)` will work if you
115+
are running this expression in the scope of the module you wish to be precompiling. However, if you are
116+
attempting to AOT compile a different module, this means that `eval_module` needs to be appropriately
117+
chosen. And, because `eval_expression=true`, all caveats of world-age apply.

src/systems/clock_inference.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ end
192192
function generate_discrete_affect(
193193
osys::AbstractODESystem, syss, inputs, continuous_id, id_to_clock;
194194
checkbounds = true,
195-
eval_module = @__MODULE__, eval_expression = true)
195+
eval_module = @__MODULE__, eval_expression = false)
196196
@static if VERSION < v"1.7"
197197
error("The `generate_discrete_affect` function requires at least Julia 1.7")
198198
end
@@ -412,15 +412,17 @@ function generate_discrete_affect(
412412
push!(svs, sv)
413413
end
414414
if eval_expression
415+
affects = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), affect_funs)
416+
inits = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), init_funs)
417+
else
415418
affects = map(affect_funs) do a
416-
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
419+
drop_expr(RuntimeGeneratedFunction(
420+
eval_module, eval_module, toexpr(LiteralExpr(a))))
417421
end
418422
inits = map(init_funs) do a
419-
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
423+
drop_expr(RuntimeGeneratedFunction(
424+
eval_module, eval_module, toexpr(LiteralExpr(a))))
420425
end
421-
else
422-
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
423-
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
424426
end
425427
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
426428
return affects, inits, clocks, svs, appended_parameters, defaults

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
313313
version = nothing, tgrad = false,
314314
jac = false, p = nothing,
315315
t = nothing,
316-
eval_expression = true,
316+
eval_expression = false,
317317
sparse = false, simplify = false,
318318
eval_module = @__MODULE__,
319319
steady_state = false,
@@ -327,12 +327,12 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
327327
if !iscomplete(sys)
328328
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
329329
end
330-
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
330+
f_gen = generate_function(sys, dvs, ps; expression = Val{true},
331331
expression_module = eval_module, checkbounds = checkbounds,
332332
kwargs...)
333-
f_oop, f_iip = eval_expression ?
334-
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
335-
f_gen
333+
f_oop, f_iip = eval_expression ? eval_module.eval.(f_gen) :
334+
(drop_expr(RuntimeGeneratedFunction(eval_module, eval_module, ex)) for ex in f_gen)
335+
336336
f(u, p, t) = f_oop(u, p, t)
337337
f(du, u, p, t) = f_iip(du, u, p, t)
338338
f(u, p::Tuple{Vararg{Number}}, t) = f_oop(u, p, t)
@@ -352,12 +352,12 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
352352
if tgrad
353353
tgrad_gen = generate_tgrad(sys, dvs, ps;
354354
simplify = simplify,
355-
expression = Val{eval_expression},
355+
expression = Val{true},
356356
expression_module = eval_module,
357357
checkbounds = checkbounds, kwargs...)
358-
tgrad_oop, tgrad_iip = eval_expression ?
359-
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in tgrad_gen) :
360-
tgrad_gen
358+
tgrad_oop, tgrad_iip = eval_expression ? eval_module.eval.(tgrad_gen) :
359+
(drop_expr(RuntimeGeneratedFunction(
360+
eval_module, eval_module, ex)) for ex in tgrad_gen)
361361
if p isa Tuple
362362
__tgrad(u, p, t) = tgrad_oop(u, p..., t)
363363
__tgrad(J, u, p, t) = tgrad_iip(J, u, p..., t)
@@ -374,12 +374,13 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
374374
if jac
375375
jac_gen = generate_jacobian(sys, dvs, ps;
376376
simplify = simplify, sparse = sparse,
377-
expression = Val{eval_expression},
377+
expression = Val{true},
378378
expression_module = eval_module,
379379
checkbounds = checkbounds, kwargs...)
380-
jac_oop, jac_iip = eval_expression ?
381-
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
382-
jac_gen
380+
jac_oop, jac_iip = eval_expression ? eval_module.eval.(jac_gen) :
381+
(drop_expr(RuntimeGeneratedFunction(
382+
eval_module, eval_module, ex)) for ex in jac_gen)
383+
383384
_jac(u, p, t) = jac_oop(u, p, t)
384385
_jac(J, u, p, t) = jac_iip(J, u, p, t)
385386
_jac(u, p::Tuple{Vararg{Number}}, t) = jac_oop(u, p, t)
@@ -474,7 +475,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
474475
ddvs = map(diff2term Differential(get_iv(sys)), dvs),
475476
version = nothing, p = nothing,
476477
jac = false,
477-
eval_expression = true,
478+
eval_expression = false,
478479
sparse = false, simplify = false,
479480
eval_module = @__MODULE__,
480481
checkbounds = false,
@@ -485,12 +486,11 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
485486
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
486487
end
487488
f_gen = generate_function(sys, dvs, ps; implicit_dae = true,
488-
expression = Val{eval_expression},
489+
expression = Val{true},
489490
expression_module = eval_module, checkbounds = checkbounds,
490491
kwargs...)
491-
f_oop, f_iip = eval_expression ?
492-
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
493-
f_gen
492+
f_oop, f_iip = eval_expression ? eval_module.eval.(f_gen) :
493+
(drop_expr(RuntimeGeneratedFunction(eval_module, eval_module, ex)) for ex in f_gen)
494494
f(du, u, p, t) = f_oop(du, u, p, t)
495495
f(du, u, p::MTKParameters, t) = f_oop(du, u, p..., t)
496496
f(out, du, u, p, t) = f_iip(out, du, u, p, t)
@@ -499,12 +499,13 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
499499
if jac
500500
jac_gen = generate_dae_jacobian(sys, dvs, ps;
501501
simplify = simplify, sparse = sparse,
502-
expression = Val{eval_expression},
502+
expression = Val{true},
503503
expression_module = eval_module,
504504
checkbounds = checkbounds, kwargs...)
505-
jac_oop, jac_iip = eval_expression ?
506-
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
507-
jac_gen
505+
jac_oop, jac_iip = eval_expression ? eval_module.eval.(jac_gen) :
506+
(drop_expr(RuntimeGeneratedFunction(
507+
eval_module, eval_module, ex)) for ex in jac_gen)
508+
508509
_jac(du, u, p, ˍ₋gamma, t) = jac_oop(du, u, p, ˍ₋gamma, t)
509510
_jac(du, u, p::MTKParameters, ˍ₋gamma, t) = jac_oop(du, u, p..., ˍ₋gamma, t)
510511

@@ -555,7 +556,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
555556
expression = Val{true},
556557
expression_module = eval_module, checkbounds = checkbounds,
557558
kwargs...)
558-
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
559+
f_oop, f_iip = (drop_expr(RuntimeGeneratedFunction(eval_module, eval_module, ex)) for ex in f_gen)
559560
f(u, h, p, t) = f_oop(u, h, p, t)
560561
f(u, h, p::MTKParameters, t) = f_oop(u, h, p..., t)
561562
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
@@ -580,7 +581,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
580581
expression = Val{true},
581582
expression_module = eval_module, checkbounds = checkbounds,
582583
kwargs...)
583-
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
584+
f_oop, f_iip = (drop_expr(RuntimeGeneratedFunction(eval_module, eval_module, ex)) for ex in f_gen)
584585
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
585586
isdde = true, kwargs...)
586587
g_oop, g_iip = (drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen)
@@ -770,7 +771,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
770771
checkbounds = false, sparse = false,
771772
simplify = false,
772773
linenumbers = true, parallel = SerialForm(),
773-
eval_expression = true,
774+
eval_expression = false,
774775
use_union = true,
775776
tofloat = true,
776777
symbolic_u0 = false,

src/systems/diffeqs/sdesystem.jl

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -407,21 +407,21 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
407407
ps = parameters(sys),
408408
u0 = nothing;
409409
version = nothing, tgrad = false, sparse = false,
410-
jac = false, Wfact = false, eval_expression = true,
410+
jac = false, Wfact = false, eval_expression = false,
411411
checkbounds = false,
412412
kwargs...) where {iip}
413413
if !iscomplete(sys)
414414
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
415415
end
416416
dvs = scalarize.(dvs)
417417

418-
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, kwargs...)
419-
f_oop, f_iip = eval_expression ?
420-
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in f_gen) : f_gen
421-
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{eval_expression},
418+
f_gen = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
419+
f_oop, f_iip = eval_expression ? eval_module.eval.(f_gen) :
420+
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in f_gen)
421+
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
422422
kwargs...)
423-
g_oop, g_iip = eval_expression ?
424-
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen) : g_gen
423+
g_oop, g_iip = eval_expression ? eval_module.eval.(g_gen) :
424+
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen)
425425

426426
f(u, p, t) = f_oop(u, p, t)
427427
f(u, p::MTKParameters, t) = f_oop(u, p..., t)
@@ -433,11 +433,11 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
433433
g(du, u, p::MTKParameters, t) = g_iip(du, u, p..., t)
434434

435435
if tgrad
436-
tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{eval_expression},
436+
tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{true},
437437
kwargs...)
438-
tgrad_oop, tgrad_iip = eval_expression ?
439-
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tgrad_gen) :
440-
tgrad_gen
438+
tgrad_oop, tgrad_iip = eval_expression ? eval_module.eval.(tgrad_gen) :
439+
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tgrad_gen)
440+
441441
_tgrad(u, p, t) = tgrad_oop(u, p, t)
442442
_tgrad(u, p::MTKParameters, t) = tgrad_oop(u, p..., t)
443443
_tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
@@ -447,11 +447,11 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
447447
end
448448

449449
if jac
450-
jac_gen = generate_jacobian(sys, dvs, ps; expression = Val{eval_expression},
450+
jac_gen = generate_jacobian(sys, dvs, ps; expression = Val{true},
451451
sparse = sparse, kwargs...)
452-
jac_oop, jac_iip = eval_expression ?
453-
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in jac_gen) :
454-
jac_gen
452+
jac_oop, jac_iip = eval_expression ? eval_module.eval.(jac_gen) :
453+
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in jac_gen)
454+
455455
_jac(u, p, t) = jac_oop(u, p, t)
456456
_jac(u, p::MTKParameters, t) = jac_oop(u, p..., t)
457457
_jac(J, u, p, t) = jac_iip(J, u, p, t)
@@ -463,12 +463,11 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
463463
if Wfact
464464
tmp_Wfact, tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true;
465465
expression = Val{true}, kwargs...)
466-
Wfact_oop, Wfact_iip = eval_expression ?
467-
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact) :
468-
tmp_Wfact
469-
Wfact_oop_t, Wfact_iip_t = eval_expression ?
470-
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact_t) :
471-
tmp_Wfact_t
466+
Wfact_oop, Wfact_iip = eval_expression ? eval_module.eval.(tmp_Wfact) :
467+
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact)
468+
Wfact_oop_t, Wfact_iip_t = eval_expression ? eval_module.eval.(tmp_Wfact_t) :
469+
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact_t)
470+
472471
_Wfact(u, p, dtgamma, t) = Wfact_oop(u, p, dtgamma, t)
473472
_Wfact(u, p::MTKParameters, dtgamma, t) = Wfact_oop(u, p..., dtgamma, t)
474473
_Wfact(W, u, p, dtgamma, t) = Wfact_iip(W, u, p, dtgamma, t)

0 commit comments

Comments
 (0)