Skip to content

Commit c9678ab

Browse files
Merge pull request #466 from SciML/expr_outs
Function and Expr outs
2 parents 1402694 + ddd0188 commit c9678ab

File tree

9 files changed

+458
-16
lines changed

9 files changed

+458
-16
lines changed

src/ModelingToolkit.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,13 @@ include("systems/dependency_graphs.jl")
113113
include("latexify_recipes.jl")
114114
include("build_function.jl")
115115

116-
export ODESystem, ODEFunction
117-
export SDESystem, SDEFunction
116+
export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr
117+
export SDESystem, SDEFunction, SDEFunctionExpr, SDESystemExpr
118118
export JumpSystem
119-
export ODEProblem, SDEProblem, NonlinearProblem, OptimizationProblem, SteadyStateProblem
119+
export ODEProblem, SDEProblem
120+
export NonlinearProblem, NonlinearProblemExpr
121+
export OptimizationProblem, OptimizationProblemExpr
122+
export SteadyStateProblem, SteadyStateProblemExpr
120123
export JumpProblem, DiscreteProblem
121124
export NonlinearSystem, OptimizationSystem
122125
export ode_order_lowering

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,80 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
190190
syms = Symbol.(states(sys)))
191191
end
192192

193+
"""
194+
```julia
195+
function DiffEqBase.ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
196+
ps = parameters(sys);
197+
version = nothing, tgrad=false,
198+
jac = false, Wfact = false,
199+
sparse = false,
200+
kwargs...) where {iip}
201+
```
202+
203+
Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
204+
The arguments `dvs` and `ps` are used to set the order of the dependent
205+
variable and parameter vectors, respectively.
206+
"""
207+
struct ODEFunctionExpr{iip} end
208+
209+
function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
210+
ps = parameters(sys), u0 = nothing;
211+
version = nothing, tgrad=false,
212+
jac = false, Wfact = false,
213+
sparse = false,linenumbers = false,
214+
kwargs...) where {iip}
215+
216+
idx = iip ? 2 : 1
217+
f = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
218+
if tgrad
219+
_tgrad = generate_tgrad(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
220+
else
221+
_tgrad = :nothing
222+
end
223+
224+
if jac
225+
_jac = generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{true}, kwargs...)[idx]
226+
else
227+
_jac = :nothing
228+
end
229+
230+
if Wfact
231+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps; expression=Val{true}, kwargs...)
232+
_Wfact = tmp_Wfact[idx]
233+
_Wfact_t = tmp_Wfact_t[idx]
234+
else
235+
_Wfact,_Wfact_t = :nothing,:nothing
236+
end
237+
238+
M = calculate_massmatrix(sys)
239+
240+
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
241+
242+
ex = quote
243+
f = $f
244+
tgrad = $_tgrad
245+
jac = $_jac
246+
Wfact = $_Wfact
247+
Wfact_t = $_Wfact_t
248+
M = $_M
249+
250+
ODEFunction{$iip}(f,
251+
jac = jac,
252+
tgrad = tgrad,
253+
Wfact = Wfact,
254+
Wfact_t = Wfact_t,
255+
mass_matrix = M,
256+
syms = $(Symbol.(states(sys))))
257+
end
258+
!linenumbers ? striplines(ex) : ex
259+
end
260+
261+
262+
function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
263+
ODEFunctionExpr{true}(sys, args...; kwargs...)
264+
end
265+
266+
193267
function DiffEqBase.ODEProblem(sys::AbstractODESystem, args...; kwargs...)
194268
ODEProblem{true}(sys, args...; kwargs...)
195269
end
@@ -225,6 +299,51 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
225299
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
226300
end
227301

302+
"""
303+
```julia
304+
function DiffEqBase.ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
305+
parammap=DiffEqBase.NullParameters();
306+
version = nothing, tgrad=false,
307+
jac = false, Wfact = false,
308+
checkbounds = false, sparse = false,
309+
linenumbers = true, parallel=SerialForm(),
310+
kwargs...) where iip
311+
```
312+
313+
Generates a Julia expression for constructing an ODEProblem from an
314+
ODESystem and allows for automatically symbolically calculating
315+
numerical enhancements.
316+
"""
317+
struct ODEProblemExpr{iip} end
318+
319+
function ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
320+
parammap=DiffEqBase.NullParameters();
321+
version = nothing, tgrad=false,
322+
jac = false, Wfact = false,
323+
checkbounds = false, sparse = false,
324+
linenumbers = false, parallel=SerialForm(),
325+
kwargs...) where iip
326+
dvs = states(sys)
327+
ps = parameters(sys)
328+
u0 = varmap_to_vars(u0map,dvs)
329+
p = varmap_to_vars(parammap,ps)
330+
f = ODEFunctionExpr{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
331+
linenumbers=linenumbers,parallel=parallel,
332+
sparse=sparse)
333+
ex = quote
334+
f = $f
335+
u0 = $u0
336+
tspan = $tspan
337+
p = $p
338+
ODEProblem(f,u0,tspan,p;$(kwargs...))
339+
end
340+
!linenumbers ? striplines(ex) : ex
341+
end
342+
343+
function ODEProblemExpr(sys::AbstractODESystem, args...; kwargs...)
344+
ODEProblemExpr{true}(sys, args...; kwargs...)
345+
end
346+
228347

229348
### Enables Steady State Problems ###
230349
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -244,7 +363,7 @@ function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,tspan,
244363
Generates an SteadyStateProblem from an ODESystem and allows for automatically
245364
symbolically calculating numerical enhancements.
246365
"""
247-
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,
366+
function DiffEqBase.SteadyStateProblem{iip}(sys::AbstractODESystem,u0map,
248367
parammap=DiffEqBase.NullParameters();
249368
version = nothing, tgrad=false,
250369
jac = false, Wfact = false,
@@ -260,3 +379,46 @@ function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,
260379
sparse=sparse)
261380
SteadyStateProblem(f,u0,p;kwargs...)
262381
end
382+
383+
"""
384+
```julia
385+
function DiffEqBase.SteadyStateProblemExpr(sys::AbstractODESystem,u0map,tspan,
386+
parammap=DiffEqBase.NullParameters();
387+
version = nothing, tgrad=false,
388+
jac = false, Wfact = false,
389+
checkbounds = false, sparse = false,
390+
linenumbers = true, parallel=SerialForm(),
391+
kwargs...) where iip
392+
```
393+
Generates a Julia expression for building a SteadyStateProblem from
394+
an ODESystem and allows for automatically symbolically calculating
395+
numerical enhancements.
396+
"""
397+
struct SteadyStateProblemExpr{iip} end
398+
399+
function SteadyStateProblemExpr{iip}(sys::AbstractODESystem,u0map,
400+
parammap=DiffEqBase.NullParameters();
401+
version = nothing, tgrad=false,
402+
jac = false, Wfact = false,
403+
checkbounds = false, sparse = false,
404+
linenumbers = true, parallel=SerialForm(),
405+
kwargs...) where iip
406+
dvs = states(sys)
407+
ps = parameters(sys)
408+
u0 = varmap_to_vars(u0map,dvs)
409+
p = varmap_to_vars(parammap,ps)
410+
f = ODEFunctionExpr(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
411+
linenumbers=linenumbers,parallel=parallel,
412+
sparse=sparse)
413+
ex = quote
414+
f = $f
415+
u0 = $u0
416+
p = $p
417+
SteadyStateProblem(f,u0,p;$(kwargs...))
418+
end
419+
!linenumbers ? striplines(ex) : ex
420+
end
421+
422+
function SteadyStateProblemExpr(sys::AbstractODESystem, args...; kwargs...)
423+
SteadyStateProblemExpr{true}(sys, args...; kwargs...)
424+
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,83 @@ function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...)
154154
SDEFunction{true}(sys, args...; kwargs...)
155155
end
156156

157+
"""
158+
```julia
159+
function DiffEqBase.SDEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
160+
ps = parameters(sys);
161+
version = nothing, tgrad=false,
162+
jac = false, Wfact = false,
163+
sparse = false,
164+
kwargs...) where {iip}
165+
```
166+
167+
Create a Julia expression for an `SDEFunction` from the [`SDESystem`](@ref).
168+
The arguments `dvs` and `ps` are used to set the order of the dependent
169+
variable and parameter vectors, respectively.
170+
"""
171+
struct SDEFunctionExpr{iip} end
172+
173+
function SDEFunctionExpr{iip}(sys::SDESystem, dvs = states(sys),
174+
ps = parameters(sys), u0 = nothing;
175+
version = nothing, tgrad=false,
176+
jac = false, Wfact = false,
177+
sparse = false,linenumbers = false,
178+
kwargs...) where {iip}
179+
180+
idx = iip ? 2 : 1
181+
f = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
182+
g = generate_diffusion_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
183+
if tgrad
184+
_tgrad = generate_tgrad(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
185+
else
186+
_tgrad = :nothing
187+
end
188+
189+
if jac
190+
_jac = generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{true}, kwargs...)[idx]
191+
else
192+
_jac = :nothing
193+
end
194+
195+
if Wfact
196+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps; expression=Val{true}, kwargs...)
197+
_Wfact = tmp_Wfact[idx]
198+
_Wfact_t = tmp_Wfact_t[idx]
199+
else
200+
_Wfact,_Wfact_t = :nothing,:nothing
201+
end
202+
203+
M = calculate_massmatrix(sys)
204+
205+
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
206+
207+
ex = quote
208+
f = $f
209+
g = $g
210+
tgrad = $_tgrad
211+
jac = $_jac
212+
Wfact = $_Wfact
213+
Wfact_t = $_Wfact_t
214+
M = $_M
215+
216+
SDEFunction{$iip}(f,g,
217+
jac = jac,
218+
tgrad = tgrad,
219+
Wfact = Wfact,
220+
Wfact_t = Wfact_t,
221+
mass_matrix = M,
222+
syms = $(Symbol.(states(sys))))
223+
end
224+
!linenumbers ? striplines(ex) : ex
225+
end
226+
227+
228+
function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
229+
SDEFunctionExpr{true}(sys, args...; kwargs...)
230+
end
231+
157232
function rename(sys::SDESystem,name)
158-
ODESystem(sys.eqs, sys.noiseeqs, sys.iv, sys.states, sys.ps, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems)
233+
SDESystem(sys.eqs, sys.noiseeqs, sys.iv, sys.states, sys.ps, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems)
159234
end
160235

161236
"""
@@ -203,3 +278,58 @@ end
203278
function DiffEqBase.SDEProblem(sys::SDESystem, args...; kwargs...)
204279
SDEProblem{true}(sys, args...; kwargs...)
205280
end
281+
282+
"""
283+
```julia
284+
function DiffEqBase.SDEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
285+
parammap=DiffEqBase.NullParameters();
286+
version = nothing, tgrad=false,
287+
jac = false, Wfact = false,
288+
checkbounds = false, sparse = false,
289+
linenumbers = true, parallel=SerialForm(),
290+
kwargs...) where iip
291+
```
292+
293+
Generates a Julia expression for constructing an ODEProblem from an
294+
ODESystem and allows for automatically symbolically calculating
295+
numerical enhancements.
296+
"""
297+
struct SDEProblemExpr{iip} end
298+
299+
function SDEProblemExpr{iip}(sys::SDESystem,u0map,tspan,
300+
parammap=DiffEqBase.NullParameters();
301+
version = nothing, tgrad=false,
302+
jac = false, Wfact = false,
303+
checkbounds = false, sparse = false,
304+
linenumbers = false, parallel=SerialForm(),
305+
kwargs...) where iip
306+
dvs = states(sys)
307+
ps = parameters(sys)
308+
u0 = varmap_to_vars(u0map,dvs)
309+
p = varmap_to_vars(parammap,ps)
310+
f = SDEFunctionExpr{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,
311+
Wfact=Wfact,checkbounds=checkbounds,
312+
linenumbers=linenumbers,parallel=parallel,
313+
sparse=sparse)
314+
if typeof(sys.noiseeqs) <: AbstractVector
315+
noise_rate_prototype = nothing
316+
elseif sparsenoise
317+
I,J,V = findnz(SparseArrays.sparse(sys.noiseeqs))
318+
noise_rate_prototype = SparseArrays.sparse(I,J,zero(eltype(u0)))
319+
else
320+
noise_rate_prototype = zeros(eltype(u0),size(sys.noiseeqs))
321+
end
322+
ex = quote
323+
f = $f
324+
u0 = $u0
325+
tspan = $tspan
326+
p = $p
327+
noise_rate_prototype = $noise_rate_prototype
328+
SDEProblem(f,f.g,u0,tspan,p;noise_rate_prototype=noise_rate_prototype,$(kwargs...))
329+
end
330+
!linenumbers ? striplines(ex) : ex
331+
end
332+
333+
function SDEProblemExpr(sys::SDESystem, args...; kwargs...)
334+
SDEProblemExpr{true}(sys, args...; kwargs...)
335+
end

0 commit comments

Comments
 (0)