Skip to content

Commit 400f292

Browse files
Merge pull request #636 from lxvm/proto
allow `integrand_prototype` for oop integral functions
2 parents 14c0b32 + 263f910 commit 400f292

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

src/scimlfunctions.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,9 +2056,10 @@ out-of-place handling.
20562056
IntegralFunction{iip,specialize}(f, [integrand_prototype])
20572057
```
20582058
2059-
Note that only `f` is required, and in the case of inplace integrands a mutable container
2059+
Note that only `f` is required, and in the case of inplace integrands a mutable array
20602060
`integrand_prototype` to store the result of the integrand. If `integrand_prototype` is
2061-
present, `f` is interpreted as in-place, and otherwise `f` is assumed to be out-of-place.
2061+
present for either in-place or out-of-place integrands it is used to infer the return type
2062+
of the integrand.
20622063
20632064
## iip: In-Place vs Out-Of-Place
20642065
@@ -2114,14 +2115,14 @@ BatchIntegralFunction{iip,specialize}(bf, [integrand_prototype];
21142115
max_batch=typemax(Int))
21152116
```
21162117
Note that only `bf` is required, and in the case of inplace integrands a mutable
2117-
container `integrand_prototype` to store a batch of integrand evaluations, with
2118+
array `integrand_prototype` to store a batch of integrand evaluations, with
21182119
a last "batching" dimension.
21192120
21202121
The keyword `max_batch` is used to set a soft limit on the number of points to
21212122
batch at the same time so that memory usage is controlled.
21222123
2123-
If `integrand_prototype` is present, `bf` is interpreted as in-place, and
2124-
otherwise `bf` is assumed to be out-of-place.
2124+
If `integrand_prototype` is present for either in-place or out-of-place integrands it is
2125+
used to infer the return type of the integrand.
21252126
21262127
## iip: In-Place vs Out-Of-Place
21272128
@@ -3158,7 +3159,8 @@ function DAEFunction{iip, specialize}(f;
31583159
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
31593160
sys = __has_sys(f) ? f.sys : nothing,
31603161
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
3161-
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where {iip,
3162+
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where {
3163+
iip,
31623164
specialize
31633165
}
31643166
if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
@@ -3854,10 +3856,7 @@ function IntegralFunction(f)
38543856
end
38553857
function IntegralFunction(f, integrand_prototype)
38563858
calculated_iip = isinplace(f, 3, "integral", true)
3857-
if !calculated_iip
3858-
throw(IntegrandMismatchFunctionError(calculated_iip, true))
3859-
end
3860-
IntegralFunction{true}(f, integrand_prototype)
3859+
IntegralFunction{calculated_iip}(f, integrand_prototype)
38613860
end
38623861

38633862
function BatchIntegralFunction{iip, specialize}(f, integrand_prototype;
@@ -3890,10 +3889,7 @@ function BatchIntegralFunction(f; kwargs...)
38903889
end
38913890
function BatchIntegralFunction(f, integrand_prototype; kwargs...)
38923891
calculated_iip = isinplace(f, 3, "batchintegral", true)
3893-
if !calculated_iip
3894-
throw(IntegrandMismatchFunctionError(calculated_iip, true))
3895-
end
3896-
BatchIntegralFunction{true}(f, integrand_prototype; kwargs...)
3892+
BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...)
38973893
end
38983894

38993895
########## Utility functions

test/function_building_error_messages.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ intfiip(y, u, p) = y .= 1.0
474474
for (f, kws, iip) in (
475475
(intf, (;), false),
476476
(IntegralFunction(intf), (;), false),
477+
(IntegralFunction(intf, 1.0), (;), false),
477478
(intfiip, (; nout = 3), true),
478479
(IntegralFunction(intfiip, zeros(3)), (;), true)
479480
), domain in (((0.0, 1.0),), (([0.0], [1.0]),), (0.0, 1.0), ([0.0], [1.0]))
@@ -648,9 +649,9 @@ i1(u) = u
648649
itoo(y, u, p, a) = y .= u * p
649650

650651
IntegralFunction(ioop)
652+
IntegralFunction(ioop, 0.0)
651653
IntegralFunction(iiip, Float64[])
652654

653-
@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(ioop, Float64[])
654655
@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(iiip)
655656
@test_throws SciMLBase.TooFewArgumentsError IntegralFunction(i1)
656657
@test_throws SciMLBase.TooManyArgumentsError IntegralFunction(itoo)
@@ -665,10 +666,11 @@ bitoo(y, u, p, a) = y .= p .* u
665666

666667
BatchIntegralFunction(boop)
667668
BatchIntegralFunction(boop, max_batch = 20)
669+
BatchIntegralFunction(boop, Float64[])
670+
BatchIntegralFunction(boop, Float64[], max_batch = 20)
668671
BatchIntegralFunction(biip, Float64[])
669672
BatchIntegralFunction(biip, Float64[], max_batch = 20)
670673

671-
@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(boop, Float64[])
672674
@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(biip)
673675
@test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(bi1)
674676
@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo)

0 commit comments

Comments
 (0)