Skip to content

Commit 9017479

Browse files
committed
Propagate IIP information in the Wrapper Functions
1 parent ece1966 commit 9017479

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <[email protected]> and contributors"]
4-
version = "2.8.1"
4+
version = "2.8.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/function_wrappers.jl

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,90 @@
1-
mutable struct TimeGradientWrapper{fType, uType, P} <: Function
1+
mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractSciMLFunction{iip}
22
f::fType
33
uprev::uType
44
p::P
55
end
6-
(ff::TimeGradientWrapper)(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)
7-
(ff::TimeGradientWrapper)(du2, t) = ff.f(du2, ff.uprev, ff.p, t)
86

9-
mutable struct UJacobianWrapper{fType, tType, P} <: Function
7+
function TimeGradientWrapper(f::F, uprev, p) where {F}
8+
return TimeGradientWrapper{isinplace(f), F, typeof(uprev), typeof(p)}(f, uprev, p)
9+
end
10+
11+
(ff::TimeGradientWrapper{true})(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)
12+
(ff::TimeGradientWrapper{true})(du2, t) = ff.f(du2, ff.uprev, ff.p, t)
13+
14+
(ff::TimeGradientWrapper{false})(t) = ff.f(ff.uprev, ff.p, t)
15+
16+
mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractSciMLFunction{iip}
1017
f::fType
1118
t::tType
1219
p::P
1320
end
1421

15-
(ff::UJacobianWrapper)(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
16-
(ff::UJacobianWrapper)(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
17-
(ff::UJacobianWrapper)(du1, uprev, p, t) = ff.f(du1, uprev, p, t)
18-
(ff::UJacobianWrapper)(uprev, p, t) = (du1 = similar(uprev); ff.f(du1, uprev, p, t); du1)
22+
function UJacobianWrapper(f::F, t, p) where {F}
23+
return UJacobianWrapper{isinplace(f), F, typeof(t), typeof(p)}(f, t, p)
24+
end
25+
26+
(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
27+
(ff::UJacobianWrapper{true})(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
28+
(ff::UJacobianWrapper{true})(du1, uprev, p, t) = ff.f(du1, uprev, p, t)
29+
(ff::UJacobianWrapper{true})(uprev, p, t) = (du1 = similar(uprev); ff.f(du1, uprev, p, t); du1)
1930

20-
mutable struct TimeDerivativeWrapper{F, uType, P} <: Function
31+
(ff::UJacobianWrapper{false})(uprev) = ff.f(uprev, ff.p, ff.t)
32+
(ff::UJacobianWrapper{false})(uprev, p, t) = ff.f(uprev, p, t)
33+
34+
mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractSciMLFunction{iip}
2135
f::F
2236
u::uType
2337
p::P
2438
end
25-
(ff::TimeDerivativeWrapper)(t) = ff.f(ff.u, ff.p, t)
2639

27-
mutable struct UDerivativeWrapper{F, tType, P} <: Function
40+
function TimeDerivativeWrapper(f::F, u, p) where {F}
41+
return TimeDerivativeWrapper{isinplace(f), F, typeof(u), typeof(p)}(f, u, p)
42+
end
43+
44+
(ff::TimeDerivativeWrapper{false})(t) = ff.f(ff.u, ff.p, t)
45+
(ff::TimeDerivativeWrapper{true})(du1, t) = ff.f(du1, ff.u, ff.p, t)
46+
(ff::TimeDerivativeWrapper{true})(t) = (du1 = similar(ff.u); ff.f(du1, ff.u, ff.p, t); du1)
47+
48+
mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractSciMLFunction{iip}
2849
f::F
2950
t::tType
3051
p::P
3152
end
32-
(ff::UDerivativeWrapper)(u) = ff.f(u, ff.p, ff.t)
3353

34-
mutable struct ParamJacobianWrapper{fType, tType, uType} <: Function
54+
function UDerivativeWrapper(f::F, t, p) where {F}
55+
return UDerivativeWrapper{isinplace(f), F, typeof(t), typeof(p)}(f, t, p)
56+
end
57+
58+
(ff::UDerivativeWrapper{false})(u) = ff.f(u, ff.p, ff.t)
59+
(ff::UDerivativeWrapper{true})(du1, u) = ff.f(du1, u, ff.p, ff.t)
60+
(ff::UDerivativeWrapper{true})(u) = (du1 = similar(u); ff.f(du1, u, ff.p, ff.t); du1)
61+
62+
mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractSciMLFunction{iip}
3563
f::fType
3664
t::tType
3765
u::uType
3866
end
3967

40-
function (ff::ParamJacobianWrapper)(du1, p)
41-
ff.f(du1, ff.u, p, ff.t)
68+
function ParamJacobianWrapper(f::F, t, u) where {F}
69+
return ParamJacobianWrapper{isinplace(f), F, typeof(t), typeof(u)}(f, t, u)
4270
end
4371

44-
function (ff::ParamJacobianWrapper)(p)
72+
(ff::ParamJacobianWrapper{true})(du1, p) = ff.f(du1, ff.u, p, ff.t)
73+
function (ff::ParamJacobianWrapper{true})(p)
4574
du1 = similar(p, size(ff.u))
4675
ff.f(du1, ff.u, p, ff.t)
4776
return du1
4877
end
78+
(ff::ParamJacobianWrapper{false})(p) = ff.f(ff.u, p, ff.t)
4979

50-
mutable struct JacobianWrapper{fType, pType}
80+
mutable struct JacobianWrapper{iip, fType, pType} <: AbstractSciMLFunction{iip}
5181
f::fType
5282
p::pType
5383
end
5484

55-
(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
56-
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
85+
function JacobianWrapper(f::F, p) where {F}
86+
return JacobianWrapper{isinplace(f), F, typeof(p)}(f, p)
87+
end
88+
89+
(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
90+
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)

0 commit comments

Comments
 (0)