Skip to content

Commit c926e49

Browse files
Merge pull request #369 from isaacsas/rs-cleanup
add ismassaction and get rid of some temp arrays
2 parents 93f81d2 + 4c20da5 commit c926e49

File tree

3 files changed

+41
-32
lines changed

3 files changed

+41
-32
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ export JumpProblem, DiscreteProblem
116116
export NonlinearSystem, OptimizationSystem
117117
export ode_order_lowering
118118
export PDESystem
119-
export Reaction, ReactionSystem
119+
export Reaction, ReactionSystem, ismassaction
120120
export Differential, expand_derivatives, @derivatives
121121
export IntervalDomain, ProductDomain, , CircleDomain
122122
export Equation, ConstrainedEquation

src/systems/reaction/reactionsystem.jl

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ function get_netstoich(subs, prods, sstoich, pstoich)
5252
ns
5353
end
5454

55-
5655
struct ReactionSystem <: AbstractSystem
5756
eqs::Vector{Reaction}
5857
iv::Variable
@@ -121,46 +120,56 @@ function assemble_diffusion(rs)
121120
eqs
122121
end
123122

124-
function assemble_jumps(rs)
125-
eqs = Vector{Union{ConstantRateJump, MassActionJump, VariableRateJump}}()
126-
127-
for rx in rs.eqs
128-
rl = jumpratelaw(rx)
129-
affect = Vector{Equation}()
130-
for (spec,stoich) in rx.netstoich
131-
push!(affect,var2op(spec) ~ var2op(spec) + stoich)
132-
end
133-
if any(isequal.(var2op(rs.iv),get_variables(rx.rate)))
134-
push!(eqs,VariableRateJump(rl,affect))
135-
elseif rx.only_use_rate || any([isequal(state,r_op) for state in rs.states, r_op in getfield.(get_variables(rx.rate),:op)])
136-
push!(eqs,ConstantRateJump(rl,affect))
137-
else
138-
reactant_stoch = isempty(rx.substoich) ? [0 => 1] : Pair.(var2op.(getfield.(rx.substrates,:op)),rx.substoich)
139-
net_stoch = map(p -> Pair(var2op(p[1]),p[2]),rx.netstoich)
140-
push!(eqs,MassActionJump(rx.rate, reactant_stoch, net_stoch))
141-
end
142-
end
143-
eqs
123+
function var2op(var)
124+
Operation(var,Vector{Expression}())
144125
end
145126

146127
# Calculate the Jump rate law (like ODE, but uses X instead of X(t).
147128
# The former generates a "MethodError: objects of type Int64 are not callable" when trying to solve the problem.
148-
function jumpratelaw(rx)
129+
function jumpratelaw(rx; rxvars=get_variables(rx.rate))
149130
@unpack rate, substrates, substoich, only_use_rate = rx
150-
rl = deepcopy(rate)
151-
for op in get_variables(rx.rate)
152-
rl = substitute(rl,op=>var2op(op.op))
131+
rl = rate
132+
for op in rxvars
133+
rl = substitute(rl, op => var2op(op.op))
153134
end
154135
if !only_use_rate
155136
for (i,stoich) in enumerate(substoich)
156-
rl *= isone(stoich) ? var2op(substrates[i].op) : Operation(binomial,[var2op(substrates[i].op),stoich])
137+
rl *= isone(stoich) ? var2op(substrates[i].op) : Operation(binomial,[var2op(substrates[i].op),stoich])
157138
end
158139
end
159140
rl
160141
end
161142

162-
function var2op(var)
163-
Operation(var,Vector{Expression}())
143+
# if haveivdep=false then time dependent rates will still be classified as mass action
144+
function ismassaction(rx, rs; rxvars = get_variables(rx.rate),
145+
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars))
146+
return !(haveivdep || rx.only_use_rate || any(convert(Variable,rxv) in states(rs) for rxv in rxvars))
147+
end
148+
149+
function assemble_jumps(rs)
150+
eqs = Vector{Union{ConstantRateJump, MassActionJump, VariableRateJump}}()
151+
152+
for rx in equations(rs)
153+
rxvars = get_variables(rx.rate)
154+
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars)
155+
if ismassaction(rx, rs; rxvars=rxvars, haveivdep=haveivdep)
156+
reactant_stoch = isempty(rx.substoich) ? [0 => 1] : [var2op(sub.op) => stoich for (sub,stoich) in zip(rx.substrates,rx.substoich)]
157+
net_stoch = [Pair(var2op(p[1]),p[2]) for p in rx.netstoich]
158+
push!(eqs, MassActionJump(rx.rate, reactant_stoch, net_stoch))
159+
else
160+
rl = jumpratelaw(rx, rxvars=rxvars)
161+
affect = Vector{Equation}()
162+
for (spec,stoich) in rx.netstoich
163+
push!(affect, var2op(spec) ~ var2op(spec) + stoich)
164+
end
165+
if haveivdep
166+
push!(eqs, VariableRateJump(rl,affect))
167+
else
168+
push!(eqs, ConstantRateJump(rl,affect))
169+
end
170+
end
171+
end
172+
eqs
164173
end
165174

166175
function Base.convert(::Type{<:ODESystem},rs::ReactionSystem)

test/reactionsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,21 +118,21 @@ statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(
118118
parammap = map((x,y)->Pair(x(),y),parameters(js),pars)
119119
for i = 1:14
120120
maj = MT.assemble_maj(js, js.eqs[i], statetoid,parammap)
121-
@test abs(jumps[i].scaled_rates - maj.scaled_rates) < 10*eps()
121+
@test abs(jumps[i].scaled_rates - maj.scaled_rates) < 100*eps()
122122
@test jumps[i].reactant_stoch == maj.reactant_stoch
123123
@test jumps[i].net_stoch == maj.net_stoch
124124
end
125125
for i = 15:18
126126
(i==16) && continue
127127
crj = MT.assemble_crj(js, js.eqs[i], statetoid)
128-
@test abs(crj.rate(u0,p,time) - jumps[i].rate(u0,p,time)) < 10*eps()
128+
@test abs(crj.rate(u0,p,time) - jumps[i].rate(u0,p,time)) < 100*eps()
129129
fake_integrator1 = (u=zeros(4),p=p,t=0); fake_integrator2 = deepcopy(fake_integrator1);
130130
crj.affect!(fake_integrator1); jumps[i].affect!(fake_integrator2);
131131
@test fake_integrator1 == fake_integrator2
132132
end
133133
for i = 19:20
134134
crj = MT.assemble_vrj(js, js.eqs[i], statetoid)
135-
@test abs(crj.rate(u0,p,time) - jumps[i].rate(u0,p,time)) < 10*eps()
135+
@test abs(crj.rate(u0,p,time) - jumps[i].rate(u0,p,time)) < 100*eps()
136136
fake_integrator1 = (u=zeros(4),p=p,t=0.); fake_integrator2 = deepcopy(fake_integrator1);
137137
crj.affect!(fake_integrator1); jumps[i].affect!(fake_integrator2);
138138
@test fake_integrator1 == fake_integrator2

0 commit comments

Comments
 (0)