Skip to content

Commit 9b7e139

Browse files
Merge pull request #2882 from MasonProtter/detect-diagonal-noise
Detect diagonal noise in `SDESystem`
2 parents c97a0dc + dd3277f commit 9b7e139

File tree

7 files changed

+126
-11
lines changed

7 files changed

+126
-11
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1212
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1313
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1414
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
15+
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
1516
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1617
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1718
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -72,6 +73,7 @@ DataStructures = "0.17, 0.18"
7273
DeepDiffs = "1"
7374
DiffEqBase = "6.103.0"
7475
DiffEqCallbacks = "2.16, 3"
76+
DiffEqNoiseProcess = "5"
7577
DiffRules = "0.1, 1.0"
7678
Distributed = "1"
7779
Distributions = "0.23, 0.24, 0.25"

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using DiffEqCallbacks
2222
using Graphs
2323
import ExprTools: splitdef, combinedef
2424
import OrderedCollections
25+
using DiffEqNoiseProcess: DiffEqNoiseProcess, WienerProcess
2526

2627
using SymbolicIndexingInterface
2728
using LinearAlgebra, SparseArrays, LabelledArrays

src/systems/abstractsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,8 @@ for prop in [:eqs
655655
:solved_unknowns
656656
:split_idxs
657657
:parent
658-
:index_cache]
658+
:index_cache
659+
:is_scalar_noise]
659660
fname_get = Symbol(:get_, prop)
660661
fname_has = Symbol(:has_, prop)
661662
@eval begin

src/systems/diffeqs/sdesystem.jl

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,18 @@ struct SDESystem <: AbstractODESystem
128128
The hierarchical parent system before simplification.
129129
"""
130130
parent::Any
131+
"""
132+
Signal for whether the noise equations should be treated as a scalar process. This should only
133+
be `true` when `noiseeqs isa Vector`.
134+
"""
135+
is_scalar_noise::Bool
131136

132137
function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
133138
tgrad,
134139
jac,
135140
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
136141
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
137-
complete = false, index_cache = nothing, parent = nothing;
142+
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false;
138143
checks::Union{Bool, Int} = true)
139144
if checks == true || (checks & CheckComponents) > 0
140145
check_independent_variables([iv])
@@ -146,6 +151,9 @@ struct SDESystem <: AbstractODESystem
146151
throw(ArgumentError("Noise equations ill-formed. Number of rows must match number of drift equations. size(neqs,1) = $(size(neqs,1)) != length(deqs) = $(length(deqs))"))
147152
end
148153
check_equations(equations(cevents), iv)
154+
if is_scalar_noise && neqs isa AbstractMatrix
155+
throw(ArgumentError("Noise equations ill-formed. Received a matrix of noise equations of size $(size(neqs)), but `is_scalar_noise` was set to `true`. Scalar noise is only compatible with an `AbstractVector` of noise equations."))
156+
end
149157
end
150158
if checks == true || (checks & CheckUnits) > 0
151159
u = __get_unit_type(dvs, ps, iv)
@@ -154,7 +162,7 @@ struct SDESystem <: AbstractODESystem
154162
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
155163
ctrl_jac,
156164
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
157-
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent)
165+
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise)
158166
end
159167
end
160168

@@ -173,7 +181,11 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
173181
discrete_events = nothing,
174182
parameter_dependencies = nothing,
175183
metadata = nothing,
176-
gui_metadata = nothing)
184+
gui_metadata = nothing,
185+
complete = false,
186+
index_cache = nothing,
187+
parent = nothing,
188+
is_scalar_noise = false)
177189
name === nothing &&
178190
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
179191
iv′ = value(iv)
@@ -210,7 +222,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
210222
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
211223
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
212224
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
213-
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata; checks = checks)
225+
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
226+
complete, index_cache, parent, is_scalar_noise; checks = checks)
214227
end
215228

216229
function SDESystem(sys::ODESystem, neqs; kwargs...)
@@ -225,6 +238,7 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
225238
isequal(nameof(sys1), nameof(sys2)) &&
226239
isequal(get_eqs(sys1), get_eqs(sys2)) &&
227240
isequal(get_noiseeqs(sys1), get_noiseeqs(sys2)) &&
241+
isequal(get_is_scalar_noise(sys1), get_is_scalar_noise(sys2)) &&
228242
_eq_unordered(get_unknowns(sys1), get_unknowns(sys2)) &&
229243
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
230244
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
@@ -616,16 +630,24 @@ function DiffEqBase.SDEProblem{iip, specialize}(
616630
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
617631

618632
noiseeqs = get_noiseeqs(sys)
633+
is_scalar_noise = get_is_scalar_noise(sys)
619634
if noiseeqs isa AbstractVector
620635
noise_rate_prototype = nothing
636+
if is_scalar_noise
637+
noise = WienerProcess(0.0, 0.0, 0.0)
638+
else
639+
noise = nothing
640+
end
621641
elseif sparsenoise
622642
I, J, V = findnz(SparseArrays.sparse(noiseeqs))
623643
noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0)))
644+
noise = nothing
624645
else
625646
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
647+
noise = nothing
626648
end
627649

628-
SDEProblem{iip}(f, u0, tspan, p; callback = cbs,
650+
SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
629651
noise_rate_prototype = noise_rate_prototype, kwargs...)
630652
end
631653

@@ -693,22 +715,32 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
693715
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
694716

695717
noiseeqs = get_noiseeqs(sys)
718+
is_scalar_noise = get_is_scalar_noise(sys)
696719
if noiseeqs isa AbstractVector
697720
noise_rate_prototype = nothing
721+
if is_scalar_noise
722+
noise = WienerProcess(0.0, 0.0, 0.0)
723+
else
724+
noise = nothing
725+
end
698726
elseif sparsenoise
699727
I, J, V = findnz(SparseArrays.sparse(noiseeqs))
700728
noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0)))
729+
noise = nothing
701730
else
702731
T = u0 === nothing ? Float64 : eltype(u0)
703732
noise_rate_prototype = zeros(T, size(get_noiseeqs(sys)))
733+
noise = nothing
704734
end
705735
ex = quote
706736
f = $f
707737
u0 = $u0
708738
tspan = $tspan
709739
p = $p
710740
noise_rate_prototype = $noise_rate_prototype
711-
SDEProblem(f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype,
741+
noise = $noise
742+
SDEProblem(
743+
f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype, noise = noise,
712744
$(kwargs...))
713745
end
714746
!linenumbers ? Base.remove_linenums!(ex) : ex

src/systems/systems.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,23 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
127127
g_row > size(g, 1) && continue
128128
@views copyto!(sorted_g_rows[i, :], g[g_row, :])
129129
end
130-
131-
return SDESystem(full_equations(ode_sys), sorted_g_rows,
130+
# Fix for https://github.com/SciML/ModelingToolkit.jl/issues/2490
131+
if sorted_g_rows isa AbstractMatrix && size(sorted_g_rows, 2) == 1
132+
# If there's only one brownian variable referenced across all the equations,
133+
# we get a Nx1 matrix of noise equations, which is a special case known as scalar noise
134+
noise_eqs = sorted_g_rows[:, 1]
135+
is_scalar_noise = true
136+
elseif isdiag(sorted_g_rows)
137+
# If the noise matrix is diagonal, then the solver just takes a vector column of equations
138+
# and it interprets that as diagonal noise.
139+
noise_eqs = diag(sorted_g_rows)
140+
is_scalar_noise = false
141+
else
142+
noise_eqs = sorted_g_rows
143+
is_scalar_noise = false
144+
end
145+
return SDESystem(full_equations(ode_sys), noise_eqs,
132146
get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys);
133-
name = nameof(ode_sys))
147+
name = nameof(ode_sys), is_scalar_noise)
134148
end
135149
end

test/dde.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ sol = solve(prob, RKMil())
7878
eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c +* x(t) + γ) * η]
7979
@mtkbuild sys = System(eqs, t)
8080
@test equations(sys) == [D(x(t)) ~ a * x(t) + b * x(t - τ) + c]
81-
@test isequal(ModelingToolkit.get_noiseeqs(sys), [α * x(t) + γ;;])
81+
@test isequal(ModelingToolkit.get_noiseeqs(sys), [α * x(t) + γ])
8282
prob_mtk = SDDEProblem(sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,));
8383
@test_nowarn sol_mtk = solve(prob_mtk, RKMil())
8484

test/sdesystem.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,68 @@ ps = @SVector[p => 5.0, d => 0.5]
657657
sprob = SDEProblem(sys, u0, tspan, ps)
658658
@test sprob.f.g(sprob.u0, sprob.p, sprob.tspan[1]) isa SVector{2, Float64}
659659
@test_nowarn solve(sprob, ImplicitEM())
660+
661+
let
662+
@parameters σ ρ β
663+
@variables x(t) y(t) z(t)
664+
@brownian a
665+
eqs = [D(x) ~ σ * (y - x) + 0.1a * x,
666+
D(y) ~ x *- z) - y + 0.1a * y,
667+
D(z) ~ x * y - β * z + 0.1a * z]
668+
669+
@mtkbuild de = System(eqs, t)
670+
671+
u0map = [
672+
x => 1.0,
673+
y => 0.0,
674+
z => 0.0
675+
]
676+
677+
parammap = [
678+
σ => 10.0,
679+
β => 26.0,
680+
ρ => 2.33
681+
]
682+
prob = SDEProblem(de, u0map, (0.0, 100.0), parammap)
683+
# TODO: re-enable this when we support scalar noise
684+
@test solve(prob, SOSRI()).retcode == ReturnCode.Success
685+
end
686+
687+
let # test to make sure that scalar noise always receive the same kicks
688+
@variables x(t) y(t)
689+
@brownian a
690+
eqs = [D(x) ~ a,
691+
D(y) ~ a]
692+
693+
@mtkbuild de = System(eqs, t)
694+
prob = SDEProblem(de, [x => 0, y => 0], (0.0, 10.0), [])
695+
sol = solve(prob, SOSRI())
696+
@test sol[end][1] == sol[end][2]
697+
end
698+
699+
let # test that diagonal noise is correctly handled
700+
@parameters σ ρ β
701+
@variables x(t) y(t) z(t)
702+
@brownian a b c
703+
eqs = [D(x) ~ σ * (y - x) + 0.1a * x,
704+
D(y) ~ x *- z) - y + 0.1b * y,
705+
D(z) ~ x * y - β * z + 0.1c * z]
706+
707+
@mtkbuild de = System(eqs, t)
708+
709+
u0map = [
710+
x => 1.0,
711+
y => 0.0,
712+
z => 0.0
713+
]
714+
715+
parammap = [
716+
σ => 10.0,
717+
β => 26.0,
718+
ρ => 2.33
719+
]
720+
721+
prob = SDEProblem(de, u0map, (0.0, 100.0), parammap)
722+
# SOSRI only works for diagonal and scalar noise
723+
@test solve(prob, SOSRI()).retcode == ReturnCode.Success
724+
end

0 commit comments

Comments
 (0)