Skip to content

Make ismutable a type-parameter of FunctionMap #194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LinearMaps"
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
version = "3.9.0"
version = "3.10.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
15 changes: 15 additions & 0 deletions docs/src/history.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# Version history

## What's new in v3.10

* A new `MulStyle` trait called `TwoArg` has been added. It should be used for `LinearMap`s
that do not admit a mutating multiplication à la (3-arg or 5-arg) `mul!`, but only
out-of-place multiplication à la `A * x`. Products (aka `CompositeMap`s) and sums (aka
`LinearCombination`s) of `TwoArg`-`LinearMap`s now have memory-optimized multiplication
kernels. For instance, `A*B*C*x` for three `TwoArg`-`LinearMap`s `A`, `B` and `C` now
allocates only `y = C*x`, `z = B*y` and the result of `A*z`.
* The construction of function-based `LinearMap`s, typed `FunctionMap`, has been rearranged.
Additionally to the convenience constructor `LinearMap{T=Float64}(f, [fc,] M, N=M; kwargs...)`,
the newly exported constructor `FunctionMap{T,iip}(f, [fc], M, N; kwargs...)` is readily
available. Here, `iip` is either `true` or `false`, and encodes whether `f` (and `fc` if
present) are mutating functions. In the convenience constructor, this is determined via the
`Bool` keyword argument `ismutating` and may not be fully inferred.

## What's new in v3.9

* The application of `LinearMap`s to vectors operation, i.e., `(A,x) -> A*x = A(x)`, is now
Expand Down
11 changes: 2 additions & 9 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
pkg> add LinearMaps
```

in package mode, to be entered by typing `]` in the Julia REPL.
in package mode, which can be entered by typing `]` in the Julia REPL.

## Examples

Expand Down Expand Up @@ -100,16 +100,10 @@ KrylovKit.eigsolve(-A, size(A, 1), 3, :SR)

Arpack.eigs(Δ; nev=3, which=:LR)
ArnoldiMethod.partialeigen(ArnoldiMethod.partialschur(Δ; nev=3, which=ArnoldiMethod.LR())[1])
KrylovKit.eigsolve(x -> Δ*x, size(Δ, 1), 3, :LR)
```

In Julia v1.3 and above, the last line can be simplified to

```julia
KrylovKit.eigsolve(Δ, size(Δ, 1), 3, :LR)
```

leveraging the fact that objects of type `L <: LinearMap` are callable.
In the last line above we leverage the fact that objects of type `L <: LinearMap` are callable.

### Inverse map with conjugate gradient

Expand Down Expand Up @@ -156,7 +150,6 @@ result = C * tmp2
i.e. inside the CG solver for solving `Sx = b` we use CG to solve another inner linear
system.


## Philosophy

Several iterative linear algebra methods such as linear solvers or eigensolvers
Expand Down
11 changes: 7 additions & 4 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@ constructor described next.
Abstract supertype

```@docs
LinearMaps.LinearMap
LinearMap
```

### `FunctionMap`

Type for wrapping an arbitrary function that is supposed to implement the
matrix-vector product as a `LinearMap`; see above.

```@docs
FunctionMap
```

### `WrappedMap`

Type for wrapping an `AbstractMatrix` or `LinearMap` and to possible redefine
Expand Down Expand Up @@ -99,20 +103,19 @@ SparseArrays.blockdiag
Type for lazily representing constantly filled matrices.

```@docs
LinearMaps.FillMap
FillMap
```

### `EmbeddedMap`

Type for representing linear maps that are embedded in larger zero maps.


### `InverseMap`

Type for lazy inverse of another linear map.

```@docs
LinearMaps.InverseMap
InverseMap
```

### `KhatriRaoMap` and `FaceSplittingMap`
Expand Down
28 changes: 15 additions & 13 deletions src/LinearMaps.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module LinearMaps

export LinearMap, FillMap, InverseMap
export LinearMap, FunctionMap, FillMap, InverseMap
export ⊗, squarekron, kronsum, ⊕, sumkronsum, khatrirao, facesplitting

using LinearAlgebra
Expand Down Expand Up @@ -40,13 +40,19 @@ convert_to_lmaps(A) = (convert(LinearMap, A),)

abstract type MulStyle end

struct FiveArg <: MulStyle end
struct ThreeArg <: MulStyle end
struct FiveArg <: MulStyle end # types admit in-place multiplication and addition
struct ThreeArg <: MulStyle end # types "only" admit in-place multiplication
struct TwoArg <: MulStyle end # types "only" admit out-of-place multiplication

MulStyle(::FiveArg, ::FiveArg) = FiveArg()
MulStyle(::ThreeArg, ::FiveArg) = ThreeArg()
MulStyle(::FiveArg, ::ThreeArg) = ThreeArg()
MulStyle(::FiveArg, ::TwoArg) = TwoArg()
MulStyle(::ThreeArg, ::FiveArg) = ThreeArg()
MulStyle(::ThreeArg, ::ThreeArg) = ThreeArg()
MulStyle(::ThreeArg, ::TwoArg) = ThreeArg()
MulStyle(::TwoArg, ::FiveArg) = TwoArg()
MulStyle(::TwoArg, ::ThreeArg) = ThreeArg()
MulStyle(::TwoArg, ::TwoArg) = TwoArg()
MulStyle(::LinearMap) = ThreeArg() # default
MulStyle(::AbstractVecOrMat) = FiveArg()
MulStyle(::AbstractQ) = ThreeArg()
Expand Down Expand Up @@ -113,10 +119,6 @@ _combine(As::LinearMapVector, Bs::LinearMapVector) = Base.vect(As..., Bs...)

Compute the action of the linear map `A` on the vector `x`.

!!! compat "Julia 1.3"
In Julia versions v1.3 and above, objects `L` of any subtype of `LinearMap`
are callable in the sense that `L(x) = L*x` for `x::AbstractVector`.

## Examples
```jldoctest; setup=(using LinearAlgebra, LinearMaps)
julia> A=LinearMap([1.0 2.0; 3.0 4.0]); x=[1.0, 1.0];
Expand All @@ -136,7 +138,7 @@ function Base.:(*)(A::LinearMap, x::AbstractVector)
check_dim_mul(A, x)
T = promote_type(eltype(A), eltype(x))
y = similar(x, T, axes(A)[1])
return mul!(y, A, x)
return @inbounds mul!(y, A, x)
end

(L::LinearMap)(x::AbstractVector) = L*x
Expand Down Expand Up @@ -166,8 +168,8 @@ julia> Y
7.0 7.0
```
"""
function mul!(y::AbstractVecOrMat, A::LinearMap, x::AbstractVector)
check_dim_mul(y, A, x)
@inline function mul!(y::AbstractVecOrMat, A::LinearMap, x::AbstractVector)
@boundscheck check_dim_mul(y, A, x)
return _unsafe_mul!(y, A, x)
end
# the following is of interest in, e.g., subspace-iteration methods
Expand All @@ -194,7 +196,7 @@ julia> mul!(Y, A, b)
```
"""
function mul!(y::AbstractVecOrMat, A::LinearMap, s::Number)
size(y) == size(A) ||
size(y) == size(A) ||
throw(
DimensionMismatch("y has size $(size(y)), A has size $(size(A))."))
return _unsafe_mul!(y, A, s)
Expand Down Expand Up @@ -257,7 +259,7 @@ julia> mul!(Y, A, b, 2, 1)
```
"""
function mul!(y::AbstractMatrix, A::LinearMap, s::Number, α::Number, β::Number)
size(y) == size(A) ||
size(y) == size(A) ||
throw(
DimensionMismatch("y has size $(size(y)), A has size $(size(A))."))
return _unsafe_mul!(y, A, s, α, β)
Expand Down
1 change: 1 addition & 0 deletions src/blockmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ end
# provide one global intermediate storage vector if necessary
__blockmul!(::FiveArg, y, A, x::AbstractVecOrMat, α, β) = ___blockmul!(y, A, x, α, β, nothing)
__blockmul!(::ThreeArg, y, A, x::AbstractVecOrMat, α, β) = ___blockmul!(y, A, x, α, β, similar(y))
__blockmul!(::TwoArg, y, A, x::AbstractVecOrMat, α, β) = ___blockmul!(y, A, x, α, β, nothing)
function ___blockmul!(y, A, x, α, β, ::Nothing)
maps, rows, yinds, xinds = A.maps, A.rows, A.rowranges, A.colranges
mapind = 0
Expand Down
104 changes: 55 additions & 49 deletions src/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,18 @@ Base.:(==)(A::CompositeMap, B::CompositeMap) =
(eltype(A) == eltype(B) && all(A.maps .== B.maps))

# multiplication with vectors/matrices
_unsafe_mul!(y, A::CompositeMap, x::AbstractVector) = _compositemul!(y, A, x)
function Base.:(*)(A::CompositeMap, x::AbstractVector)
MulStyle(A) === TwoArg() ?
foldr(*, reverse(A.maps), init=x) :
invoke(*, Tuple{LinearMap, AbstractVector}, A, x)
end

function _unsafe_mul!(y, A::CompositeMap, x::AbstractVector)
MulStyle(A) === TwoArg() ?
copyto!(y, foldr(*, reverse(A.maps), init=x)) :
_compositemul!(y, A, x)
return y
end
_unsafe_mul!(y, A::CompositeMap, x::AbstractMatrix) = _compositemul!(y, A, x)

function _compositemul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap}}, x,
Expand All @@ -174,10 +185,50 @@ function _compositemul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap}}, x,
return _unsafe_mul!(y, A.maps[1], x)
end
function _compositemul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}}, x,
source = similar(y, (size(A.maps[1],1), size(x)[2:end]...)),
source = nothing,
dest = nothing)
if isnothing(source)
z = convert(AbstractArray, A.maps[1] * x)
_unsafe_mul!(y, A.maps[2], z)
return y
else
_unsafe_mul!(source, A.maps[1], x)
_unsafe_mul!(y, A.maps[2], source)
return y
end
end
_compositemul!(y, A::CompositeMap{<:Any,<:LinearMapTuple}, x, s = nothing, d = nothing) =
_compositemulN!(y, A, x, s, d)
function _compositemul!(y, A::CompositeMap{<:Any,<:LinearMapVector}, x,
source = nothing,
dest = nothing)
_unsafe_mul!(source, A.maps[1], x)
_unsafe_mul!(y, A.maps[2], source)
N = length(A.maps)
if N == 1
return _unsafe_mul!(y, A.maps[1], x)
elseif N == 2
return _unsafe_mul!(y, A.maps[2] * A.maps[1], x)
else
return _compositemulN!(y, A, x, source, dest)
end
end

function _compositemulN!(y, A::CompositeMap, x,
src = nothing,
dst = nothing)
N = length(A.maps) # ≥ 3
source = isnothing(src) ?
convert(AbstractArray, A.maps[1] * x) :
_unsafe_mul!(src, A.maps[1], x)
dest = isnothing(dst) ?
convert(AbstractArray, A.maps[2] * source) :
_unsafe_mul!(dst, A.maps[2], source)
dest, source = source, dest # alternate dest and source
for n in 3:N-1
dest = _resize(dest, (size(A.maps[n], 1), size(x)[2:end]...))
_unsafe_mul!(dest, A.maps[n], source)
dest, source = source, dest # alternate dest and source
end
_unsafe_mul!(y, A.maps[N], source)
return y
end

Expand All @@ -197,48 +248,3 @@ function _resize(dest::AbstractMatrix, sz::Tuple{<:Integer,<:Integer})
size(dest) == sz && return dest
similar(dest, sz)
end

function _compositemul!(y, A::CompositeMap{<:Any,<:LinearMapTuple}, x,
source = similar(y, (size(A.maps[1],1), size(x)[2:end]...)),
dest = similar(y, (size(A.maps[2],1), size(x)[2:end]...)))
N = length(A.maps)
_unsafe_mul!(source, A.maps[1], x)
for n in 2:N-1
dest = _resize(dest, (size(A.maps[n],1), size(x)[2:end]...))
_unsafe_mul!(dest, A.maps[n], source)
dest, source = source, dest # alternate dest and source
end
_unsafe_mul!(y, A.maps[N], source)
return y
end

function _compositemul!(y, A::CompositeMap{<:Any,<:LinearMapVector}, x)
N = length(A.maps)
if N == 1
return _unsafe_mul!(y, A.maps[1], x)
elseif N == 2
return _compositemul2!(y, A, x)
else
return _compositemulN!(y, A, x)
end
end

function _compositemul2!(y, A::CompositeMap{<:Any,<:LinearMapVector}, x,
source = similar(y, (size(A.maps[1],1), size(x)[2:end]...)))
_unsafe_mul!(source, A.maps[1], x)
_unsafe_mul!(y, A.maps[2], source)
return y
end
function _compositemulN!(y, A::CompositeMap{<:Any,<:LinearMapVector}, x,
source = similar(y, (size(A.maps[1],1), size(x)[2:end]...)),
dest = similar(y, (size(A.maps[2],1), size(x)[2:end]...)))
N = length(A.maps)
_unsafe_mul!(source, A.maps[1], x)
for n in 2:N-1
dest = _resize(dest, (size(A.maps[n],1), size(x)[2:end]...))
_unsafe_mul!(dest, A.maps[n], source)
dest, source = source, dest # alternate dest and source
end
_unsafe_mul!(y, A.maps[N], source)
return y
end
5 changes: 0 additions & 5 deletions src/fillmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ Base.:(==)(A::FillMap, B::FillMap) = A.λ == B.λ && A.size == B.size
LinearAlgebra.adjoint(A::FillMap) = FillMap(adjoint(A.λ), reverse(A.size))
LinearAlgebra.transpose(A::FillMap) = FillMap(transpose(A.λ), reverse(A.size))

function Base.:(*)(A::FillMap, x::AbstractVector)
T = typeof(oneunit(eltype(A)) * (zero(eltype(x)) + zero(eltype(x))))
return fill(iszero(A.λ) ? zero(T) : A.λ*sum(x), A.size[1])
end

function _unsafe_mul!(y, A::FillMap, x::AbstractVector)
return fill!(y, iszero(A.λ) ? zero(eltype(y)) : A.λ*sum(x))
end
Expand Down
Loading