Skip to content

Commit c2382fd

Browse files
committed
make (h/v/hv)cat more type stable
1 parent e575b5f commit c2382fd

File tree

2 files changed

+33
-39
lines changed

2 files changed

+33
-39
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LinearMaps"
22
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
3-
version = "3.5.0"
3+
version = "3.5.1"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/blockmap.jl

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ BlockMap{T}(maps::As, rows::Rs) where {T, As<:LinearMapTuple, Rs} =
2424

2525
MulStyle(A::BlockMap) = MulStyle(A.maps...)
2626

27+
function _getranges(maps, dim, inds::NTuple{N,Int}=ntuple(identity, Val(length(maps)))) where {N}
28+
sizes = ntuple(i -> (@inbounds size(maps[inds[i]], dim)), Val(N))
29+
ends = cumsum(sizes)
30+
starts = (1, (1 .+ Base.front(ends))...)
31+
return UnitRange.(starts, ends)
32+
end
33+
2734
"""
2835
rowcolranges(maps, rows)
2936
@@ -32,24 +39,19 @@ map in `maps`, according to its position in a virtual matrix representation of t
3239
block linear map obtained from `hvcat(rows, maps...)`.
3340
"""
3441
function rowcolranges(maps, rows)
35-
rowranges = ntuple(n->1:0, Val(length(rows)))
36-
colranges = ntuple(n->1:0, Val(length(maps)))
37-
mapind = 0
38-
rowstart = 1
39-
for (i, row) in enumerate(rows)
40-
mapind += 1
41-
rowend = rowstart + Int(size(maps[mapind], 1))::Int - 1
42-
rowranges = Base.setindex(rowranges, rowstart:rowend, i)
43-
colstart = 1
44-
colend = Int(size(maps[mapind], 2))::Int
45-
colranges = Base.setindex(colranges, colstart:colend, mapind)
46-
for colind in 2:row
47-
mapind += 1
48-
colstart = colend + 1
49-
colend += Int(size(maps[mapind], 2))::Int
50-
colranges = Base.setindex(colranges, colstart:colend, mapind)
51-
end
52-
rowstart = rowend + 1
42+
# find indices of the row-wise first maps
43+
firstmapinds = cumsum((1, Base.front(rows)...))
44+
# compute rowranges from size(map, 1) of the row-wise first maps
45+
rowranges = _getranges(maps, 1, firstmapinds)
46+
47+
# compute ranges from size(map, 1) as if all in one row
48+
temp = _getranges(maps, 2)
49+
# introduce "line breaks"
50+
colranges = ntuple(Val(length(maps))) do i
51+
# for each map find the index of the respective row-wise first map
52+
@inbounds firstmapind = firstmapinds[something(findlast(<=(i), firstmapinds), 1)]
53+
# shift ranges by the first col-index of the row-wise first map
54+
return @inbounds temp[i] .- first(temp[firstmapind]) .+ 1
5355
end
5456
return rowranges, colranges
5557
end
@@ -82,17 +84,13 @@ function Base.hcat(As::Union{LinearMap, UniformScaling, AbstractVecOrMat}...)
8284
T = promote_type(map(eltype, As)...)
8385
nbc = length(As)
8486

85-
nrows = -1
8687
# find first non-UniformScaling to detect number of rows
87-
for A in As
88-
if !(A isa UniformScaling)
89-
nrows = size(A, 1)
90-
break
91-
end
92-
end
93-
@assert nrows != -1
88+
j = findfirst(A -> !isa(A, UniformScaling), As)
9489
# this should not happen, function should only be called with at least one LinearMap
95-
return BlockMap{T}(promote_to_lmaps(ntuple(i->nrows, nbc), 1, 1, As...), (nbc,))
90+
@assert !isnothing(j)
91+
@inbounds nrows = size(As[j], 1)
92+
93+
return BlockMap{T}(promote_to_lmaps(ntuple(_ -> nrows, Val(nbc)), 1, 1, As...), (nbc,))
9694
end
9795

9896
############
@@ -124,18 +122,14 @@ function Base.vcat(As::Union{LinearMap,UniformScaling,AbstractVecOrMat}...)
124122
T = promote_type(map(eltype, As)...)
125123
nbr = length(As)
126124

127-
ncols = -1
128-
# find first non-UniformScaling to detect number of columns
129-
for A in As
130-
if !(A isa UniformScaling)
131-
ncols = size(A, 2)
132-
break
133-
end
134-
end
135-
@assert ncols != -1
125+
# find first non-UniformScaling to detect number of rows
126+
j = findfirst(A -> !isa(A, UniformScaling), As)
136127
# this should not happen, function should only be called with at least one LinearMap
137-
rows = ntuple(i->1, nbr)
138-
return BlockMap{T}(promote_to_lmaps(ntuple(i->ncols, nbr), 1, 2, As...), rows)
128+
@assert !isnothing(j)
129+
@inbounds ncols = size(As[j], 2)
130+
131+
rows = ntuple(_ -> 1, Val(nbr))
132+
return BlockMap{T}(promote_to_lmaps(ntuple(_ -> ncols, Val(nbr)), 1, 2, As...), rows)
139133
end
140134

141135
############

0 commit comments

Comments
 (0)