Skip to content

Commit b813592

Browse files
Merge branch 'jump-construction-fix'
2 parents 21b4159 + 0929823 commit b813592

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,18 @@ function assemble_crj(js, crj, statetoid)
9191
end
9292

9393
function assemble_maj(js, maj::MassActionJump{U,Vector{Pair{V,W}},Vector{Pair{V2,W2}}},
94-
statetoid, parammap) where {U,V,W,V2,W2}
94+
statetoid, parammap, pcontext) where {U,V,W,V2,W2}
95+
9596
sr = maj.scaled_rates
96-
if sr isa Operation
97-
pval = simplify(substitute(sr,parammap)).value
97+
if sr isa Operation
98+
if isempty(sr.args)
99+
pval = parammap[sr.op]
100+
else
101+
pval = Base.eval(pcontext, Expr(maj.scaled_rates))
102+
end
98103
elseif sr isa Variable
99-
pval = Dict(parammap)[sr()]
100-
else
104+
pval = parammap[sr]
105+
else
101106
pval = maj.scaled_rates
102107
end
103108

@@ -164,10 +169,20 @@ sol = solve(jprob, SSAStepper())
164169
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
165170

166171
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
167-
parammap = map((x,y)->Pair(x(),y), parameters(js), prob.p)
172+
parammap = Dict(convert(Variable,param) => prob.p[i] for (i,param) in enumerate(parameters(js)))
168173
eqs = equations(js)
174+
175+
# for mass action jumps might need to evaluate parameter expressions
176+
# populate dummy module with params as local variables
177+
# (for eval-ing parameter expressions)
178+
pvars = parameters(js)
179+
param_context = Module()
180+
for (i, pval) in enumerate(prob.p)
181+
psym = Symbol(pvars[i])
182+
Base.eval(param_context, :($psym = $pval))
183+
end
169184

170-
majs = MassActionJump[assemble_maj(js, j, statetoid, parammap) for j in eqs.x[1]]
185+
majs = MassActionJump[assemble_maj(js, j, statetoid, parammap, param_context) for j in eqs.x[1]]
171186
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
172187
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
173188
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")

src/systems/reaction/reactionsystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,10 @@ function assemble_jumps(rs)
244244
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars)
245245
if ismassaction(rx, rs; rxvars=rxvars, haveivdep=haveivdep)
246246
reactant_stoch = isempty(rx.substoich) ? [0 => 1] : [var2op(sub.op) => stoich for (sub,stoich) in zip(rx.substrates,rx.substoich)]
247+
coef = isempty(rx.substoich) ? one(eltype(rx.substoich)) : prod(stoich -> factorial(stoich), rx.substoich)
248+
rate = isone(coef) ? rx.rate : rx.rate/coef
247249
net_stoch = [Pair(var2op(p[1]),p[2]) for p in rx.netstoich]
248-
push!(eqs, MassActionJump(rx.rate, reactant_stoch, net_stoch))
250+
push!(eqs, MassActionJump(rate, reactant_stoch, net_stoch, scale_rates=false))
249251
else
250252
rl = jumpratelaw(rx, rxvars=rxvars)
251253
affect = Vector{Equation}()

test/reactionsystem.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,15 @@ jumps[19] = VariableRateJump((u,p,t) -> p[19]*u[1]*t, integrator -> (integrator.
115115
jumps[20] = VariableRateJump((u,p,t) -> p[20]*t*u[1]*binomial(u[2],2)*u[3], integrator -> (integrator.u[2] -= 2; integrator.u[3] -= 1; integrator.u[4] += 2))
116116

117117
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
118-
parammap = map((x,y)->Pair(x(),y),parameters(js),pars)
118+
parammap = Dict(convert(Variable,param) => pars[i] for (i,param) in enumerate(parameters(js)))
119+
pvars = parameters(js)
120+
param_context = Module()
121+
for (i, pval) in enumerate(pars)
122+
psym = Symbol(pvars[i])
123+
Base.eval(param_context, :($psym = $pval))
124+
end
119125
for i = 1:14
120-
maj = MT.assemble_maj(js, js.eqs[i], statetoid,parammap)
126+
maj = MT.assemble_maj(js, js.eqs[i], statetoid, parammap, param_context)
121127
@test abs(jumps[i].scaled_rates - maj.scaled_rates) < 100*eps()
122128
@test jumps[i].reactant_stoch == maj.reactant_stoch
123129
@test jumps[i].net_stoch == maj.net_stoch

0 commit comments

Comments
 (0)