Skip to content

Commit d5111de

Browse files
committed
use multiple dispatch in generic multiplication
1 parent 33395aa commit d5111de

File tree

3 files changed

+34
-53
lines changed

3 files changed

+34
-53
lines changed

src/LinearMaps.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -258,16 +258,16 @@ end
258258
_unsafe_mul!(y, A::MapOrVecOrMat, x) = mul!(y, A, x)
259259
_unsafe_mul!(y, A::AbstractVecOrMat, x, α, β) = mul!(y, A, x, α, β)
260260
_unsafe_mul!(y::AbstractVecOrMat, A::LinearMap, x::AbstractVector, α, β) =
261-
_generic_mapvec_mul!(y, A, x, α, β)
261+
_generic_map_mul!(y, A, x, α, β)
262262
_unsafe_mul!(y::AbstractMatrix, A::LinearMap, x::AbstractMatrix) =
263-
_generic_mapmat_mul!(y, A, x)
263+
_generic_map_mul!(y, A, x)
264264
_unsafe_mul!(y::AbstractMatrix, A::LinearMap, x::AbstractMatrix, α::Number, β::Number) =
265-
_generic_mapmat_mul!(y, A, x, α, β)
266-
_unsafe_mul!(Y::AbstractMatrix, A::LinearMap, s::Number) = _generic_mapnum_mul!(Y, A, s)
265+
_generic_map_mul!(y, A, x, α, β)
266+
_unsafe_mul!(Y::AbstractMatrix, A::LinearMap, s::Number) = _generic_map_mul!(Y, A, s)
267267
_unsafe_mul!(Y::AbstractMatrix, A::LinearMap, s::Number, α::Number, β::Number) =
268-
_generic_mapnum_mul!(Y, A, s, α, β)
268+
_generic_map_mul!(Y, A, s, α, β)
269269

270-
function _generic_mapvec_mul!(y, A, x, α, β)
270+
function _generic_map_mul!(y, A, x::AbstractVector, α, β)
271271
# this function needs to call mul! for, e.g., AdjointMap{...,<:CustomMap}
272272
if isone(α)
273273
iszero(β) && return mul!(y, A, x)
@@ -294,21 +294,19 @@ function _generic_mapvec_mul!(y, A, x, α, β)
294294
return y
295295
end
296296
end
297-
298-
function _generic_mapmat_mul!(Y, A, X)
297+
function _generic_map_mul!(Y, A, X::AbstractMatrix)
299298
for (Xi, Yi) in zip(eachcol(X), eachcol(Y))
300299
mul!(Yi, A, Xi)
301300
end
302301
return Y
303302
end
304-
function _generic_mapmat_mul!(Y, A, X, α, β)
303+
function _generic_map_mul!(Y, A, X::AbstractMatrix, α, β)
305304
for (Xi, Yi) in zip(eachcol(X), eachcol(Y))
306305
mul!(Yi, A, Xi, α, β)
307306
end
308307
return Y
309308
end
310-
311-
function _generic_mapnum_mul!(Y, A, s)
309+
function _generic_map_mul!(Y, A, s::Number)
312310
T = promote_type(eltype(A), typeof(s))
313311
ax2 = axes(A)[2]
314312
xi = zeros(T, ax2)
@@ -319,7 +317,7 @@ function _generic_mapnum_mul!(Y, A, s)
319317
end
320318
return Y
321319
end
322-
function _generic_mapnum_mul!(Y, A, s, α, β)
320+
function _generic_map_mul!(Y, A, s::Number, α, β)
323321
T = promote_type(eltype(A), typeof(s))
324322
ax2 = axes(A)[2]
325323
xi = zeros(T, ax2)

src/transpose.jl

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -50,44 +50,24 @@ Base.:(==)(A::LinearMap, B::TransposeMap) = issymmetric(A) && B.lmap == A
5050
Base.:(==)(A::LinearMap, B::AdjointMap) = ishermitian(A) && B.lmap == A
5151

5252
# multiplication with vector/matrices
53-
# # TransposeMap
54-
_unsafe_mul!(y::AbstractVecOrMat, A::TransposeMap, x::AbstractVector) =
55-
issymmetric(A.lmap) ?
56-
_unsafe_mul!(y, A.lmap, x) : error("transpose not implemented for $(A.lmap)")
57-
_unsafe_mul!(y::AbstractMatrix, A::TransposeMap, x::AbstractMatrix) =
58-
issymmetric(A.lmap) ?
59-
_unsafe_mul!(y, A.lmap, x) : _generic_mapmat_mul!(y, A, x)
60-
_unsafe_mul!(y::AbstractMatrix, A::TransposeMap, x::Number) =
61-
issymmetric(A.lmap) ?
62-
_unsafe_mul!(y, A.lmap, x) : _generic_mapnum_mul!(y, A, x)
63-
_unsafe_mul!(y::AbstractVecOrMat, A::TransposeMap, x::AbstractVector, α::Number, β::Number)=
64-
issymmetric(A.lmap) ?
65-
_unsafe_mul!(y, A.lmap, x, α, β) : _generic_mapvec_mul!(y, A, x, α, β)
66-
_unsafe_mul!(y::AbstractMatrix, A::TransposeMap, x::AbstractMatrix, α::Number, β::Number) =
67-
issymmetric(A.lmap) ?
68-
_unsafe_mul!(y, A.lmap, x, α, β) : _generic_mapmat_mul!(y, A, x, α, β)
69-
_unsafe_mul!(y::AbstractMatrix, A::TransposeMap, x::Number, α::Number, β::Number) =
70-
issymmetric(A.lmap) ?
71-
_unsafe_mul!(y, A.lmap, x, α, β) : _generic_mapnum_mul!(y, A, x, α, β)
72-
# # AdjointMap
73-
_unsafe_mul!(y::AbstractVecOrMat, A::AdjointMap, x::AbstractVector) =
74-
ishermitian(A.lmap) ?
75-
_unsafe_mul!(y, A.lmap, x) : error("adjoint not implemented for $(A.lmap)")
76-
_unsafe_mul!(y::AbstractMatrix, A::AdjointMap, x::AbstractMatrix) =
77-
ishermitian(A.lmap) ?
78-
_unsafe_mul!(y, A.lmap, x) : _generic_mapmat_mul!(y, A, x)
79-
_unsafe_mul!(y::AbstractMatrix, A::AdjointMap, x::Number) =
80-
ishermitian(A.lmap) ?
81-
_unsafe_mul!(y, A.lmap, x) : _generic_mapnum_mul!(y, A, x)
82-
_unsafe_mul!(y::AbstractVecOrMat, A::AdjointMap, x::AbstractVector, α::Number, β::Number) =
83-
ishermitian(A.lmap) ?
84-
_unsafe_mul!(y, A.lmap, x, α, β) : _generic_mapvec_mul!(y, A, x, α, β)
85-
_unsafe_mul!(y::AbstractMatrix, A::AdjointMap, x::AbstractMatrix, α::Number, β::Number) =
86-
ishermitian(A.lmap) ?
87-
_unsafe_mul!(y, A.lmap, x, α, β) : _generic_mapmat_mul!(y, A, x, α, β)
88-
_unsafe_mul!(y::AbstractMatrix, A::AdjointMap, x::Number, α::Number, β::Number) =
89-
ishermitian(A.lmap) ?
90-
_unsafe_mul!(y, A.lmap, x, α, β) : _generic_mapnum_mul!(y, A, x, α, β)
53+
for (Typ, prop, text) in ((AdjointMap, ishermitian, "adjoint"), (TransposeMap, issymmetric, "transpose"))
54+
@eval _unsafe_mul!(y::AbstractVecOrMat, A::$Typ, x::AbstractVector) =
55+
$prop(A.lmap) ?
56+
_unsafe_mul!(y, A.lmap, x) : error($text * " not implemented for $(A.lmap)")
57+
@eval _unsafe_mul!(y::AbstractVecOrMat, A::$Typ, x::AbstractVector, α::Number, β::Number) =
58+
$prop(A.lmap) ?
59+
_unsafe_mul!(y, A.lmap, x, α, β) : _generic_map_mul!(y, A, x, α, β)
60+
61+
for In in (Number, AbstractMatrix)
62+
@eval _unsafe_mul!(y::AbstractMatrix, A::$Typ, x::$In) =
63+
$prop(A.lmap) ?
64+
_unsafe_mul!(y, A.lmap, x) : _generic_map_mul!(y, A, x)
65+
66+
@eval _unsafe_mul!(y::AbstractMatrix, A::$Typ, x::$In, α::Number, β::Number) =
67+
ishermitian(A.lmap) ?
68+
_unsafe_mul!(y, A.lmap, x, α, β) : _generic_map_mul!(y, A, x, α, β)
69+
end
70+
end
9171

9272
# # ConjugateMap
9373
const ConjugateMap = AdjointMap{<:Any, <:TransposeMap}
@@ -104,9 +84,12 @@ for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractM
10484
end
10585
end
10686
end
87+
function _unsafe_mul!(y::AbstractMatrix, Ac::ConjugateMap, x::Number)
88+
return _conjmul!(y, Ac.lmap.lmap, x)
89+
end
10790

10891
# multiplication helper function
109-
_conjmul!(y, A, x) = conj!(mul!(y, A, conj(x)))
92+
_conjmul!(y, A, x) = conj!(_unsafe_mul!(y, A, conj(x)))
11093
function _conjmul!(y, A, x::AbstractVector, α, β)
11194
xca = conj!(x * α)
11295
z = A * xca

test/linearmaps.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ end
6363
w = rand(ComplexF64, 10); W = rand(ComplexF64, 10, 3)
6464
F(v) == F*v
6565
@test mul!(w, F, v) === w == F * v
66-
@test_throws ErrorException F' * v
66+
@test_throws ErrorException("transpose not implemented for "*sprint((t, s) -> show(t, "text/plain", s), F)) F' * v
6767
@test_throws ErrorException transpose(F) * v
68-
@test_throws ErrorException mul!(w, adjoint(FC), v)
68+
@test_throws ErrorException("adjoint not implemented for "*sprint((t, s) -> show(t, "text/plain", s), FC)) mul!(w, adjoint(FC), v)
6969
@test_throws ErrorException mul!(w, transpose(F), v)
7070
FM = convert(AbstractMatrix, F)
7171
L = LowerTriangular(ones(10, 10))

0 commit comments

Comments
 (0)