Skip to content

Commit fa7e79c

Browse files
committed
use Val(dims) and Iterators.peel for _dim_stack
1 parent a9db4de commit fa7e79c

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

base/abstractarray.jl

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2662,44 +2662,41 @@ _typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::IteratorSize, A) where {T,S}
26622662
_vec_axis(A, ax=_axes(A)) = length(ax) == 1 ? only(ax) : OneTo(prod(length, ax; init=1))
26632663

26642664
function _dim_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S}
2665-
xit = iterate(A)
2665+
xit = Iterators.peel(A)
26662666
nothing === xit && return _empty_stack(dims, T, S, A)
2667-
x1, _ = xit
2667+
x1, xrest = xit
26682668
ax1 = _axes(x1)
26692669
N1 = length(ax1)+1
26702670
dims in 1:N1 || throw(ArgumentError("cannot stack slices ndims(x) = $(N1-1) along dims = $dims"))
26712671

26722672
newaxis = _vec_axis(A)
2673-
outax = ntuple(d -> d==dims ? newaxis : _axes(x1)[d - (d>dims)], N1)
2673+
outax = ntuple(d -> d==dims ? newaxis : ax1[d - (d>dims)], N1)
26742674
B = similar(_first_array(x1, A), T, outax...)
26752675

2676-
iit = iterate(newaxis)
2677-
while xit !== nothing
2678-
x, state = xit
2679-
i, istate = iit
2680-
_stack_size_check(x, ax1)
2681-
@inbounds if dims==1
2682-
inds1 = ntuple(d -> d==1 ? i : Colon(), N1)
2683-
if x isa AbstractArray
2684-
B[inds1...] = x
2685-
else
2686-
copyto!(view(B, inds1...), x)
2687-
end
2688-
else
2689-
inds = ntuple(d -> d==dims ? i : Colon(), N1)
2690-
if x isa AbstractArray
2691-
B[inds...] = x
2692-
else
2693-
# This is where the type-instability of inds hurts, but it is pretty exotic:
2694-
copyto!(view(B, inds...), x)
2695-
end
2696-
end
2697-
xit = iterate(A, state)
2698-
iit = iterate(newaxis, istate)
2676+
if dims == 1
2677+
_dim_stack!(Val(1), B, x1, xrest)
2678+
elseif dims == 2
2679+
_dim_stack!(Val(2), B, x1, xrest)
2680+
else
2681+
_dim_stack!(Val(dims), B, x1, xrest)
26992682
end
27002683
B
27012684
end
27022685

2686+
function _dim_stack!(::Val{dims}, B::AbstractArray, x1, xrest) where {dims}
2687+
before = ntuple(d -> Colon(), dims - 1)
2688+
after = ntuple(d -> Colon(), ndims(B) - dims)
2689+
2690+
i = firstindex(B, dims)
2691+
copyto!(view(B, before..., i, after...), x1)
2692+
2693+
for x in xrest
2694+
_stack_size_check(x, _axes(x1))
2695+
i += 1
2696+
@inbounds copyto!(view(B, before..., i, after...), x)
2697+
end
2698+
end
2699+
27032700
@inline function _stack_size_check(x, ax1::Tuple)
27042701
if _axes(x) != ax1
27052702
uax1 = UnitRange.(ax1)

test/abstractarray.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,20 @@ end
16061606
# Trivial, because numbers are iterable:
16071607
@test stack(abs2, 1:3) == [1, 4, 9] == collect(Iterators.flatten(abs2(x) for x in 1:3))
16081608

1609+
# Allocation tests
1610+
xv = [rand(10) for _ in 1:100]
1611+
xt = Tuple.(xv)
1612+
for dims in (1, 2, :)
1613+
@test stack(xv; dims) == stack(xt; dims)
1614+
@test 9000 > @allocated stack(xv; dims)
1615+
@test 9000 > @allocated stack(xt; dims)
1616+
end
1617+
xr = (reshape(1:1000,10,10,10) for _ = 1:1000)
1618+
for dims in (1, 2, 3, :)
1619+
stack(xr; dims)
1620+
@test 8.1e6 > @allocated stack(xr; dims)
1621+
end
1622+
16091623
# Mismatched sizes
16101624
@test_throws DimensionMismatch stack([1:2, 1:3])
16111625
@test_throws DimensionMismatch stack([1:2, 1:3]; dims=1)

0 commit comments

Comments
 (0)