Skip to content

Simplify multiple dispatch #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions docs/src/custom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ Base.size(A::MyFillMap) = A.size
# on the level of `mul!` etc. Factoring out dimension checking is done to minimise overhead
# caused by repetitive checking.

function LinearMaps._unsafe_mul!(y::AbstractVecOrMat, A::MyFillMap, x::AbstractVector)
# !!! note
# Multiple dispatch at the `_unsafe_mul!` level happens via the second (the map type)
# and the third arguments (`AbstractVector` or `AbstractMatrix`, see the
# [Application to matrices](@ref) section below). For that reason, the output argument
# can remain type-unbound.

function LinearMaps._unsafe_mul!(y, A::MyFillMap, x::AbstractVector)
return fill!(y, iszero(A.λ) ? zero(eltype(y)) : A.λ*sum(x))
end

Expand All @@ -45,7 +51,8 @@ end
# * in-place multiplication with vectors `mul!(y, A, x)`,
# * in-place multiply-and-add with vectors `mul!(y, A, x, α, β)`,
# * in-place multiplication and multiply-and-add with matrices `mul!(Y, A, X, α, β)`,
# * conversion to a (sparse) matrix `Matrix(A)` and `sparse(A)`.
# * conversion to a (sparse) matrix `Matrix(A)` and `sparse(A)`,
# * complete slicing of columns (and rows if the adjoint action is defined).

A = MyFillMap(5.0, (3, 3)); x = ones(3); sum(x)

Expand Down Expand Up @@ -80,20 +87,14 @@ using BenchmarkTools

# The second benchmark indicates the allocation of an intermediate vector `z`
# which stores the result of `A*x` before it gets scaled and added to (the scaled)
# `y = zeros(3)`. For that reason, it is beneficial to provide a custom "5-arg `mul!`"
# if you can avoid the allocation of an intermediate vector. To indicate that there
# exists an allocation-free implementation, you should set the `MulStyle` trait,
# whose default is `ThreeArg()`.
# `y = zeros(3)`. For that reason, it is beneficial to provide a custom "5-arg
# `_unsafe_mul!`" if you can avoid the allocation of an intermediate vector. To indicate
# that there exists an allocation-free implementation of multiply-and-add, you should set
# the `MulStyle` trait, whose default is `ThreeArg()`, to `FiveArg()`.

LinearMaps.MulStyle(A::MyFillMap) = FiveArg()

function LinearMaps._unsafe_mul!(
y::AbstractVecOrMat,
A::MyFillMap,
x::AbstractVector,
α::Number,
β::Number
)
function LinearMaps._unsafe_mul!(y, A::MyFillMap, x::AbstractVector, α, β)
if iszero(α)
!isone(β) && rmul!(y, β)
return y
Expand Down Expand Up @@ -159,7 +160,7 @@ try MyFillMap(5.0, (3, 4))' * ones(3) catch e println(e) end
# wrapped map types; for instance,

function LinearMaps._unsafe_mul!(
y::AbstractVecOrMat,
y,
transA::LinearMaps.TransposeMap{<:Any,<:MyFillMap},
x::AbstractVector
)
Expand All @@ -183,7 +184,7 @@ MyFillMap(5.0, (3, 4))' * ones(3)
Base.delete_method(
first(methods(
LinearMaps._unsafe_mul!,
(AbstractVecOrMat, LinearMaps.TransposeMap{<:Any,<:MyFillMap}, AbstractVector))
(Any, LinearMaps.TransposeMap{<:Any,<:MyFillMap}, AbstractVector))
)
)

Expand Down Expand Up @@ -222,13 +223,13 @@ mul!(similar(x)', x', A)
# Calling the in-place multiplication function `mul!(Y, A, X)` for matrices,
# however, does compute the columnwise action of `A` on `X` and stores the
# result in `Y`. In case there is a more efficient implementation for the
# matrix application, you can provide `mul!` methods with signature
# `mul!(Y::AbstractMatrix, A::MyFillMap, X::AbstractMatrix)`, and, depending
# matrix application, you can provide `_unsafe_mul!` methods with signature
# `_unsafe_mul!(Y, A::MyFillMap, X::AbstractMatrix)`, and, depending
# on the chosen path to handle adjoints/transposes, corresponding methods
# for wrapped maps of type `AdjointMap` or `TransposeMap`, plus potentially
# corresponding 5-arg `mul!` methods. This may seem like a lot of methods to
# be implemented, but note that adding such methods is only necessary/recommended
# for performance.
# for increased performance.

# ## Computing a matrix representation

Expand All @@ -250,7 +251,7 @@ M = Matrix{eltype(F)}(undef, size(F))
# for instance (as before, size checks need not be included here since they are handled by
# the corresponding `LinearAlgebra.mul!` method):

LinearMaps._unsafe_mul!(M::AbstractMatrix, A::MyFillMap, s::Number) = fill!(M, A.λ*s)
LinearMaps._unsafe_mul!(M, A::MyFillMap, s::Number) = fill!(M, A.λ*s)
@benchmark Matrix($F)

#-
Expand Down
2 changes: 1 addition & 1 deletion docs/src/history.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
For custom `LinearMap` subtypes, there are now two options:
1. In case your type is invariant under adjoint/transposition (i.e.,
`adjoint(L::MyLinearMap)::MyLinearMap` similar to, for instance,
`LinearCombination`s or `CompositeMap`s, `At_mul_B!` and `Ac_mul_B!` do
`LinearCombination`s or `CompositeMap`s), `At_mul_B!` and `Ac_mul_B!` do
not require any replacement! Rather, multiplication by `L'` is, in this case,
handled by `mul!(y, L::MyLinearMap, x[, α, β])`.
2. Otherwise, you will need to define `mul!` methods with the signature
Expand Down
1 change: 1 addition & 0 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ Base.:*(::AbstractMatrix,::LinearMap)
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::AbstractVector)
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::AbstractVector,::Number,::Number)
LinearAlgebra.mul!(::AbstractMatrix,::AbstractMatrix,::LinearMap)
LinearAlgebra.mul!(::AbstractMatrix,::AbstractMatrix,::LinearMap,::Number,::Number)
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::Number)
LinearAlgebra.mul!(::AbstractMatrix,::LinearMap,::Number,::Number,::Number)
*(::LinearAlgebra.AdjointAbsVec,::LinearMap)
Expand Down
23 changes: 11 additions & 12 deletions src/LinearMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ abstract type LinearMap{T} end

const MapOrVecOrMat{T} = Union{LinearMap{T}, AbstractVecOrMat{T}}
const MapOrMatrix{T} = Union{LinearMap{T}, AbstractMatrix{T}}
const TransposeAbsVecOrMat{T} = Transpose{T,<:AbstractVecOrMat}
const RealOrComplex = Union{Real, Complex}

const LinearMapTuple = Tuple{Vararg{LinearMap}}
Expand Down Expand Up @@ -78,9 +79,9 @@ function check_dim_mul(C, A, B)
end

_front(As::Tuple) = Base.front(As)
_front(As::AbstractVector) = @inbounds @views As[1:end-1]
_front(As::AbstractVector) = @inbounds @views As[begin:end-1]
_tail(As::Tuple) = Base.tail(As)
_tail(As::AbstractVector) = @inbounds @views As[2:end]
_tail(As::AbstractVector) = @inbounds @views As[begin+1:end]

_combine(A::LinearMap, B::LinearMap) = tuple(A, B)
_combine(A::LinearMap, Bs::LinearMapTuple) = tuple(A, Bs...)
Expand Down Expand Up @@ -258,15 +259,13 @@ end

_unsafe_mul!(y, A::MapOrVecOrMat, x) = mul!(y, A, x)
_unsafe_mul!(y, A::AbstractVecOrMat, x, α, β) = mul!(y, A, x, α, β)
_unsafe_mul!(y::AbstractVecOrMat, A::LinearMap, x::AbstractVector, α, β) =
_generic_map_mul!(y, A, x, α, β)
_unsafe_mul!(y::AbstractMatrix, A::LinearMap, x::AbstractMatrix) =
_generic_map_mul!(y, A, x)
_unsafe_mul!(y::AbstractMatrix, A::LinearMap, x::AbstractMatrix, α::Number, β::Number) =
_generic_map_mul!(y, A, x, α, β)
_unsafe_mul!(Y::AbstractMatrix, A::LinearMap, s::Number) = _generic_map_mul!(Y, A, s)
_unsafe_mul!(Y::AbstractMatrix, A::LinearMap, s::Number, α::Number, β::Number) =
_generic_map_mul!(Y, A, s, α, β)
_unsafe_mul!(X, Y::AbstractMatrix, A::AbstractVecOrMat) = mul!(X, Y, A)
_unsafe_mul!(X, Y::AbstractMatrix, A::AbstractVecOrMat, α, β) = mul!(X, Y, A, α, β)
_unsafe_mul!(y, A::LinearMap, x::AbstractVector, α, β) = _generic_map_mul!(y, A, x, α, β)
_unsafe_mul!(y, A::LinearMap, x::AbstractMatrix) = _generic_map_mul!(y, A, x)
_unsafe_mul!(y, A::LinearMap, x::AbstractMatrix, α, β) = _generic_map_mul!(y, A, x, α, β)
_unsafe_mul!(Y, A::LinearMap, s::Number) = _generic_map_mul!(Y, A, s)
_unsafe_mul!(Y, A::LinearMap, s::Number, α, β) = _generic_map_mul!(Y, A, s, α, β)

function _generic_map_mul!(y, A, x::AbstractVector, α, β)
# this function needs to call mul! for, e.g., AdjointMap{...,<:CustomMap}
Expand Down Expand Up @@ -330,9 +329,9 @@ function _generic_map_mul!(Y, A, s::Number, α, β)
return Y
end

include("left.jl") # left multiplication by a transpose or adjoint vector
include("transpose.jl") # transposing linear maps
include("wrappedmap.jl") # wrap a matrix of linear map in a new type, thereby allowing to alter its properties
include("left.jl") # left multiplication by a transpose or adjoint vector
include("uniformscalingmap.jl") # the uniform scaling map, to be able to make linear combinations of LinearMap objects and multiples of I
include("linearcombination.jl") # defining linear combinations of linear maps
include("scaledmap.jl") # multiply by a (real or complex) scalar
Expand Down
21 changes: 10 additions & 11 deletions src/blockmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,13 +427,13 @@ end
# multiplication with vectors & matrices
############

for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix))
for In in (AbstractVector, AbstractMatrix)
@eval begin
function _unsafe_mul!(y::$Out, A::BlockMap, x::$In)
function _unsafe_mul!(y, A::BlockMap, x::$In)
require_one_based_indexing(y, x)
return _blockmul!(y, A, x, true, false)
end
function _unsafe_mul!(y::$Out, A::BlockMap, x::$In, α::Number, β::Number)
function _unsafe_mul!(y, A::BlockMap, x::$In, α, β)
require_one_based_indexing(y, x)
return _blockmul!(y, A, x, α, β)
end
Expand All @@ -442,11 +442,11 @@ for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractM
for (MT, transform) in ((:TransposeMap, :transpose), (:AdjointMap, :adjoint))
@eval begin
MapType = $MT{<:Any, <:BlockMap}
function _unsafe_mul!(y::$Out, wrapA::MapType, x::$In)
function _unsafe_mul!(y, wrapA::MapType, x::$In)
require_one_based_indexing(y, x)
return _transblockmul!(y, wrapA.lmap, x, true, false, $transform)
end
function _unsafe_mul!(y::$Out, wrapA::MapType, x::$In, α::Number, β::Number)
function _unsafe_mul!(y, wrapA::MapType, x::$In, α, β)
require_one_based_indexing(y, x)
return _transblockmul!(y, wrapA.lmap, x, α, β, $transform)
end
Expand All @@ -458,14 +458,13 @@ end
# multiplication with a scalar
############

function _unsafe_mul!(Y::AbstractMatrix, A::BlockMap, s::Number, α::Number=true, β::Number=false)
function _unsafe_mul!(Y, A::BlockMap, s::Number, α=true, β=false)
require_one_based_indexing(Y, s)
return _blockmul!(Y, A, s, α, β)
end
for (MT, transform) in ((:TransposeMap, :transpose), (:AdjointMap, :adjoint))
@eval begin
function _unsafe_mul!(Y::AbstractMatrix, wrapA::$MT{<:Any, <:BlockMap}, s::Number,
α::Number=true, β::Number=false)
function _unsafe_mul!(Y, wrapA::$MT{<:Any, <:BlockMap}, s::Number, α=true, β=false)
require_one_based_indexing(Y)
return _transblockmul!(Y, wrapA.lmap, s, α, β, $transform)
end
Expand Down Expand Up @@ -557,13 +556,13 @@ LinearAlgebra.transpose(A::BlockDiagonalMap{T}) where {T} =
Base.:(==)(A::BlockDiagonalMap, B::BlockDiagonalMap) =
(eltype(A) == eltype(B) && all(A.maps .== B.maps))

for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix), (Number, AbstractMatrix))
for In in (AbstractVector, AbstractMatrix, Number)
@eval begin
function _unsafe_mul!(y::$Out, A::BlockDiagonalMap, x::$In)
function _unsafe_mul!(y, A::BlockDiagonalMap, x::$In)
require_one_based_indexing(y, x)
return _blockscaling!(y, A, x)
end
function _unsafe_mul!(y::$Out, A::BlockDiagonalMap, x::$In, α::Number, β::Number)
function _unsafe_mul!(y, A::BlockDiagonalMap, x::$In, α, β)
require_one_based_indexing(y, x)
return _blockscaling!(y, A, x, α, β)
end
Expand Down
30 changes: 8 additions & 22 deletions src/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,15 @@ Base.:(==)(A::CompositeMap, B::CompositeMap) =
(eltype(A) == eltype(B) && all(A.maps .== B.maps))

# multiplication with vectors/matrices
_unsafe_mul!(y::AbstractVecOrMat, A::CompositeMap, x::AbstractVector) =
_compositemul!(y, A, x)
_unsafe_mul!(y::AbstractMatrix, A::CompositeMap, x::AbstractMatrix) =
_compositemul!(y, A, x)
_unsafe_mul!(y, A::CompositeMap, x::AbstractVector) = _compositemul!(y, A, x)
_unsafe_mul!(y, A::CompositeMap, x::AbstractMatrix) = _compositemul!(y, A, x)

function _compositemul!(y::AbstractVecOrMat,
A::CompositeMap{<:Any,<:Tuple{LinearMap}},
x::AbstractVecOrMat,
function _compositemul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap}}, x,
source = nothing,
dest = nothing)
return _unsafe_mul!(y, A.maps[1], x)
end
function _compositemul!(y::AbstractVecOrMat,
A::CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}},
x::AbstractVecOrMat,
function _compositemul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}}, x,
source = similar(y, (size(A.maps[1],1), size(x)[2:end]...)),
dest = nothing)
_unsafe_mul!(source, A.maps[1], x)
Expand All @@ -204,9 +198,7 @@ function _resize(dest::AbstractMatrix, sz::Tuple{<:Integer,<:Integer})
similar(dest, sz)
end

function _compositemul!(y::AbstractVecOrMat,
A::CompositeMap{<:Any,<:LinearMapTuple},
x::AbstractVecOrMat,
function _compositemul!(y, A::CompositeMap{<:Any,<:LinearMapTuple}, x,
source = similar(y, (size(A.maps[1],1), size(x)[2:end]...)),
dest = similar(y, (size(A.maps[2],1), size(x)[2:end]...)))
N = length(A.maps)
Expand All @@ -220,9 +212,7 @@ function _compositemul!(y::AbstractVecOrMat,
return y
end

function _compositemul!(y::AbstractVecOrMat,
A::CompositeMap{<:Any,<:LinearMapVector},
x::AbstractVecOrMat)
function _compositemul!(y, A::CompositeMap{<:Any,<:LinearMapVector}, x)
N = length(A.maps)
if N == 1
return _unsafe_mul!(y, A.maps[1], x)
Expand All @@ -233,17 +223,13 @@ function _compositemul!(y::AbstractVecOrMat,
end
end

function _compositemul2!(y::AbstractVecOrMat,
A::CompositeMap{<:Any,<:LinearMapVector},
x::AbstractVecOrMat,
function _compositemul2!(y, A::CompositeMap{<:Any,<:LinearMapVector}, x,
source = similar(y, (size(A.maps[1],1), size(x)[2:end]...)))
_unsafe_mul!(source, A.maps[1], x)
_unsafe_mul!(y, A.maps[2], source)
return y
end
function _compositemulN!(y::AbstractVecOrMat,
A::CompositeMap{<:Any,<:LinearMapVector},
x::AbstractVecOrMat,
function _compositemulN!(y, A::CompositeMap{<:Any,<:LinearMapVector}, x,
source = similar(y, (size(A.maps[1],1), size(x)[2:end]...)),
dest = similar(y, (size(A.maps[2],1), size(x)[2:end]...)))
N = length(A.maps)
Expand Down
14 changes: 7 additions & 7 deletions src/embeddedmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,25 @@ Base.:(==)(A::EmbeddedMap, B::EmbeddedMap) =
LinearAlgebra.adjoint(A::EmbeddedMap) = EmbeddedMap(adjoint(A.lmap), reverse(A.dims), A.cols, A.rows)
LinearAlgebra.transpose(A::EmbeddedMap) = EmbeddedMap(transpose(A.lmap), reverse(A.dims), A.cols, A.rows)

for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix))
@eval function _unsafe_mul!(y::$Out, A::EmbeddedMap, x::$In)
for In in (AbstractVector, AbstractMatrix)
@eval function _unsafe_mul!(y, A::EmbeddedMap, x::$In)
fill!(y, zero(eltype(y)))
_unsafe_mul!(selectdim(y, 1, A.rows), A.lmap, selectdim(x, 1, A.cols))
return y
end
@eval function _unsafe_mul!(y::$Out, A::EmbeddedMap, x::$In, alpha::Number, beta::Number)
LinearAlgebra._rmul_or_fill!(y, beta)
_unsafe_mul!(selectdim(y, 1, A.rows), A.lmap, selectdim(x, 1, A.cols), alpha, !iszero(beta))
@eval function _unsafe_mul!(y, A::EmbeddedMap, x::$In, α, β)
LinearAlgebra._rmul_or_fill!(y, β)
_unsafe_mul!(selectdim(y, 1, A.rows), A.lmap, selectdim(x, 1, A.cols), α, !iszero(β))
return y
end
end

function _unsafe_mul!(Y::AbstractMatrix, A::EmbeddedMap, x::Number)
function _unsafe_mul!(Y, A::EmbeddedMap, x::Number)
fill!(Y, zero(eltype(Y)))
_unsafe_mul!(view(Y, A.rows, A.cols), A.lmap, x)
return Y
end
function _unsafe_mul!(Y::AbstractMatrix, A::EmbeddedMap, x::Number, α::Number, β::Number)
function _unsafe_mul!(Y, A::EmbeddedMap, x::Number, α, β)
LinearAlgebra._rmul_or_fill!(Y, β)
_unsafe_mul!(view(Y, A.rows, A.cols), A.lmap, x, α, !iszero(β))
return Y
Expand Down
8 changes: 4 additions & 4 deletions src/fillmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ function Base.:(*)(A::FillMap, x::AbstractVector)
return fill(iszero(A.λ) ? zero(T) : A.λ*sum(x), A.size[1])
end

function _unsafe_mul!(y::AbstractVecOrMat, A::FillMap, x::AbstractVector)
function _unsafe_mul!(y, A::FillMap, x::AbstractVector)
return fill!(y, iszero(A.λ) ? zero(eltype(y)) : A.λ*sum(x))
end

_unsafe_mul!(Y::AbstractMatrix, A::FillMap, x::Number) = fill!(Y, A.λ*x)
function _unsafe_mul!(Y::AbstractMatrix, A::FillMap, x::Number, α::Number, β::Number)
_unsafe_mul!(Y, A::FillMap, x::Number) = fill!(Y, A.λ*x)
function _unsafe_mul!(Y, A::FillMap, x::Number, α, β)
LinearAlgebra._rmul_or_fill!(Y, β)
Y .+= A.λ*x*α
return Y
end

function _unsafe_mul!(y::AbstractVecOrMat, A::FillMap, x::AbstractVector, α::Number, β::Number)
function _unsafe_mul!(y, A::FillMap, x::AbstractVector, α, β)
if iszero(α)
!isone(β) && rmul!(y, β)
else
Expand Down
6 changes: 3 additions & 3 deletions src/functionmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ function Base.:(*)(A::TransposeFunctionMap, x::AbstractVector)
end
end

function _unsafe_mul!(y::AbstractVecOrMat, A::FunctionMap, x::AbstractVector)
function _unsafe_mul!(y, A::FunctionMap, x::AbstractVector)
ismutating(A) ? A.f(y, x) : copyto!(y, A.f(x))
return y
end

function _unsafe_mul!(y::AbstractVecOrMat, At::TransposeFunctionMap, x::AbstractVector)
function _unsafe_mul!(y, At::TransposeFunctionMap, x::AbstractVector)
A = At.lmap
(issymmetric(A) || (isreal(A) && ishermitian(A))) && return _unsafe_mul!(y, A, x)
if A.fc !== nothing
Expand All @@ -136,7 +136,7 @@ function _unsafe_mul!(y::AbstractVecOrMat, At::TransposeFunctionMap, x::Abstract
end
end

function _unsafe_mul!(y::AbstractVecOrMat, Ac::AdjointFunctionMap, x::AbstractVector)
function _unsafe_mul!(y, Ac::AdjointFunctionMap, x::AbstractVector)
A = Ac.lmap
ishermitian(A) && return _unsafe_mul!(y, A, x)
if A.fc !== nothing
Expand Down
4 changes: 2 additions & 2 deletions src/inversemap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ LinearAlgebra.ishermitian(imap::InverseMap) = ishermitian(imap.A)
LinearAlgebra.isposdef(imap::InverseMap) = isposdef(imap.A)

# Two separate methods to deal with method ambiguities
function _unsafe_mul!(y::AbstractVector, imap::InverseMap, x::AbstractVector)
function _unsafe_mul!(y, imap::InverseMap, x::AbstractVector)
imap.ldiv!(y, imap.A, x)
return y
end
function _unsafe_mul!(y::AbstractMatrix, imap::InverseMap, x::AbstractMatrix)
function _unsafe_mul!(y, imap::InverseMap, x::AbstractMatrix)
imap.ldiv!(y, imap.A, x)
return y
end
Loading