@@ -231,6 +231,9 @@ function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
231
231
if isdde
232
232
eqs = delay_to_function (sys, eqs)
233
233
end
234
+ if eqs isa AbstractMatrix && isdiag (eqs)
235
+ eqs = diag (eqs)
236
+ end
234
237
u = map (x -> time_varying_as_func (value (x), sys), dvs)
235
238
p = if has_index_cache (sys) && get_index_cache (sys) != = nothing
236
239
reorder_parameters (get_index_cache (sys), ps)
@@ -403,14 +406,14 @@ function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
403
406
checks = false )
404
407
end
405
408
406
- function DiffEqBase. SDEFunction {iip} (sys:: SDESystem , dvs = unknowns (sys),
409
+ function DiffEqBase. SDEFunction {iip, specialize } (sys:: SDESystem , dvs = unknowns (sys),
407
410
ps = parameters (sys),
408
411
u0 = nothing ;
409
412
version = nothing , tgrad = false , sparse = false ,
410
413
jac = false , Wfact = false , eval_expression = false ,
411
414
eval_module = @__MODULE__ ,
412
415
checkbounds = false ,
413
- kwargs... ) where {iip}
416
+ kwargs... ) where {iip, specialize }
414
417
if ! iscomplete (sys)
415
418
error (" A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`" )
416
419
end
@@ -480,7 +483,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
480
483
481
484
observedfun = ObservedFunctionCache (sys; eval_expression, eval_module)
482
485
483
- SDEFunction {iip} (f, g,
486
+ SDEFunction {iip, specialize } (f, g,
484
487
sys = sys,
485
488
jac = _jac === nothing ? nothing : _jac,
486
489
tgrad = _tgrad === nothing ? nothing : _tgrad,
@@ -505,6 +508,16 @@ function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...)
505
508
SDEFunction {true} (sys, args... ; kwargs... )
506
509
end
507
510
511
+ function DiffEqBase. SDEFunction {true} (sys:: SDESystem , args... ;
512
+ kwargs... )
513
+ SDEFunction {true, SciMLBase.AutoSpecialize} (sys, args... ; kwargs... )
514
+ end
515
+
516
+ function DiffEqBase. SDEFunction {false} (sys:: SDESystem , args... ;
517
+ kwargs... )
518
+ SDEFunction {false, SciMLBase.FullSpecialize} (sys, args... ; kwargs... )
519
+ end
520
+
508
521
"""
509
522
```julia
510
523
DiffEqBase.SDEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
@@ -583,14 +596,16 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
583
596
SDEFunctionExpr {true} (sys, args... ; kwargs... )
584
597
end
585
598
586
- function DiffEqBase. SDEProblem {iip} (sys:: SDESystem , u0map = [], tspan = get_tspan (sys),
599
+ function DiffEqBase. SDEProblem {iip, specialize} (
600
+ sys:: SDESystem , u0map = [], tspan = get_tspan (sys),
587
601
parammap = DiffEqBase. NullParameters ();
588
602
sparsenoise = nothing , check_length = true ,
589
- callback = nothing , kwargs... ) where {iip}
603
+ callback = nothing , kwargs... ) where {iip, specialize }
590
604
if ! iscomplete (sys)
591
605
error (" A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`" )
592
606
end
593
- f, u0, p = process_DEProblem (SDEFunction{iip}, sys, u0map, parammap; check_length,
607
+ f, u0, p = process_DEProblem (
608
+ SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
594
609
kwargs... )
595
610
cbs = process_events (sys; callback, kwargs... )
596
611
sparsenoise === nothing && (sparsenoise = get (kwargs, :sparse , false ))
@@ -628,6 +643,21 @@ function DiffEqBase.SDEProblem(sys::SDESystem, args...; kwargs...)
628
643
SDEProblem {true} (sys, args... ; kwargs... )
629
644
end
630
645
646
+ function DiffEqBase. SDEProblem (sys:: SDESystem ,
647
+ u0map:: StaticArray ,
648
+ args... ;
649
+ kwargs... )
650
+ SDEProblem {false, SciMLBase.FullSpecialize} (sys, u0map, args... ; kwargs... )
651
+ end
652
+
653
+ function DiffEqBase. SDEProblem {true} (sys:: SDESystem , args... ; kwargs... )
654
+ SDEProblem {true, SciMLBase.AutoSpecialize} (sys, args... ; kwargs... )
655
+ end
656
+
657
+ function DiffEqBase. SDEProblem {false} (sys:: SDESystem , args... ; kwargs... )
658
+ SDEProblem {false, SciMLBase.FullSpecialize} (sys, args... ; kwargs... )
659
+ end
660
+
631
661
"""
632
662
```julia
633
663
DiffEqBase.SDEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,
0 commit comments