Skip to content

Commit 0aa8314

Browse files
Merge pull request #353 from isaacsas/depgraphs
dependency graphs for JumpSystems
2 parents da1ce37 + b16a989 commit 0aa8314

File tree

9 files changed

+292
-31
lines changed

9 files changed

+292
-31
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1212
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1313
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
1414
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
15+
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1617
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1718
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"

src/ModelingToolkit.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ using RecursiveArrayTools
1818
import SymbolicUtils
1919
import SymbolicUtils: to_symbolic, FnType
2020

21+
import LightGraphs: SimpleDiGraph, add_edge!
22+
2123
import TreeViews
2224

2325
"""
@@ -101,6 +103,7 @@ include("systems/optimization/optimizationsystem.jl")
101103
include("systems/pde/pdesystem.jl")
102104

103105
include("systems/reaction/reactionsystem.jl")
106+
include("systems/dependency_graphs.jl")
104107

105108
include("latexify_recipes.jl")
106109
include("build_function.jl")
@@ -118,7 +121,7 @@ export Differential, expand_derivatives, @derivatives
118121
export IntervalDomain, ProductDomain, , CircleDomain
119122
export Equation, ConstrainedEquation
120123
export Operation, Expression, Variable
121-
export independent_variable, states, parameters, equations
124+
export independent_variable, states, parameters, equations
122125

123126
export calculate_jacobian, generate_jacobian, generate_function
124127
export calculate_tgrad, generate_tgrad
@@ -127,6 +130,10 @@ export calculate_factorized_W, generate_factorized_W
127130
export calculate_hessian, generate_hessian
128131
export calculate_massmatrix, generate_diffusion_function
129132

133+
export BipartiteGraph, equation_dependencies, variable_dependencies
134+
export eqeq_dependencies, varvar_dependencies
135+
export asgraph, asdigraph
136+
130137
export simplified_expr, rename, get_variables
131138
export simplify, substitute
132139
export build_function

src/systems/dependency_graphs.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# variables equations depend on as a vector of vectors of variables
2+
# each system type should define extract_variables! for a single equation
3+
function equation_dependencies(sys::AbstractSystem; variables=states(sys))
4+
eqs = equations(sys)
5+
deps = Set{Variable}()
6+
depeqs_to_vars = Vector{Vector{Variable}}(undef,length(eqs))
7+
8+
for (i,eq) in enumerate(eqs)
9+
depeqs_to_vars[i] = collect(get_variables!(deps, eq, variables))
10+
empty!(deps)
11+
end
12+
13+
depeqs_to_vars
14+
end
15+
16+
# modeled on LightGraphs SimpleGraph
17+
mutable struct BipartiteGraph{T <: Integer}
18+
ne::Int
19+
fadjlist::Vector{Vector{T}} # fadjlist[src] = [dest1,dest2,...]
20+
badjlist::Vector{Vector{T}} # badjlist[dst] = [src1,src2,...]
21+
end
22+
23+
# convert equation-variable dependencies to a bipartite graph
24+
function asgraph(eqdeps, vtois)
25+
fadjlist = Vector{Vector{Int}}(undef, length(eqdeps))
26+
for (i,dep) in enumerate(eqdeps)
27+
fadjlist[i] = sort!([vtois[var] for var in dep])
28+
end
29+
30+
badjlist = [Vector{Int}() for i = 1:length(vtois)]
31+
ne = 0
32+
for (eqidx,vidxs) in enumerate(fadjlist)
33+
foreach(vidx -> push!(badjlist[vidx], eqidx), vidxs)
34+
ne += length(vidxs)
35+
end
36+
37+
BipartiteGraph(ne, fadjlist, badjlist)
38+
end
39+
40+
function Base.isequal(bg1::BipartiteGraph{T}, bg2::BipartiteGraph{T}) where {T<:Integer}
41+
iseq = (bg1.ne == bg2.ne)
42+
iseq &= (bg1.fadjlist == bg2.fadjlist)
43+
iseq &= (bg1.badjlist == bg2.badjlist)
44+
iseq
45+
end
46+
47+
# could be made to directly generate graph and save memory
48+
function asgraph(sys::AbstractSystem; variables=nothing, variablestoids=nothing)
49+
vs = isnothing(variables) ? states(sys) : variables
50+
eqdeps = equation_dependencies(sys, variables=vs)
51+
vtois = isnothing(variablestoids) ? Dict(convert(Variable, v) => i for (i,v) in enumerate(vs)) : variablestoids
52+
asgraph(eqdeps, vtois)
53+
end
54+
55+
# for each variable determine the equations that modify it
56+
function variable_dependencies(sys::AbstractSystem; variables=states(sys), variablestoids=nothing)
57+
eqs = equations(sys)
58+
vtois = isnothing(variablestoids) ? Dict(convert(Variable, v) => i for (i,v) in enumerate(variables)) : variablestoids
59+
60+
deps = Set{Variable}()
61+
badjlist = Vector{Vector{Int}}(undef, length(eqs))
62+
for (eidx,eq) in enumerate(eqs)
63+
modified_states!(deps, eq, variables)
64+
badjlist[eidx] = sort!([vtois[var] for var in deps])
65+
empty!(deps)
66+
end
67+
68+
fadjlist = [Vector{Int}() for i = 1:length(variables)]
69+
ne = 0
70+
for (eqidx,vidxs) in enumerate(badjlist)
71+
foreach(vidx -> push!(fadjlist[vidx], eqidx), vidxs)
72+
ne += length(vidxs)
73+
end
74+
75+
BipartiteGraph(ne, fadjlist, badjlist)
76+
end
77+
78+
# convert BipartiteGraph to LightGraph.SimpleDiGraph
79+
function asdigraph(g::BipartiteGraph, sys::AbstractSystem; variables = states(sys), equationsfirst = true)
80+
neqs = length(equations(sys))
81+
nvars = length(variables)
82+
fadjlist = deepcopy(g.fadjlist)
83+
badjlist = deepcopy(g.badjlist)
84+
85+
# offset is for determining indices for the second set of vertices
86+
offset = equationsfirst ? neqs : nvars
87+
for i = 1:offset
88+
fadjlist[i] .+= offset
89+
end
90+
91+
# add empty rows for vertices without connections
92+
append!(fadjlist, [Vector{Int}() for i=1:(equationsfirst ? nvars : neqs)])
93+
prepend!(badjlist, [Vector{Int}() for i=1:(equationsfirst ? neqs : nvars)])
94+
95+
SimpleDiGraph(g.ne, fadjlist, badjlist)
96+
end
97+
98+
# maps the i'th eq to equations that depend on it
99+
function eqeq_dependencies(eqdeps::BipartiteGraph{T}, vardeps::BipartiteGraph{T}) where {T <: Integer}
100+
g = SimpleDiGraph{T}(length(eqdeps.fadjlist))
101+
102+
for (eqidx,sidxs) in enumerate(vardeps.badjlist)
103+
# states modified by eqidx
104+
for sidx in sidxs
105+
# equations depending on sidx
106+
foreach(v -> add_edge!(g, eqidx, v), eqdeps.badjlist[sidx])
107+
end
108+
end
109+
110+
g
111+
end
112+
113+
# maps the i'th variable to variables that depend on it
114+
varvar_dependencies(eqdeps::BipartiteGraph{T}, vardeps::BipartiteGraph{T}) where {T <: Integer} = eqeq_dependencies(vardeps, eqdeps)

src/systems/jumps/jumpsystem.jl

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump}
22

3-
struct JumpSystem <: AbstractSystem
4-
eqs::Vector{JumpType}
3+
struct JumpSystem{U <: ArrayPartition} <: AbstractSystem
4+
eqs::U
55
iv::Variable
66
states::Vector{Variable}
77
ps::Vector{Variable}
@@ -11,9 +11,22 @@ end
1111

1212
function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
1313
name = gensym(:JumpSystem))
14-
JumpSystem(eqs, iv, convert.(Variable, states), convert.(Variable, ps), name, systems)
15-
end
1614

15+
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
16+
for eq in eqs
17+
if eq isa MassActionJump
18+
push!(ap.x[1], eq)
19+
elseif eq isa ConstantRateJump
20+
push!(ap.x[2], eq)
21+
elseif eq isa VariableRateJump
22+
push!(ap.x[3], eq)
23+
else
24+
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.")
25+
end
26+
end
27+
28+
JumpSystem{typeof(ap)}(ap, convert(Variable,iv), convert.(Variable, states), convert.(Variable, ps), name, systems)
29+
end
1730

1831

1932
generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
@@ -26,6 +39,7 @@ generate_affect_function(js, affect, outputidxs) = build_function(affect, states
2639
expression=Val{false},
2740
headerfun=add_integrator_header,
2841
outputidxs=outputidxs)[2]
42+
2943
function assemble_vrj(js, vrj, statetoid)
3044
rate = generate_rate_function(js, vrj.rate)
3145
outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!)
@@ -84,10 +98,13 @@ Generates a DiscreteProblem from an AbstractSystem
8498
function DiffEqBase.DiscreteProblem(sys::AbstractSystem, u0map, tspan::Tuple,
8599
parammap=DiffEqBase.NullParameters(); kwargs...)
86100
u0 = varmap_to_vars(u0map, states(sys))
87-
p = varmap_to_vars(parammap, parameters(sys))
88-
DiscreteProblem(u0, tspan, p; kwargs...)
101+
p = varmap_to_vars(parammap, parameters(sys))
102+
f = (du,u,p,t) -> du.=u # identity function to make syms works
103+
df = DiscreteFunction(f, syms=Symbol.(states(sys)))
104+
DiscreteProblem(df, u0, tspan, p; kwargs...)
89105
end
90106

107+
91108
"""
92109
```julia
93110
function DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
@@ -96,25 +113,65 @@ function DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
96113
Generates a JumpProblem from a JumpSystem.
97114
"""
98115
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
99-
vrjs = Vector{VariableRateJump}()
100-
crjs = Vector{ConstantRateJump}()
101-
majs = Vector{MassActionJump}()
102-
pvars = parameters(js)
116+
103117
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
104-
parammap = map((x,y)->Pair(x(),y),pvars,prob.p)
105-
106-
for j in equations(js)
107-
if j isa ConstantRateJump
108-
push!(crjs, assemble_crj(js, j, statetoid))
109-
elseif j isa VariableRateJump
110-
push!(vrjs, assemble_vrj(js, j, statetoid))
111-
elseif j isa MassActionJump
112-
push!(majs, assemble_maj(js, j, statetoid, parammap))
113-
else
114-
error("JumpSystems should only contain Constant, Variable or Mass Action Jumps.")
115-
end
116-
end
117-
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
118+
parammap = map((x,y)->Pair(x(),y), parameters(js), prob.p)
119+
eqs = equations(js)
120+
121+
majs = MassActionJump[assemble_maj(js, j, statetoid, parammap) for j in eqs.x[1]]
122+
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
123+
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
124+
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
118125
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, isempty(majs) ? nothing : majs)
119-
JumpProblem(prob, aggregator, jset)
126+
127+
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator)
128+
jdeps = asgraph(js)
129+
vdeps = variable_dependencies(js)
130+
vtoj = jdeps.badjlist
131+
jtov = vdeps.badjlist
132+
jtoj = needs_depgraph(aggregator) ? eqeq_dependencies(jdeps, vdeps).fadjlist : nothing
133+
else
134+
vtoj = nothing; jtov = nothing; jtoj = nothing
135+
end
136+
137+
JumpProblem(prob, aggregator, jset; dep_graph=jtoj, vartojumps_map=vtoj, jumptovars_map=jtov)
120138
end
139+
140+
141+
### Functions to determine which states a jump depends on
142+
function get_variables!(dep, jump::Union{ConstantRateJump,VariableRateJump}, variables)
143+
foreach(var -> (var in variables) && push!(dep, var), vars(jump.rate))
144+
dep
145+
end
146+
147+
function get_variables!(dep, jump::MassActionJump, variables)
148+
jsr = jump.scaled_rates
149+
150+
if jsr isa Variable
151+
(jsr in variables) && push!(dep, jsr)
152+
elseif jsr isa Operation
153+
foreach(var -> (var in variables) && push!(dep, var), vars(jsr))
154+
end
155+
156+
for varasop in jump.reactant_stoch
157+
var = convert(Variable, varasop[1])
158+
(var in variables) && push!(dep, var)
159+
end
160+
161+
dep
162+
end
163+
164+
### Functions to determine which states are modified by a given jump
165+
function modified_states!(mstates, jump::Union{ConstantRateJump,VariableRateJump}, sts)
166+
for eq in jump.affect!
167+
st = convert(Variable, eq.lhs)
168+
(st in sts) && push!(mstates, st)
169+
end
170+
end
171+
172+
function modified_states!(mstates, jump::MassActionJump, sts)
173+
for (state,stoich) in jump.net_stoch
174+
st = convert(Variable, state)
175+
(st in sts) && push!(mstates, st)
176+
end
177+
end

src/systems/reaction/reactionsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ function jumpratelaw(rx)
149149
@unpack rate, substrates, substoich, only_use_rate = rx
150150
rl = deepcopy(rate)
151151
for op in get_variables(rx.rate)
152-
rl = substitute_expr!(rl,op=>var2op(op.op))
152+
rl = substitute(rl,op=>var2op(op.op))
153153
end
154154
if !only_use_rate
155155
for (i,stoich) in enumerate(substoich)

test/dep_graphs.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using ModelingToolkit, LightGraphs
2+
3+
# use a ReactionSystem to generate systems for testing
4+
@parameters k1 k2 t
5+
@variables S(t) I(t) R(t)
6+
7+
rxs = [Reaction(k1, nothing, [S]),
8+
Reaction(k1, [S], nothing),
9+
Reaction(k2, [S,I], [I], [1,1], [2]),
10+
Reaction(k2, [S,R], [S], [2,1], [2]),
11+
Reaction(k1*I, nothing, [R]),
12+
Reaction(k1*k2/(1+t), [S], [R])]
13+
rs = ReactionSystem(rxs, t, [S,I,R], [k1,k2])
14+
15+
16+
#################################
17+
# testing for Jumps
18+
#################################
19+
js = convert(JumpSystem, rs)
20+
S = convert(Variable,S); I = convert(Variable,I); R = convert(Variable,R)
21+
k1 = convert(Variable,k1); k2 = convert(Variable,k2)
22+
23+
# eq to vars they depend on
24+
eq_sdeps = [Variable[], [S], [S,I], [S,R], [I], [S]]
25+
eq_sidepsf = [Int[], [1], [1,2], [1,3], [2], [1]]
26+
eq_sidepsb = [[2,3,4,6], [3,5],[4]]
27+
deps = equation_dependencies(js)
28+
@test all(i -> isequal(Set(eq_sdeps[i]),Set(deps[i])), 1:length(rxs))
29+
depsbg = asgraph(js)
30+
@test depsbg.fadjlist == eq_sidepsf
31+
@test depsbg.badjlist == eq_sidepsb
32+
33+
# eq to params they depend on
34+
eq_pdeps = [[k1],[k1],[k2],[k2],[k1],[k1,k2]]
35+
eq_pidepsf = [[1],[1],[2],[2],[1],[1,2]]
36+
eq_pidepsb = [[1,2,5,6],[3,4,6]]
37+
deps = equation_dependencies(js, variables=parameters(js))
38+
@test all(i -> isequal(Set(eq_pdeps[i]),Set(deps[i])), 1:length(rxs))
39+
depsbg2 = asgraph(js, variables=parameters(js))
40+
@test depsbg2.fadjlist == eq_pidepsf
41+
@test depsbg2.badjlist == eq_pidepsb
42+
43+
# var to eqs that modify them
44+
s_eqdepsf = [[1,2,3,6],[3],[4,5,6]]
45+
s_eqdepsb = [[1],[1],[1,2],[3],[3],[1,3]]
46+
ne = 8
47+
bg = BipartiteGraph(ne, s_eqdepsf, s_eqdepsb)
48+
deps2 = variable_dependencies(js)
49+
@test isequal(bg,deps2)
50+
51+
# eq to eqs that depend on them
52+
eq_eqdeps = [[2,3,4,6],[2,3,4,6],[2,3,4,5,6],[4],[4],[2,3,4,6]]
53+
dg = SimpleDiGraph(6)
54+
for (eqidx,eqdeps) in enumerate(eq_eqdeps)
55+
for eqdepidx in eqdeps
56+
add_edge!(dg, eqidx, eqdepidx)
57+
end
58+
end
59+
dg3 = eqeq_dependencies(depsbg,deps2)
60+
@test dg == dg3
61+
62+
# var to vars that depend on them
63+
var_vardeps = [[1,2,3],[1,2,3],[3]]
64+
ne = 7
65+
dg = SimpleDiGraph(3)
66+
for (vidx,vdeps) in enumerate(var_vardeps)
67+
for vdepidx in vdeps
68+
add_edge!(dg, vidx, vdepidx)
69+
end
70+
end
71+
dg4 = varvar_dependencies(depsbg,deps2)
72+
@test dg == dg4
73+
74+

0 commit comments

Comments
 (0)