Skip to content

Commit a35f2e2

Browse files
committed
Slightly reorganize composition, improve coverage
1 parent 9fc8ed1 commit a35f2e2

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

src/composition.jl

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::AbstractVector
2222
MulStyle(A::CompositeMap) = MulStyle(A.maps...) === TwoArg() ? TwoArg() : ThreeArg()
2323

2424
# basic methods
25-
Base.size(A::CompositeMap) = (size(A.maps[end], 1), size(A.maps[1], 2))
26-
Base.axes(A::CompositeMap) = (axes(A.maps[end])[1], axes(A.maps[1])[2])
25+
Base.size(A::CompositeMap) = (size(last(A.maps), 1), size(first(A.maps), 2))
26+
Base.axes(A::CompositeMap) = (axes(last(A.maps))[1], axes(first(A.maps))[2])
2727
Base.isreal(A::CompositeMap) = all(isreal, A.maps) # sufficient but not necessary
2828

2929
# the following rules are sufficient but not necessary
@@ -32,17 +32,17 @@ for (f, _f, g) in ((:issymmetric, :_issymmetric, :transpose),
3232
@eval begin
3333
LinearAlgebra.$f(A::CompositeMap) = $_f(A.maps)
3434
$_f(maps::Tuple{}) = true
35-
$_f(maps::Tuple{<:LinearMap}) = $f(maps[1])
35+
$_f(maps::Tuple{<:LinearMap}) = $f(first(maps))
3636
$_f(maps::LinearMapTuple) =
37-
maps[end] == $g(maps[1]) && $_f(Base.front(Base.tail(maps)))
37+
maps[end] == $g(first(maps)) && $_f(Base.front(Base.tail(maps)))
3838
function $_f(maps::LinearMapVector)
3939
n = length(maps)
4040
if n == 0
4141
return true
4242
elseif n == 1
43-
return ($f(maps[1]))::Bool
43+
return ($f(first(maps)))::Bool
4444
else
45-
return ((maps[end] == $g(maps[1]))::Bool && $_f(@views maps[2:end-1]))
45+
return ((last(maps) == $g(first(maps)))::Bool && $_f(@views maps[begin+1:end-1]))
4646
end
4747
end
4848
# since the introduction of ScaledMap, the following cases cannot occur
@@ -71,10 +71,10 @@ function _isposdef(maps::LinearMapVector)
7171
if n == 0
7272
return true
7373
elseif n == 1
74-
return isposdef(maps[1])
74+
return isposdef(first(maps))
7575
else
76-
return (maps[end] == adjoint(maps[1]) || maps[end] == maps[1]) &&
77-
isposdef(maps[1]) && _isposdef(maps[2:end-1])
76+
return (last(maps) == adjoint(first(maps)) || last(maps) == first(maps)) &&
77+
isposdef(first(maps)) && _isposdef(maps[begin+1:end-1])
7878
end
7979
end
8080

@@ -180,12 +180,11 @@ function _unsafe_mul!(y, A::CompositeMap, x::AbstractVector)
180180
return y
181181
end
182182
_unsafe_mul!(y, A::CompositeMap, x::AbstractMatrix) = _compositemul!(y, A, x)
183+
_unsafe_mul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap}}, x::AbstractVector) =
184+
_unsafe_mul!(y, A.maps[1], x)
185+
_unsafe_mul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap}}, X::AbstractMatrix) =
186+
_unsafe_mul!(y, A.maps[1], X)
183187

184-
function _compositemul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap}}, x,
185-
source = nothing,
186-
dest = nothing)
187-
return _unsafe_mul!(y, A.maps[1], x)
188-
end
189188
function _compositemul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}}, x,
190189
source = nothing,
191190
dest = nothing)
@@ -206,9 +205,9 @@ function _compositemul!(y, A::CompositeMap{<:Any,<:LinearMapVector}, x,
206205
dest = nothing)
207206
N = length(A.maps)
208207
if N == 1
209-
return _unsafe_mul!(y, A.maps[1], x)
208+
return _unsafe_mul!(y, A.maps[begin], x)
210209
elseif N == 2
211-
return _unsafe_mul!(y, A.maps[2] * A.maps[1], x)
210+
return _unsafe_mul!(y, A.maps[end] * A.maps[begin], x)
212211
else
213212
return _compositemulN!(y, A, x, source, dest)
214213
end
@@ -218,19 +217,21 @@ function _compositemulN!(y, A::CompositeMap, x,
218217
src = nothing,
219218
dst = nothing)
220219
N = length(A.maps) # ≥ 3
220+
n = n0 = firstindex(A.maps)
221221
source = isnothing(src) ?
222-
convert(AbstractArray, A.maps[1] * x) :
223-
_unsafe_mul!(src, A.maps[1], x)
222+
convert(AbstractArray, A.maps[n] * x) :
223+
_unsafe_mul!(src, A.maps[n], x)
224+
n += 1
224225
dest = isnothing(dst) ?
225-
convert(AbstractArray, A.maps[2] * source) :
226-
_unsafe_mul!(dst, A.maps[2], source)
226+
convert(AbstractArray, A.maps[n] * source) :
227+
_unsafe_mul!(dst, A.maps[n], source)
227228
dest, source = source, dest # alternate dest and source
228-
for n in 3:N-1
229+
for n in (n0+2):N-1
229230
dest = _resize(dest, (size(A.maps[n], 1), size(x)[2:end]...))
230231
_unsafe_mul!(dest, A.maps[n], source)
231232
dest, source = source, dest # alternate dest and source
232233
end
233-
_unsafe_mul!(y, A.maps[N], source)
234+
_unsafe_mul!(y, last(A.maps), source)
234235
return y
235236
end
236237

test/composition.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ using LinearMaps: LinearMapVector, LinearMapTuple
66
@test F == LinearMap(cumsum, reverse cumsum reverse, 10; ismutating=false)
77
FC = LinearMap{ComplexF64}(cumsum, reverse cumsum reverse, 10; ismutating=false)
88
FCM = @inferred LinearMaps.CompositeMap{ComplexF64}((FC,))
9+
FCMv = @inferred LinearMaps.CompositeMap{ComplexF64}([FC,])
10+
FCiip = LinearMaps.CompositeMap{ComplexF64}((LinearMap{ComplexF64}(cumsum!, 10),))
11+
FCM2v = @inferred LinearMaps.CompositeMap{ComplexF64}([FC, FC])
12+
FCM2iipv = @inferred LinearMaps.CompositeMap{ComplexF64}([FCiip, FCiip])
913
L = LowerTriangular(ones(10,10))
1014
@test_throws DimensionMismatch F * LinearMap(zeros(2,2))
1115
@test_throws ErrorException LinearMaps.CompositeMap{Float64}((FC, LinearMap(rand(10,10))))
@@ -17,7 +21,8 @@ using LinearMaps: LinearMapVector, LinearMapTuple
1721
N = @inferred LinearMap(B)
1822
v = rand(ComplexF64, 10)
1923
α = rand(ComplexF64)
20-
@test FCM * v == F * v
24+
@test FCiip * v == FCM * v == F * v == FCMv * v
25+
@test FCM2v * v == F * F * v == FCM2iipv * v
2126
@test @inferred (F * F) * v == @inferred F * (F * v)
2227
@test @inferred (F * A) * v == @inferred F * (A * v)
2328
@test LinearMaps._compositemul!(zero(F * A * v), F * A, v, zero(A*v)) (F * A) * v
@@ -123,6 +128,7 @@ using LinearMaps: LinearMapVector, LinearMapTuple
123128
w1 = im.*ones(ComplexF64, prod(sizes[1]))
124129
for i = N:-1:1
125130
v2 = prod(Lf[i:N]) * ones(prod(sizes[1]))
131+
i < N && (y2 = LinearMaps._compositemul!(zero(v2), prod(Lf[i:N]), ones(prod(sizes[1]))))
126132
u2 = transpose(LinearMap(prod(Lt[N:-1:i]))) * ones(prod(sizes[1]))
127133
w2 = adjoint(LinearMap(prod(Lc[N:-1:i]))) * ones(prod(sizes[1]))
128134

@@ -131,6 +137,7 @@ using LinearMaps: LinearMapVector, LinearMapTuple
131137
w1 = adjoint(Lc[i]) * w1
132138

133139
@test v1 == v2
140+
i < N && @test v2 == y2
134141
@test u1 == u2
135142
@test w1 == w2
136143
end

0 commit comments

Comments
 (0)