Skip to content

Commit f6b059c

Browse files
Merge pull request #543 from SciML/ap/propagate_iip
Propagate IIP information in the Wrapper Functions
2 parents ece1966 + d7f5284 commit f6b059c

File tree

3 files changed

+66
-21
lines changed

3 files changed

+66
-21
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <[email protected]> and contributors"]
4-
version = "2.8.1"
4+
version = "2.8.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -69,6 +69,8 @@ PartialFunctions = "1.1"
6969
PrecompileTools = "1"
7070
Preferences = "1.3"
7171
Printf = "1.9"
72+
PyCall = "1.96"
73+
PythonCall = "0.9"
7274
RCall = "0.13.18"
7375
RecipesBase = "0.7.0, 0.8, 1.0"
7476
RecursiveArrayTools = "2.33"

src/function_wrappers.jl

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,99 @@
1-
mutable struct TimeGradientWrapper{fType, uType, P} <: Function
1+
mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractSciMLFunction{iip}
22
f::fType
33
uprev::uType
44
p::P
55
end
6-
(ff::TimeGradientWrapper)(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)
7-
(ff::TimeGradientWrapper)(du2, t) = ff.f(du2, ff.uprev, ff.p, t)
86

9-
mutable struct UJacobianWrapper{fType, tType, P} <: Function
7+
function TimeGradientWrapper{iip}(f::F, uprev, p) where {F, iip}
8+
return TimeGradientWrapper{iip, F, typeof(uprev), typeof(p)}(f, uprev, p)
9+
end
10+
function TimeGradientWrapper(f::F, uprev, p) where {F}
11+
return TimeGradientWrapper{isinplace(f, 4)}(f, uprev, p)
12+
end
13+
14+
(ff::TimeGradientWrapper{true})(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)
15+
(ff::TimeGradientWrapper{true})(du2, t) = ff.f(du2, ff.uprev, ff.p, t)
16+
17+
(ff::TimeGradientWrapper{false})(t) = ff.f(ff.uprev, ff.p, t)
18+
19+
mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractSciMLFunction{iip}
1020
f::fType
1121
t::tType
1222
p::P
1323
end
1424

15-
(ff::UJacobianWrapper)(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
16-
(ff::UJacobianWrapper)(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
17-
(ff::UJacobianWrapper)(du1, uprev, p, t) = ff.f(du1, uprev, p, t)
18-
(ff::UJacobianWrapper)(uprev, p, t) = (du1 = similar(uprev); ff.f(du1, uprev, p, t); du1)
25+
function UJacobianWrapper{iip}(f::F, t, p) where {F, iip}
26+
return UJacobianWrapper{iip, F, typeof(t), typeof(p)}(f, t, p)
27+
end
28+
UJacobianWrapper(f::F, t, p) where {F} = UJacobianWrapper{isinplace(f, 4)}(f, t, p)
29+
30+
(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
31+
(ff::UJacobianWrapper{true})(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
32+
(ff::UJacobianWrapper{true})(du1, uprev, p, t) = ff.f(du1, uprev, p, t)
33+
(ff::UJacobianWrapper{true})(uprev, p, t) = (du1 = similar(uprev); ff.f(du1, uprev, p, t); du1)
1934

20-
mutable struct TimeDerivativeWrapper{F, uType, P} <: Function
35+
(ff::UJacobianWrapper{false})(uprev) = ff.f(uprev, ff.p, ff.t)
36+
(ff::UJacobianWrapper{false})(uprev, p, t) = ff.f(uprev, p, t)
37+
38+
mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractSciMLFunction{iip}
2139
f::F
2240
u::uType
2341
p::P
2442
end
25-
(ff::TimeDerivativeWrapper)(t) = ff.f(ff.u, ff.p, t)
2643

27-
mutable struct UDerivativeWrapper{F, tType, P} <: Function
44+
function TimeDerivativeWrapper{iip}(f::F, u, p) where {F, iip}
45+
return TimeDerivativeWrapper{iip, F, typeof(u), typeof(p)}(f, u, p)
46+
end
47+
function TimeDerivativeWrapper(f::F, u, p) where {F}
48+
return TimeDerivativeWrapper{isinplace(f, 4)}(f, u, p)
49+
end
50+
51+
(ff::TimeDerivativeWrapper{false})(t) = ff.f(ff.u, ff.p, t)
52+
(ff::TimeDerivativeWrapper{true})(du1, t) = ff.f(du1, ff.u, ff.p, t)
53+
(ff::TimeDerivativeWrapper{true})(t) = (du1 = similar(ff.u); ff.f(du1, ff.u, ff.p, t); du1)
54+
55+
mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractSciMLFunction{iip}
2856
f::F
2957
t::tType
3058
p::P
3159
end
32-
(ff::UDerivativeWrapper)(u) = ff.f(u, ff.p, ff.t)
3360

34-
mutable struct ParamJacobianWrapper{fType, tType, uType} <: Function
61+
function UDerivativeWrapper{iip}(f::F, t, p) where {F, iip}
62+
return UDerivativeWrapper{iip, F, typeof(t), typeof(p)}(f, t, p)
63+
end
64+
UDerivativeWrapper(f::F, t, p) where {F} = UDerivativeWrapper{isinplace(f, 4)}(f, t, p)
65+
66+
(ff::UDerivativeWrapper{false})(u) = ff.f(u, ff.p, ff.t)
67+
(ff::UDerivativeWrapper{true})(du1, u) = ff.f(du1, u, ff.p, ff.t)
68+
(ff::UDerivativeWrapper{true})(u) = (du1 = similar(u); ff.f(du1, u, ff.p, ff.t); du1)
69+
70+
mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractSciMLFunction{iip}
3571
f::fType
3672
t::tType
3773
u::uType
3874
end
3975

40-
function (ff::ParamJacobianWrapper)(du1, p)
41-
ff.f(du1, ff.u, p, ff.t)
76+
function ParamJacobianWrapper{iip}(f::F, t, u) where {F, iip}
77+
return ParamJacobianWrapper{iip, F, typeof(t), typeof(u)}(f, t, u)
4278
end
79+
ParamJacobianWrapper(f::F, t, u) where {F} = ParamJacobianWrapper{isinplace(f, 4)}(f, t, u)
4380

44-
function (ff::ParamJacobianWrapper)(p)
81+
(ff::ParamJacobianWrapper{true})(du1, p) = ff.f(du1, ff.u, p, ff.t)
82+
function (ff::ParamJacobianWrapper{true})(p)
4583
du1 = similar(p, size(ff.u))
4684
ff.f(du1, ff.u, p, ff.t)
4785
return du1
4886
end
87+
(ff::ParamJacobianWrapper{false})(p) = ff.f(ff.u, p, ff.t)
4988

50-
mutable struct JacobianWrapper{fType, pType}
89+
mutable struct JacobianWrapper{iip, fType, pType} <: AbstractSciMLFunction{iip}
5190
f::fType
5291
p::pType
5392
end
5493

55-
(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
56-
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
94+
JacobianWrapper{iip}(f::F, p) where {F, iip} = JacobianWrapper{iip, F, typeof(p)}(f, p)
95+
JacobianWrapper(f::F, p) where {F} = JacobianWrapper{isinplace(f, 3)}(f, p)
96+
97+
(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
98+
(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p)))
99+
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)

test/aqua.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ end
2525
@testset "Aqua tests (additional)" begin
2626
Aqua.test_undefined_exports(SciMLBase)
2727
Aqua.test_stale_deps(SciMLBase)
28-
Aqua.test_deps_compat(SciMLBase)
28+
Aqua.test_deps_compat(SciMLBase, check_extras = false)
2929
Aqua.test_project_extras(SciMLBase)
3030
# Aqua.test_project_toml_formatting(SciMLBase) # failing
3131
# Aqua.test_piracy(SciMLBase) # failing

0 commit comments

Comments
 (0)