Skip to content

Commit 8bf2adc

Browse files
Merge pull request #377 from isaacsas/update-get-variables
update get_variables! for Equations
2 parents 3b2ea9a + c361295 commit 8bf2adc

File tree

5 files changed

+60
-31
lines changed

5 files changed

+60
-31
lines changed

src/systems/dependency_graphs.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
# each system type should define extract_variables! for a single equation
33
function equation_dependencies(sys::AbstractSystem; variables=states(sys))
44
eqs = equations(sys)
5-
deps = Set{Variable}()
5+
deps = Set{Operation}()
66
depeqs_to_vars = Vector{Vector{Variable}}(undef,length(eqs))
77

88
for (i,eq) in enumerate(eqs)
9-
depeqs_to_vars[i] = collect(get_variables!(deps, eq, variables))
9+
get_variables!(deps, eq, variables)
10+
depeqs_to_vars[i] = [convert(Variable,v) for v in deps]
1011
empty!(deps)
1112
end
1213

@@ -57,11 +58,11 @@ function variable_dependencies(sys::AbstractSystem; variables=states(sys), varia
5758
eqs = equations(sys)
5859
vtois = isnothing(variablestoids) ? Dict(convert(Variable, v) => i for (i,v) in enumerate(variables)) : variablestoids
5960

60-
deps = Set{Variable}()
61+
deps = Set{Operation}()
6162
badjlist = Vector{Vector{Int}}(undef, length(eqs))
6263
for (eidx,eq) in enumerate(eqs)
6364
modified_states!(deps, eq, variables)
64-
badjlist[eidx] = sort!([vtois[var] for var in deps])
65+
badjlist[eidx] = sort!([vtois[convert(Variable,var)] for var in deps])
6566
empty!(deps)
6667
end
6768

src/systems/jumps/jumpsystem.jl

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,39 +139,26 @@ end
139139

140140

141141
### Functions to determine which states a jump depends on
142-
function get_variables!(dep, jump::Union{ConstantRateJump,VariableRateJump}, variables)
143-
foreach(var -> (var in variables) && push!(dep, var), vars(jump.rate))
144-
dep
145-
end
142+
get_variables!(dep, jump::Union{ConstantRateJump,VariableRateJump}, variables) = get_variables!(dep, jump.rate, variables)
146143

147144
function get_variables!(dep, jump::MassActionJump, variables)
148-
jsr = jump.scaled_rates
149-
150-
if jsr isa Variable
151-
(jsr in variables) && push!(dep, jsr)
152-
elseif jsr isa Operation
153-
foreach(var -> (var in variables) && push!(dep, var), vars(jsr))
154-
end
155-
145+
get_variables!(dep, jump.scaled_rates, variables)
156146
for varasop in jump.reactant_stoch
157-
var = convert(Variable, varasop[1])
158-
(var in variables) && push!(dep, var)
147+
(varasop[1].op in variables) && push!(dep, varasop[1])
159148
end
160-
161149
dep
162150
end
163151

164152
### Functions to determine which states are modified by a given jump
165153
function modified_states!(mstates, jump::Union{ConstantRateJump,VariableRateJump}, sts)
166154
for eq in jump.affect!
167-
st = convert(Variable, eq.lhs)
168-
(st in sts) && push!(mstates, st)
155+
st = eq.lhs
156+
(st.op in sts) && push!(mstates, st)
169157
end
170158
end
171159

172160
function modified_states!(mstates, jump::MassActionJump, sts)
173161
for (state,stoich) in jump.net_stoch
174-
st = convert(Variable, state)
175-
(st in sts) && push!(mstates, st)
162+
(state.op in sts) && push!(mstates, state)
176163
end
177164
end

src/utils.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,25 @@ get_variables(O::Operation)
8888
8989
Returns the variables in the Operation
9090
"""
91-
get_variables(e::Constant, vars = Operation[]) = vars
92-
function get_variables(e::Operation, vars = Operation[])
91+
get_variables!(vars, e::Constant, varlist=nothing) = vars
92+
get_variables(e::Constant, varlist=nothing) = get_variables!(Operation[], e, varlist)
93+
94+
function get_variables!(vars, e::Operation, varlist=nothing)
9395
if is_singleton(e)
94-
push!(vars, e)
96+
(isnothing(varlist) ? true : (e.op in varlist)) && push!(vars, e)
9597
else
96-
foreach(x -> get_variables(x, vars), e.args)
98+
foreach(x -> get_variables!(vars, x, varlist), e.args)
9799
end
98100
return unique(vars)
99101
end
102+
get_variables(e::Operation, varlist=nothing) = get_variables!(Operation[], e, varlist)
103+
104+
function get_variables!(vars, e::Equation, varlist=nothing)
105+
get_variables!(vars, e.rhs, varlist)
106+
end
107+
get_variables(e::Equation, varlist=nothing) = get_variables!(Operation[],e,varlist)
108+
109+
modified_states!(mstates, e::Equation, statelist=nothing) = get_variables!(mstates, e.lhs, statelist)
100110

101111
# variable substitution
102112
"""
@@ -119,4 +129,4 @@ function _substitute(expr, dict::Dict)
119129
simplify(SymbolicUtils.substitute(expr, dict))
120130
end
121131

122-
@deprecate substitute_expr!(expr,s) substitute(expr,s)
132+
@deprecate substitute_expr!(expr,s) substitute(expr,s)

test/dep_graphs.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ rs = ReactionSystem(rxs, t, [S,I,R], [k1,k2])
1414

1515

1616
#################################
17-
# testing for Jumps
17+
# testing for Jumps / all dgs
1818
#################################
1919
js = convert(JumpSystem, rs)
2020
S = convert(Variable,S); I = convert(Variable,I); R = convert(Variable,R)
@@ -71,4 +71,32 @@ end
7171
dg4 = varvar_dependencies(depsbg,deps2)
7272
@test dg == dg4
7373

74-
74+
#####################################
75+
# testing for ODE/SDEs
76+
#####################################
77+
os = convert(ODESystem, rs)
78+
deps = equation_dependencies(os)
79+
eq_sdeps = [[S,I], [S,I], [S,I,R]]
80+
@test all(i -> isequal(Set(eq_sdeps[i]),Set(deps[i])), 1:length(deps))
81+
82+
sdes = convert(SDESystem, rs)
83+
deps = equation_dependencies(sdes)
84+
@test all(i -> isequal(Set(eq_sdeps[i]),Set(deps[i])), 1:length(deps))
85+
86+
deps = variable_dependencies(os)
87+
s_eqdeps = [[1],[2],[3]]
88+
@test deps.fadjlist == s_eqdeps
89+
90+
#####################################
91+
# testing for nonlin sys
92+
#####################################
93+
@variables x y z
94+
@parameters σ ρ β
95+
96+
eqs = [0 ~ σ*(y-x),
97+
0 ~ ρ-y,
98+
0 ~ y - β*z]
99+
ns = NonlinearSystem(eqs, [x,y,z],[σ,ρ,β])
100+
deps = equation_dependencies(ns)
101+
eq_sdeps = [[x,y],[y],[y,z]]
102+
@test all(i -> isequal(Set(deps[i]),Set(convert.(Variable,eq_sdeps[i]))), 1:length(deps))

test/jumpsystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,14 @@ jprob = JumpProblem(js3, dprob, Direct())
118118
m3 = getmean(jprob,Nsims)
119119
@test abs(m-m3)/m < .01
120120

121-
# maj jump test with dep graphs
121+
# maj jump test with various dep graphs
122122
js3b = JumpSystem([maj1,maj2], t, [S,I,R], [β,γ])
123123
jprobb = JumpProblem(js3b, dprob, NRM())
124124
m4 = getmean(jprobb,Nsims)
125125
@test abs(m-m4)/m < .01
126+
jprobc = JumpProblem(js3b, dprob, RSSA())
127+
m4 = getmean(jprobc,Nsims)
128+
@test abs(m-m4)/m < .01
126129

127130
# mass action jump tests for other reaction types (zero order, decay)
128131
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])

0 commit comments

Comments
 (0)