Skip to content

Commit d64f973

Browse files
Merge pull request #2834 from AayushSabharwal/as/sde-sarray
fix: infer oop form for SDEProblem/SDEFunction with StaticArrays
2 parents f3b040d + 54df3cc commit d64f973

File tree

2 files changed

+64
-6
lines changed

2 files changed

+64
-6
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
231231
if isdde
232232
eqs = delay_to_function(sys, eqs)
233233
end
234+
if eqs isa AbstractMatrix && isdiag(eqs)
235+
eqs = diag(eqs)
236+
end
234237
u = map(x -> time_varying_as_func(value(x), sys), dvs)
235238
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
236239
reorder_parameters(get_index_cache(sys), ps)
@@ -403,14 +406,14 @@ function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
403406
checks = false)
404407
end
405408

406-
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
409+
function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(sys),
407410
ps = parameters(sys),
408411
u0 = nothing;
409412
version = nothing, tgrad = false, sparse = false,
410413
jac = false, Wfact = false, eval_expression = false,
411414
eval_module = @__MODULE__,
412415
checkbounds = false,
413-
kwargs...) where {iip}
416+
kwargs...) where {iip, specialize}
414417
if !iscomplete(sys)
415418
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
416419
end
@@ -480,7 +483,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
480483

481484
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
482485

483-
SDEFunction{iip}(f, g,
486+
SDEFunction{iip, specialize}(f, g,
484487
sys = sys,
485488
jac = _jac === nothing ? nothing : _jac,
486489
tgrad = _tgrad === nothing ? nothing : _tgrad,
@@ -505,6 +508,16 @@ function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...)
505508
SDEFunction{true}(sys, args...; kwargs...)
506509
end
507510

511+
function DiffEqBase.SDEFunction{true}(sys::SDESystem, args...;
512+
kwargs...)
513+
SDEFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
514+
end
515+
516+
function DiffEqBase.SDEFunction{false}(sys::SDESystem, args...;
517+
kwargs...)
518+
SDEFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
519+
end
520+
508521
"""
509522
```julia
510523
DiffEqBase.SDEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
@@ -583,14 +596,16 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
583596
SDEFunctionExpr{true}(sys, args...; kwargs...)
584597
end
585598

586-
function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map = [], tspan = get_tspan(sys),
599+
function DiffEqBase.SDEProblem{iip, specialize}(
600+
sys::SDESystem, u0map = [], tspan = get_tspan(sys),
587601
parammap = DiffEqBase.NullParameters();
588602
sparsenoise = nothing, check_length = true,
589-
callback = nothing, kwargs...) where {iip}
603+
callback = nothing, kwargs...) where {iip, specialize}
590604
if !iscomplete(sys)
591605
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`")
592606
end
593-
f, u0, p = process_DEProblem(SDEFunction{iip}, sys, u0map, parammap; check_length,
607+
f, u0, p = process_DEProblem(
608+
SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
594609
kwargs...)
595610
cbs = process_events(sys; callback, kwargs...)
596611
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
@@ -628,6 +643,21 @@ function DiffEqBase.SDEProblem(sys::SDESystem, args...; kwargs...)
628643
SDEProblem{true}(sys, args...; kwargs...)
629644
end
630645

646+
function DiffEqBase.SDEProblem(sys::SDESystem,
647+
u0map::StaticArray,
648+
args...;
649+
kwargs...)
650+
SDEProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
651+
end
652+
653+
function DiffEqBase.SDEProblem{true}(sys::SDESystem, args...; kwargs...)
654+
SDEProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
655+
end
656+
657+
function DiffEqBase.SDEProblem{false}(sys::SDESystem, args...; kwargs...)
658+
SDEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
659+
end
660+
631661
"""
632662
```julia
633663
DiffEqBase.SDEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,

test/sdesystem.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,3 +614,31 @@ sys2 = complete(sys2)
614614
prob = SDEProblem(sys1, sts .=> [1.0, 0.0, 0.0],
615615
(0.0, 100.0), ps .=> (10.0, 26.0))
616616
solve(prob, LambaEulerHeun(), seed = 1)
617+
618+
# SDEProblem construction with StaticArrays
619+
# Issue#2814
620+
@parameters p d
621+
@variables x(tt)
622+
@brownian a
623+
eqs = [D(x) ~ p - d * x + a * sqrt(p)]
624+
@mtkbuild sys = System(eqs, tt)
625+
u0 = @SVector[x => 10.0]
626+
tspan = (0.0, 10.0)
627+
ps = @SVector[p => 5.0, d => 0.5]
628+
sprob = SDEProblem(sys, u0, tspan, ps)
629+
@test !isinplace(sprob)
630+
@test !isinplace(sprob.f)
631+
@test_nowarn solve(sprob, ImplicitEM())
632+
633+
# Ensure diagonal noise generates vector noise function
634+
@variables y(tt)
635+
@brownian b
636+
eqs = [D(x) ~ p - d * x + a * sqrt(p)
637+
D(y) ~ p - d * y + b * sqrt(d)]
638+
@mtkbuild sys = System(eqs, tt)
639+
u0 = @SVector[x => 10.0, y => 20.0]
640+
tspan = (0.0, 10.0)
641+
ps = @SVector[p => 5.0, d => 0.5]
642+
sprob = SDEProblem(sys, u0, tspan, ps)
643+
@test sprob.f.g(sprob.u0, sprob.p, sprob.tspan[1]) isa SVector{2, Float64}
644+
@test_nowarn solve(sprob, ImplicitEM())

0 commit comments

Comments
 (0)