Skip to content

Commit b10cddb

Browse files
Merge pull request #359 from jlchan/jc/VectorOfArray_multidim_helper
Specialize `Base.similar` for `VectorOfArray` with multidimensional parent
2 parents 239bd04 + 43072c8 commit b10cddb

File tree

5 files changed

+53
-57
lines changed

5 files changed

+53
-57
lines changed

src/vector_of_array.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,16 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray
156156
VectorOfArray{T, N + 1, typeof(vec)}(vec)
157157
end
158158

159-
# allow multi-dimensional arrays as long as they're linearly indexed
159+
# allow multi-dimensional arrays as long as they're linearly indexed.
160+
# currently restricted to arrays whose elements are all the same type
160161
function VectorOfArray(array::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
161162
@assert IndexStyle(typeof(array)) isa IndexLinear
162163

163164
return VectorOfArray{T, N + 1, typeof(array)}(array)
164165
end
165166

167+
Base.parent(vec::VectorOfArray) = vec.u
168+
166169
function DiffEqArray(vec::AbstractVector{T},
167170
ts::AbstractVector,
168171
::NTuple{N, Int},
@@ -721,6 +724,18 @@ end
721724
VectorOfArray([similar(VA[:, i], T) for i in eachindex(VA.u)])
722725
end
723726

727+
# for VectorOfArray with multi-dimensional parent arrays of arrays where all elements are the same type
728+
function Base.similar(vec::VectorOfArray{
729+
T, N, AT}) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}}
730+
return VectorOfArray(similar(Base.parent(vec)))
731+
end
732+
733+
# special-case when the multi-dimensional parent array is just an AbstractVector (call the old method)
734+
function Base.similar(vec::VectorOfArray{
735+
T, N, AT}) where {T, N, AT <: AbstractVector{<:AbstractArray{T}}}
736+
return Base.similar(vec, eltype(vec))
737+
end
738+
724739
# fill!
725740
# For DiffEqArray it ignores ts and fills only u
726741
function Base.fill!(VA::AbstractVectorOfArray, x)

test/basic_indexing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ foo!(u_matrix)
250250
foo!(u_vector)
251251
@test u_matrix u_vector
252252

253+
# test that, for VectorOfArray with multi-dimensional parent arrays,
254+
# `similar` preserves the structure of the parent array
255+
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
256+
253257
# test efficiency
254258
num_allocs = @allocations foo!(u_matrix)
255259
@test num_allocs == 0

test/downstream/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
33
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
44
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
56
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
67
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
78

@@ -10,4 +11,5 @@ ModelingToolkit = "8.33"
1011
MonteCarloMeasurements = "1.1"
1112
OrdinaryDiffEq = "6.31"
1213
Unitful = "1.17"
13-
Tracker = "0.2"
14+
Tracker = "0.2"
15+
StaticArrays = "1"

test/upstream.jl renamed to test/downstream/odesolve.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test, ArrayInterface
1+
using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test, ArrayInterface, StaticArrays
22
function lorenz(du, u, p, t)
33
du[1] = 10.0 * (u[2] - u[1])
44
du[2] = u[1] * (28.0 - u[3]) - u[2]
@@ -49,3 +49,14 @@ end
4949
ArrayPartition(zeros(1), [0.75])),
5050
(0.0, 1.0)),
5151
Rodas5()).retcode == ReturnCode.Success
52+
53+
function rhs!(duu::VectorOfArray, uu::VectorOfArray, p, t)
54+
du = parent(duu)
55+
u = parent(uu)
56+
du .= u
57+
end
58+
59+
u = fill(SVector{2}(ones(2)), 2, 3)
60+
ode = ODEProblem(rhs!, VectorOfArray(u), (0.0, 1.0))
61+
sol = solve(ode, Tsit5())
62+
@test SciMLBase.successful_retcode(sol)

test/runtests.jl

Lines changed: 18 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,67 +19,31 @@ end
1919

2020
@time begin
2121
if GROUP == "Core" || GROUP == "All"
22-
@time @safetestset "Quality Assurance" begin
23-
include("qa.jl")
24-
end
25-
@time @safetestset "Utils Tests" begin
26-
include("utils_test.jl")
27-
end
28-
@time @safetestset "NamedArrayPartition Tests" begin
29-
include("named_array_partition_tests.jl")
30-
end
31-
@time @safetestset "Partitions Tests" begin
32-
include("partitions_test.jl")
33-
end
34-
@time @safetestset "VecOfArr Indexing Tests" begin
35-
include("basic_indexing.jl")
36-
end
37-
@time @safetestset "SymbolicIndexingInterface API test" begin
38-
include("symbolic_indexing_interface_test.jl")
39-
end
40-
@time @safetestset "VecOfArr Interface Tests" begin
41-
include("interface_tests.jl")
42-
end
43-
@time @safetestset "Table traits" begin
44-
include("tabletraits.jl")
45-
end
46-
@time @safetestset "StaticArrays Tests" begin
47-
include("copy_static_array_test.jl")
48-
end
49-
@time @safetestset "Linear Algebra Tests" begin
50-
include("linalg.jl")
51-
end
52-
@time @safetestset "Upstream Tests" begin
53-
include("upstream.jl")
54-
end
55-
@time @safetestset "Adjoint Tests" begin
56-
include("adjoints.jl")
57-
end
58-
@time @safetestset "Measurement Tests" begin
59-
include("measurements.jl")
60-
end
22+
@time @safetestset "Quality Assurance" include("qa.jl")
23+
@time @safetestset "Utils Tests" include("utils_test.jl")
24+
@time @safetestset "NamedArrayPartition Tests" include("named_array_partition_tests.jl")
25+
@time @safetestset "Partitions Tests" include("partitions_test.jl")
26+
@time @safetestset "VecOfArr Indexing Tests" include("basic_indexing.jl")
27+
@time @safetestset "SymbolicIndexingInterface API test" include("symbolic_indexing_interface_test.jl")
28+
@time @safetestset "VecOfArr Interface Tests" include("interface_tests.jl")
29+
@time @safetestset "Table traits" include("tabletraits.jl")
30+
@time @safetestset "StaticArrays Tests" include("copy_static_array_test.jl")
31+
@time @safetestset "Linear Algebra Tests" include("linalg.jl")
32+
@time @safetestset "Adjoint Tests" include("adjoints.jl")
33+
@time @safetestset "Measurement Tests" include("measurements.jl")
6134
end
6235

6336
if GROUP == "Downstream"
6437
activate_downstream_env()
65-
@time @safetestset "DiffEqArray Indexing Tests" begin
66-
include("downstream/symbol_indexing.jl")
67-
end
68-
@time @safetestset "Event Tests with ArrayPartition" begin
69-
include("downstream/downstream_events.jl")
70-
end
71-
@time @safetestset "Measurements and Units" begin
72-
include("downstream/measurements_and_units.jl")
73-
end
74-
@time @safetestset "TrackerExt" begin
75-
include("downstream/TrackerExt.jl")
76-
end
38+
@time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl")
39+
@time @safetestset "ODE Solve Tests" include("downstream/odesolve.jl")
40+
@time @safetestset "Event Tests with ArrayPartition" include("downstream/downstream_events.jl")
41+
@time @safetestset "Measurements and Units" include("downstream/measurements_and_units.jl")
42+
@time @safetestset "TrackerExt" include("downstream/TrackerExt.jl")
7743
end
7844

7945
if GROUP == "GPU"
8046
activate_gpu_env()
81-
@time @safetestset "VectorOfArray GPU" begin
82-
include("gpu/vectorofarray_gpu.jl")
83-
end
47+
@time @safetestset "VectorOfArray GPU" include("gpu/vectorofarray_gpu.jl")
8448
end
8549
end

0 commit comments

Comments
 (0)