Skip to content

Commit b382a37

Browse files
fix: infer oop form for SDEProblem/SDEFunction with StaticArrays
1 parent a203b86 commit b382a37

File tree

2 files changed

+48
-6
lines changed

2 files changed

+48
-6
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,14 +403,14 @@ function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
403403
checks = false)
404404
end
405405

406-
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
406+
function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(sys),
407407
ps = parameters(sys),
408408
u0 = nothing;
409409
version = nothing, tgrad = false, sparse = false,
410410
jac = false, Wfact = false, eval_expression = false,
411411
eval_module = @__MODULE__,
412412
checkbounds = false,
413-
kwargs...) where {iip}
413+
kwargs...) where {iip, specialize}
414414
if !iscomplete(sys)
415415
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
416416
end
@@ -480,7 +480,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
480480

481481
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
482482

483-
SDEFunction{iip}(f, g,
483+
SDEFunction{iip, specialize}(f, g,
484484
sys = sys,
485485
jac = _jac === nothing ? nothing : _jac,
486486
tgrad = _tgrad === nothing ? nothing : _tgrad,
@@ -505,6 +505,16 @@ function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...)
505505
SDEFunction{true}(sys, args...; kwargs...)
506506
end
507507

508+
function DiffEqBase.SDEFunction{true}(sys::SDESystem, args...;
509+
kwargs...)
510+
SDEFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
511+
end
512+
513+
function DiffEqBase.SDEFunction{false}(sys::SDESystem, args...;
514+
kwargs...)
515+
SDEFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
516+
end
517+
508518
"""
509519
```julia
510520
DiffEqBase.SDEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
@@ -583,14 +593,16 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
583593
SDEFunctionExpr{true}(sys, args...; kwargs...)
584594
end
585595

586-
function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map = [], tspan = get_tspan(sys),
596+
function DiffEqBase.SDEProblem{iip, specialize}(
597+
sys::SDESystem, u0map = [], tspan = get_tspan(sys),
587598
parammap = DiffEqBase.NullParameters();
588599
sparsenoise = nothing, check_length = true,
589-
callback = nothing, kwargs...) where {iip}
600+
callback = nothing, kwargs...) where {iip, specialize}
590601
if !iscomplete(sys)
591602
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`")
592603
end
593-
f, u0, p = process_DEProblem(SDEFunction{iip}, sys, u0map, parammap; check_length,
604+
f, u0, p = process_DEProblem(
605+
SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
594606
kwargs...)
595607
cbs = process_events(sys; callback, kwargs...)
596608
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
@@ -628,6 +640,21 @@ function DiffEqBase.SDEProblem(sys::SDESystem, args...; kwargs...)
628640
SDEProblem{true}(sys, args...; kwargs...)
629641
end
630642

643+
function DiffEqBase.SDEProblem(sys::SDESystem,
644+
u0map::StaticArray,
645+
args...;
646+
kwargs...)
647+
SDEProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
648+
end
649+
650+
function DiffEqBase.SDEProblem{true}(sys::SDESystem, args...; kwargs...)
651+
SDEProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
652+
end
653+
654+
function DiffEqBase.SDEProblem{false}(sys::SDESystem, args...; kwargs...)
655+
SDEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
656+
end
657+
631658
"""
632659
```julia
633660
DiffEqBase.SDEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,

test/sdesystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,3 +614,18 @@ sys2 = complete(sys2)
614614
prob = SDEProblem(sys1, sts .=> [1.0, 0.0, 0.0],
615615
(0.0, 100.0), ps .=> (10.0, 26.0))
616616
solve(prob, LambaEulerHeun(), seed = 1)
617+
618+
# SDEProblem construction with StaticArrays
619+
# Issue#2814
620+
@parameters p d
621+
@variables x(t)
622+
@brownian a
623+
eqs = [D(x) ~ p - d * x + a * sqrt(p)]
624+
@mtkbuild sys = System(eqs, t)
625+
u0 = @SVector[x => 10.0]
626+
tspan = (0.0, 10.0)
627+
ps = @SVector[p => 5.0, d => 0.5]
628+
sprob = SDEProblem(ssys, u0, tspan, ps)
629+
@test !isinplace(sprob)
630+
@test !isinplace(sprob.f)
631+
@test_nowarn solve(sprob, ImplicitEM())

0 commit comments

Comments
 (0)