@@ -128,13 +128,18 @@ struct SDESystem <: AbstractODESystem
128
128
The hierarchical parent system before simplification.
129
129
"""
130
130
parent:: Any
131
+ """
132
+ Signal for whether the noise equations should be treated as a scalar process. This should only
133
+ be `true` when `noiseeqs isa Vector`.
134
+ """
135
+ is_scalar_noise:: Bool
131
136
132
137
function SDESystem (tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
133
138
tgrad,
134
139
jac,
135
140
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
136
141
cevents, devents, parameter_dependencies, metadata = nothing , gui_metadata = nothing ,
137
- complete = false , index_cache = nothing , parent = nothing ;
142
+ complete = false , index_cache = nothing , parent = nothing , is_scalar_noise = false ;
138
143
checks:: Union{Bool, Int} = true )
139
144
if checks == true || (checks & CheckComponents) > 0
140
145
check_independent_variables ([iv])
@@ -146,6 +151,9 @@ struct SDESystem <: AbstractODESystem
146
151
throw (ArgumentError (" Noise equations ill-formed. Number of rows must match number of drift equations. size(neqs,1) = $(size (neqs,1 )) != length(deqs) = $(length (deqs)) " ))
147
152
end
148
153
check_equations (equations (cevents), iv)
154
+ if is_scalar_noise && neqs isa AbstractMatrix
155
+ throw (ArgumentError (" Noise equations ill-formed. Received a matrix of noise equations of size $(size (neqs)) , but `is_scalar_noise` was set to `true`. Scalar noise is only compatible with an `AbstractVector` of noise equations." ))
156
+ end
149
157
end
150
158
if checks == true || (checks & CheckUnits) > 0
151
159
u = __get_unit_type (dvs, ps, iv)
@@ -154,7 +162,7 @@ struct SDESystem <: AbstractODESystem
154
162
new (tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
155
163
ctrl_jac,
156
164
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
157
- parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent)
165
+ parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise )
158
166
end
159
167
end
160
168
@@ -173,7 +181,11 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
173
181
discrete_events = nothing ,
174
182
parameter_dependencies = nothing ,
175
183
metadata = nothing ,
176
- gui_metadata = nothing )
184
+ gui_metadata = nothing ,
185
+ complete = false ,
186
+ index_cache = nothing ,
187
+ parent = nothing ,
188
+ is_scalar_noise = false )
177
189
name === nothing &&
178
190
throw (ArgumentError (" The `name` keyword must be provided. Please consider using the `@named` macro" ))
179
191
iv′ = value (iv)
@@ -210,7 +222,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
210
222
SDESystem (Threads. atomic_add! (SYSTEM_COUNT, UInt (1 )),
211
223
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
212
224
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
213
- cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata; checks = checks)
225
+ cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
226
+ complete, index_cache, parent, is_scalar_noise; checks = checks)
214
227
end
215
228
216
229
function SDESystem (sys:: ODESystem , neqs; kwargs... )
@@ -225,6 +238,7 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
225
238
isequal (nameof (sys1), nameof (sys2)) &&
226
239
isequal (get_eqs (sys1), get_eqs (sys2)) &&
227
240
isequal (get_noiseeqs (sys1), get_noiseeqs (sys2)) &&
241
+ isequal (get_is_scalar_noise (sys1), get_is_scalar_noise (sys2)) &&
228
242
_eq_unordered (get_unknowns (sys1), get_unknowns (sys2)) &&
229
243
_eq_unordered (get_ps (sys1), get_ps (sys2)) &&
230
244
all (s1 == s2 for (s1, s2) in zip (get_systems (sys1), get_systems (sys2)))
@@ -616,16 +630,24 @@ function DiffEqBase.SDEProblem{iip, specialize}(
616
630
sparsenoise === nothing && (sparsenoise = get (kwargs, :sparse , false ))
617
631
618
632
noiseeqs = get_noiseeqs (sys)
633
+ is_scalar_noise = get_is_scalar_noise (sys)
619
634
if noiseeqs isa AbstractVector
620
635
noise_rate_prototype = nothing
636
+ if is_scalar_noise
637
+ noise = WienerProcess (0.0 , 0.0 , 0.0 )
638
+ else
639
+ noise = nothing
640
+ end
621
641
elseif sparsenoise
622
642
I, J, V = findnz (SparseArrays. sparse (noiseeqs))
623
643
noise_rate_prototype = SparseArrays. sparse (I, J, zero (eltype (u0)))
644
+ noise = nothing
624
645
else
625
646
noise_rate_prototype = zeros (eltype (u0), size (noiseeqs))
647
+ noise = nothing
626
648
end
627
649
628
- SDEProblem {iip} (f, u0, tspan, p; callback = cbs,
650
+ SDEProblem {iip} (f, u0, tspan, p; callback = cbs, noise,
629
651
noise_rate_prototype = noise_rate_prototype, kwargs... )
630
652
end
631
653
@@ -693,22 +715,32 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
693
715
sparsenoise === nothing && (sparsenoise = get (kwargs, :sparse , false ))
694
716
695
717
noiseeqs = get_noiseeqs (sys)
718
+ is_scalar_noise = get_is_scalar_noise (sys)
696
719
if noiseeqs isa AbstractVector
697
720
noise_rate_prototype = nothing
721
+ if is_scalar_noise
722
+ noise = WienerProcess (0.0 , 0.0 , 0.0 )
723
+ else
724
+ noise = nothing
725
+ end
698
726
elseif sparsenoise
699
727
I, J, V = findnz (SparseArrays. sparse (noiseeqs))
700
728
noise_rate_prototype = SparseArrays. sparse (I, J, zero (eltype (u0)))
729
+ noise = nothing
701
730
else
702
731
T = u0 === nothing ? Float64 : eltype (u0)
703
732
noise_rate_prototype = zeros (T, size (get_noiseeqs (sys)))
733
+ noise = nothing
704
734
end
705
735
ex = quote
706
736
f = $ f
707
737
u0 = $ u0
708
738
tspan = $ tspan
709
739
p = $ p
710
740
noise_rate_prototype = $ noise_rate_prototype
711
- SDEProblem (f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype,
741
+ noise = $ noise
742
+ SDEProblem (
743
+ f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype, noise = noise,
712
744
$ (kwargs... ))
713
745
end
714
746
! linenumbers ? Base. remove_linenums! (ex) : ex
0 commit comments