Skip to content

Commit c5712ac

Browse files
committed
use Val(dims) and Iterators.peel for _dim_stack
1 parent 5557072 commit c5712ac

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

base/abstractarray.jl

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2729,44 +2729,41 @@ _typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::IteratorSize, A) where {T,S}
27292729
_vec_axis(A, ax=_axes(A)) = length(ax) == 1 ? only(ax) : OneTo(prod(length, ax; init=1))
27302730

27312731
function _dim_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S}
2732-
xit = iterate(A)
2732+
xit = Iterators.peel(A)
27332733
nothing === xit && return _empty_stack(dims, T, S, A)
2734-
x1, _ = xit
2734+
x1, xrest = xit
27352735
ax1 = _axes(x1)
27362736
N1 = length(ax1)+1
27372737
dims in 1:N1 || throw(ArgumentError("cannot stack slices ndims(x) = $(N1-1) along dims = $dims"))
27382738

27392739
newaxis = _vec_axis(A)
2740-
outax = ntuple(d -> d==dims ? newaxis : _axes(x1)[d - (d>dims)], N1)
2740+
outax = ntuple(d -> d==dims ? newaxis : ax1[d - (d>dims)], N1)
27412741
B = similar(_first_array(x1, A), T, outax...)
27422742

2743-
iit = iterate(newaxis)
2744-
while xit !== nothing
2745-
x, state = xit
2746-
i, istate = iit
2747-
_stack_size_check(x, ax1)
2748-
@inbounds if dims==1
2749-
inds1 = ntuple(d -> d==1 ? i : Colon(), N1)
2750-
if x isa AbstractArray
2751-
B[inds1...] = x
2752-
else
2753-
copyto!(view(B, inds1...), x)
2754-
end
2755-
else
2756-
inds = ntuple(d -> d==dims ? i : Colon(), N1)
2757-
if x isa AbstractArray
2758-
B[inds...] = x
2759-
else
2760-
# This is where the type-instability of inds hurts, but it is pretty exotic:
2761-
copyto!(view(B, inds...), x)
2762-
end
2763-
end
2764-
xit = iterate(A, state)
2765-
iit = iterate(newaxis, istate)
2743+
if dims == 1
2744+
_dim_stack!(Val(1), B, x1, xrest)
2745+
elseif dims == 2
2746+
_dim_stack!(Val(2), B, x1, xrest)
2747+
else
2748+
_dim_stack!(Val(dims), B, x1, xrest)
27662749
end
27672750
B
27682751
end
27692752

2753+
function _dim_stack!(::Val{dims}, B::AbstractArray, x1, xrest) where {dims}
2754+
before = ntuple(d -> Colon(), dims - 1)
2755+
after = ntuple(d -> Colon(), ndims(B) - dims)
2756+
2757+
i = firstindex(B, dims)
2758+
copyto!(view(B, before..., i, after...), x1)
2759+
2760+
for x in xrest
2761+
_stack_size_check(x, _axes(x1))
2762+
i += 1
2763+
@inbounds copyto!(view(B, before..., i, after...), x)
2764+
end
2765+
end
2766+
27702767
@inline function _stack_size_check(x, ax1::Tuple)
27712768
if _axes(x) != ax1
27722769
uax1 = UnitRange.(ax1)

test/abstractarray.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,20 @@ end
16111611
# Trivial, because numbers are iterable:
16121612
@test stack(abs2, 1:3) == [1, 4, 9] == collect(Iterators.flatten(abs2(x) for x in 1:3))
16131613

1614+
# Allocation tests
1615+
xv = [rand(10) for _ in 1:100]
1616+
xt = Tuple.(xv)
1617+
for dims in (1, 2, :)
1618+
@test stack(xv; dims) == stack(xt; dims)
1619+
@test 9000 > @allocated stack(xv; dims)
1620+
@test 9000 > @allocated stack(xt; dims)
1621+
end
1622+
xr = (reshape(1:1000,10,10,10) for _ = 1:1000)
1623+
for dims in (1, 2, 3, :)
1624+
stack(xr; dims)
1625+
@test 8.1e6 > @allocated stack(xr; dims)
1626+
end
1627+
16141628
# Mismatched sizes
16151629
@test_throws DimensionMismatch stack([1:2, 1:3])
16161630
@test_throws DimensionMismatch stack([1:2, 1:3]; dims=1)

0 commit comments

Comments
 (0)