Skip to content

Commit baaad84

Browse files
committed
Additional initialization path
1 parent da5aaff commit baaad84

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

src/function_wrappers.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractSciMLFunctio
44
p::P
55
end
66

7+
function TimeGradientWrapper{iip}(f::F, uprev, p) where {F, iip}
8+
return TimeGradientWrapper{iip, F, typeof(uprev), typeof(p)}(f, uprev, p)
9+
end
710
function TimeGradientWrapper(f::F, uprev, p) where {F}
8-
return TimeGradientWrapper{isinplace(f, 4), F, typeof(uprev), typeof(p)}(f, uprev, p)
11+
return TimeGradientWrapper{isinplace(f, 4)}(f, uprev, p)
912
end
1013

1114
(ff::TimeGradientWrapper{true})(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)
@@ -19,9 +22,10 @@ mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractSciMLFunction{i
1922
p::P
2023
end
2124

22-
function UJacobianWrapper(f::F, t, p) where {F}
23-
return UJacobianWrapper{isinplace(f, 4), F, typeof(t), typeof(p)}(f, t, p)
25+
function UJacobianWrapper{iip}(f::F, t, p) where {F, iip}
26+
return UJacobianWrapper{iip, F, typeof(t), typeof(p)}(f, t, p)
2427
end
28+
UJacobianWrapper(f::F, t, p) where {F} = UJacobianWrapper{isinplace(f, 4)}(f, t, p)
2529

2630
(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
2731
(ff::UJacobianWrapper{true})(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
@@ -37,8 +41,11 @@ mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractSciMLFunction{
3741
p::P
3842
end
3943

44+
function TimeDerivativeWrapper{iip}(f::F, u, p) where {F, iip}
45+
return TimeDerivativeWrapper{iip, F, typeof(u), typeof(p)}(f, u, p)
46+
end
4047
function TimeDerivativeWrapper(f::F, u, p) where {F}
41-
return TimeDerivativeWrapper{isinplace(f, 4), F, typeof(u), typeof(p)}(f, u, p)
48+
return TimeDerivativeWrapper{isinplace(f, 4)}(f, u, p)
4249
end
4350

4451
(ff::TimeDerivativeWrapper{false})(t) = ff.f(ff.u, ff.p, t)
@@ -51,9 +58,10 @@ mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractSciMLFunction{iip
5158
p::P
5259
end
5360

54-
function UDerivativeWrapper(f::F, t, p) where {F}
55-
return UDerivativeWrapper{isinplace(f, 4), F, typeof(t), typeof(p)}(f, t, p)
61+
function UDerivativeWrapper{iip}(f::F, t, p) where {F, iip}
62+
return UDerivativeWrapper{iip, F, typeof(t), typeof(p)}(f, t, p)
5663
end
64+
UDerivativeWrapper(f::F, t, p) where {F} = UDerivativeWrapper{isinplace(f, 4)}(f, t, p)
5765

5866
(ff::UDerivativeWrapper{false})(u) = ff.f(u, ff.p, ff.t)
5967
(ff::UDerivativeWrapper{true})(du1, u) = ff.f(du1, u, ff.p, ff.t)
@@ -65,9 +73,10 @@ mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractSciMLFu
6573
u::uType
6674
end
6775

68-
function ParamJacobianWrapper(f::F, t, u) where {F}
69-
return ParamJacobianWrapper{isinplace(f, 4), F, typeof(t), typeof(u)}(f, t, u)
76+
function ParamJacobianWrapper{iip}(f::F, t, u) where {F, iip}
77+
return ParamJacobianWrapper{iip, F, typeof(t), typeof(u)}(f, t, u)
7078
end
79+
ParamJacobianWrapper(f::F, t, u) where {F} = ParamJacobianWrapper{isinplace(f, 4)}(f, t, u)
7180

7281
(ff::ParamJacobianWrapper{true})(du1, p) = ff.f(du1, ff.u, p, ff.t)
7382
function (ff::ParamJacobianWrapper{true})(p)
@@ -82,9 +91,9 @@ mutable struct JacobianWrapper{iip, fType, pType} <: AbstractSciMLFunction{iip}
8291
p::pType
8392
end
8493

85-
function JacobianWrapper(f::F, p) where {F}
86-
return JacobianWrapper{isinplace(f, 4), F, typeof(p)}(f, p)
87-
end
94+
JacobianWrapper{iip}(f::F, p) where {F, iip} = JacobianWrapper{iip, F, typeof(p)}(f, p)
95+
JacobianWrapper(f::F, p) where {F} = JacobianWrapper{isinplace(f, 3)}(f, p)
8896

8997
(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
98+
(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p)))
9099
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)

0 commit comments

Comments
 (0)