Skip to content

Commit c42b3e4

Browse files
committed
add mass action jump support
1 parent 7c1172e commit c42b3e4

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,29 @@ function assemble_crj(js, crj, statetoid)
4242
ConstantRateJump(rate, affect)
4343
end
4444

45+
function assemble_maj(js, maj::MassActionJump{U,Vector{Pair{V,W}},Vector{Pair{V2,W2}}},
46+
statetoid, ptoid, p, pcontext) where {U,V,W,V2,W2}
47+
pval = Base.eval(pcontext, Expr(maj.scaled_rates))
48+
49+
rs = Vector{Pair{Int,W}}()
50+
for (spec,stoich) in maj.reactant_stoch
51+
if iszero(spec)
52+
push!(rs, 0 => stoich)
53+
else
54+
push!(rs, statetoid[convert(Variable,spec)] => stoich)
55+
end
56+
end
57+
sort!(rs)
58+
59+
ns = Vector{Pair{Int,W2}}()
60+
for (spec,stoich) in maj.net_stoch
61+
iszero(spec) && error("Net stoichiometry can not have a species labelled 0.")
62+
push!(ns, statetoid[convert(Variable,spec)] => stoich)
63+
end
64+
sort!(ns)
65+
66+
MassActionJump(pval, rs, ns, scale_rates = false)
67+
end
4568

4669
"""
4770
```julia
@@ -68,17 +91,32 @@ Generates a JumpProblem from a JumpSystem.
6891
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
6992
vrjs = Vector{VariableRateJump}()
7093
crjs = Vector{ConstantRateJump}()
94+
majs = Vector{MassActionJump}()
95+
pvars = parameters(js)
7196
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
97+
ptoid = Dict(convert(Variable,par) => i for (i,par) in enumerate(parameters(js)))
98+
99+
# for mass action jumps might need to evaluate parameter expressions
100+
# populate dummy module with params as local variables
101+
# (for eval-ing parameter expressions)
102+
param_context = Module()
103+
for (i, pval) in enumerate(prob.p)
104+
psym = Symbol(pvars[i])
105+
Base.eval(param_context, :($psym = $pval))
106+
end
107+
72108
for j in equations(js)
73109
if j isa ConstantRateJump
74110
push!(crjs, assemble_crj(js, j, statetoid))
75111
elseif j isa VariableRateJump
76112
push!(vrjs, assemble_vrj(js, j, statetoid))
113+
elseif j isa MassActionJump
114+
push!(majs, assemble_maj(js, j, statetoid, ptoid, prob.p, param_context))
77115
else
78-
(j isa MassActionJump) && error("Generation of JumpProblems with MassActionJumps is not yet supported.")
116+
error("JumpSystems should only contain Constant, Variable or Mass Action Jumps.")
79117
end
80118
end
81119
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
82-
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, nothing)
120+
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, isempty(majs) ? nothing : majs)
83121
JumpProblem(prob, aggregator, jset)
84122
end

test/jumpsystem.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,19 @@ jprob = JumpProblem(prob,Direct(),jset, save_positions=(false,false))
105105
m2 = getmean(jprob,Nsims)
106106

107107
# test JumpSystem solution agrees with direct version
108-
@test abs(m-m2) ./ m < .01
108+
@test abs(m-m2) ./ m < .01
109+
110+
111+
# mass action jump tests for SIR model
112+
maj1 = MassActionJump(2*β/2, [S => 1, I => 1], [S => -1, I => 1])
113+
maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
114+
js = JumpSystem([maj1,maj2], t, [S,I,R], [β,γ])
115+
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
116+
ptoid = Dict(convert(Variable,par) => i for (i,par) in enumerate(parameters(js)))
117+
dprob = DiscreteProblem(js, u₀map, tspan, parammap)
118+
jprob = JumpProblem(js, dprob, Direct())
119+
m3 = getmean(jprob,Nsims)
120+
@test abs(m2-m3) < .01
121+
122+
# mass action jump tests for other reaction types (zero order, second order, decay)
123+
# TODO

0 commit comments

Comments
 (0)