Skip to content

Commit 317f9c7

Browse files
feat: generate vector noise function for diagonal noise matrix
1 parent 97e51fe commit 317f9c7

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
231231
if isdde
232232
eqs = delay_to_function(sys, eqs)
233233
end
234+
if isdiag(eqs)
235+
eqs = diag(eqs)
236+
end
234237
u = map(x -> time_varying_as_func(value(x), sys), dvs)
235238
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
236239
reorder_parameters(get_index_cache(sys), ps)

test/sdesystem.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,3 +629,16 @@ sprob = SDEProblem(sys, u0, tspan, ps)
629629
@test !isinplace(sprob)
630630
@test !isinplace(sprob.f)
631631
@test_nowarn solve(sprob, ImplicitEM())
632+
633+
# Ensure diagonal noise generates vector noise function
634+
@variables y(tt)
635+
@brownian b
636+
eqs = [D(x) ~ p - d * x + a * sqrt(p)
637+
D(y) ~ p - d * y + b * sqrt(d)]
638+
@mtkbuild sys = System(eqs, tt)
639+
u0 = @SVector[x => 10.0, y => 20.0]
640+
tspan = (0.0, 10.0)
641+
ps = @SVector[p => 5.0, d => 0.5]
642+
sprob = SDEProblem(sys, u0, tspan, ps)
643+
@test sprob.f.g(sprob.u0, sprob.p, sprob.tspan[1]) isa SVector{2, Float64}
644+
@test_nowarn solve(sprob, ImplicitEM())

0 commit comments

Comments
 (0)