Skip to content

Commit b6b1044

Browse files
authored
Handle dependencies via Pkg extensions (#208)
1 parent f522a8f commit b6b1044

File tree

8 files changed

+140
-108
lines changed

8 files changed

+140
-108
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LinearMaps"
22
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
3-
version = "3.10.1"
3+
version = "3.10.2"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -10,9 +10,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1010

1111
[weakdeps]
1212
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
13+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
14+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1315

1416
[extensions]
1517
LinearMapsChainRulesCoreExt = "ChainRulesCore"
18+
LinearMapsSparseArraysExt = "SparseArrays"
19+
LinearMapsStatisticsExt = "Statistics"
1620

1721
[compat]
1822
ChainRulesCore = "1"

ext/LinearMapsSparseArraysExt.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
module LinearMapsSparseArraysExt
2+
3+
import SparseArrays: sparse, blockdiag, SparseMatrixCSC
4+
using SparseArrays: AbstractSparseMatrix
5+
6+
using LinearAlgebra, LinearMaps
7+
import LinearMaps: _issymmetric, _ishermitian
8+
using LinearMaps: WrappedMap, CompositeMap, LinearCombination, ScaledMap, UniformScalingMap,
9+
AdjointMap, TransposeMap, BlockMap, BlockDiagonalMap, KroneckerMap, KroneckerSumMap,
10+
VecOrMatMap, AbstractVecOrMatOrQ, MapOrVecOrMat
11+
using LinearMaps: convert_to_lmaps, _tail, _unsafe_mul!
12+
13+
_issymmetric(A::AbstractSparseMatrix) = issymmetric(A)
14+
_ishermitian(A::AbstractSparseMatrix) = ishermitian(A)
15+
16+
# blockdiagonal concatenation via extension of blockdiag
17+
18+
"""
19+
blockdiag(As::Union{LinearMap,AbstractVecOrMatOrQ}...)::BlockDiagonalMap
20+
21+
Construct a (lazy) representation of the diagonal concatenation of the arguments.
22+
To avoid fallback to the generic `blockdiag`, there must be a `LinearMap`
23+
object among the first 8 arguments.
24+
"""
25+
blockdiag
26+
27+
for k in 1:8 # is 8 sufficient?
28+
Is = ntuple(n->:($(Symbol(:A, n))::AbstractVecOrMatOrQ), Val(k-1))
29+
# yields (:A1, :A2, :A3, ..., :A(k-1))
30+
L = :($(Symbol(:A, k))::LinearMap)
31+
# yields :Ak
32+
mapargs = ntuple(n ->:($(Symbol(:A, n))), Val(k-1))
33+
# yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1)))
34+
35+
@eval function blockdiag($(Is...), $L, As::MapOrVecOrMat...)
36+
return BlockDiagonalMap(convert_to_lmaps($(mapargs...))...,
37+
$(Symbol(:A, k)),
38+
convert_to_lmaps(As...)...)
39+
end
40+
end
41+
42+
# conversion to sparse arrays
43+
# sparse: create sparse matrix representation of LinearMap
44+
function sparse(A::LinearMap{T}) where {T}
45+
M, N = size(A)
46+
rowind = Int[]
47+
nzval = T[]
48+
colptr = Vector{Int}(undef, N+1)
49+
v = fill(zero(T), N)
50+
Av = Vector{T}(undef, M)
51+
52+
@inbounds for i in eachindex(v)
53+
v[i] = one(T)
54+
_unsafe_mul!(Av, A, v)
55+
js = findall(!iszero, Av)
56+
colptr[i] = length(nzval) + 1
57+
if length(js) > 0
58+
append!(rowind, js)
59+
append!(nzval, Av[js])
60+
end
61+
v[i] = zero(T)
62+
end
63+
colptr[N+1] = length(nzval) + 1
64+
65+
return SparseMatrixCSC(M, N, colptr, rowind, nzval)
66+
end
67+
Base.convert(::Type{SparseMatrixCSC}, A::LinearMap) = sparse(A)
68+
SparseMatrixCSC(A::LinearMap) = sparse(A)
69+
70+
sparse(A::ScaledMap{<:Any, <:Any, <:VecOrMatMap}) =
71+
A.λ * sparse(A.lmap.lmap)
72+
sparse(A::WrappedMap) = sparse(A.lmap)
73+
Base.convert(::Type{SparseMatrixCSC}, A::WrappedMap) = convert(SparseMatrixCSC, A.lmap)
74+
for (T, t) in ((:AdjointMap, adjoint), (:TransposeMap, transpose))
75+
@eval sparse(A::$T) = $t(convert(SparseMatrixCSC, A.lmap))
76+
end
77+
function sparse(ΣA::LinearCombination{<:Any, <:Tuple{Vararg{VecOrMatMap}}})
78+
mats = map(A->getfield(A, :lmap), ΣA.maps)
79+
return sum(sparse, mats)
80+
end
81+
function sparse(AB::CompositeMap{<:Any, <:Tuple{VecOrMatMap, VecOrMatMap}})
82+
B, A = AB.maps
83+
return sparse(A.lmap)*sparse(B.lmap)
84+
end
85+
function sparse(λA::CompositeMap{<:Any, <:Tuple{VecOrMatMap, UniformScalingMap}})
86+
A, J = λA.maps
87+
return J.λ*sparse(A.lmap)
88+
end
89+
function sparse(Aλ::CompositeMap{<:Any, <:Tuple{UniformScalingMap, VecOrMatMap}})
90+
J, A =.maps
91+
return sparse(A.lmap)*J.λ
92+
end
93+
function sparse(A::BlockMap)
94+
return hvcat(
95+
A.rows,
96+
convert(SparseMatrixCSC, first(A.maps)),
97+
convert.(AbstractArray, _tail(A.maps))...
98+
)
99+
end
100+
function sparse(A::BlockDiagonalMap)
101+
return blockdiag(convert.(SparseMatrixCSC, A.maps)...)
102+
end
103+
Base.convert(::Type{AbstractMatrix}, A::BlockDiagonalMap) = sparse(A)
104+
function sparse(A::KroneckerMap)
105+
return kron(
106+
convert(SparseMatrixCSC, first(A.maps)),
107+
convert.(AbstractMatrix, _tail(A.maps))...
108+
)
109+
end
110+
function sparse(L::KroneckerSumMap)
111+
A, B = L.maps
112+
IA = sparse(Diagonal(ones(Bool, size(A, 1))))
113+
IB = sparse(Diagonal(ones(Bool, size(B, 1))))
114+
return kron(convert(AbstractMatrix, A), IB) + kron(IA, convert(AbstractMatrix, B))
115+
end
116+
117+
end # module LinearMapsSparseArraysExt

ext/LinearMapsStatisticsExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module LinearMapsStatisticsExt
2+
3+
import Statistics: mean
4+
5+
using LinearMaps
6+
using LinearMaps: LinearMapTupleOrVector, LinearCombination
7+
8+
mean(f::F, maps::LinearMapTupleOrVector) where {F} = sum(f, maps) / length(maps)
9+
mean(maps::LinearMapTupleOrVector) = mean(identity, maps)
10+
mean(A::LinearCombination) = mean(A.maps)
11+
12+
end # module ChainRulesCore

src/LinearMaps.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ export ⊗, squarekron, kronsum, ⊕, sumkronsum, khatrirao, facesplitting
66
using LinearAlgebra
77
using LinearAlgebra: AbstractQ
88
import LinearAlgebra: mul!
9-
using SparseArrays
10-
11-
import Statistics: mean
129

1310
using Base: require_one_based_indexing
1411

@@ -422,6 +419,8 @@ LinearMap{T}(f, args...; kwargs...) where {T} = FunctionMap{T}(f, args...; kwarg
422419

423420
@static if !isdefined(Base, :get_extension)
424421
include("../ext/LinearMapsChainRulesCoreExt.jl")
422+
include("../ext/LinearMapsSparseArraysExt.jl")
423+
include("../ext/LinearMapsStatisticsExt.jl")
425424
end
426425

427426
end # module

src/blockmap.jl

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -505,34 +505,17 @@ for k in 1:8 # is 8 sufficient?
505505
mapargs = ntuple(n ->:($(Symbol(:A, n))), Val(k-1))
506506
# yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1)))
507507

508-
@eval begin
509-
function SparseArrays.blockdiag($(Is...), $L, As::MapOrVecOrMat...)
508+
@eval function Base.cat($(Is...), $L, As::MapOrVecOrMat...; dims::Dims{2})
509+
if dims == (1,2)
510510
return BlockDiagonalMap(convert_to_lmaps($(mapargs...))...,
511511
$(Symbol(:A, k)),
512512
convert_to_lmaps(As...)...)
513-
end
514-
515-
function Base.cat($(Is...), $L, As::MapOrVecOrMat...; dims::Dims{2})
516-
if dims == (1,2)
517-
return BlockDiagonalMap(convert_to_lmaps($(mapargs...))...,
518-
$(Symbol(:A, k)),
519-
convert_to_lmaps(As...)...)
520-
else
521-
throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)"))
522-
end
513+
else
514+
throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)"))
523515
end
524516
end
525517
end
526518

527-
"""
528-
blockdiag(As::Union{LinearMap,AbstractVecOrMatOrQ}...)::BlockDiagonalMap
529-
530-
Construct a (lazy) representation of the diagonal concatenation of the arguments.
531-
To avoid fallback to the generic `SparseArrays.blockdiag`, there must be a `LinearMap`
532-
object among the first 8 arguments.
533-
"""
534-
SparseArrays.blockdiag
535-
536519
"""
537520
cat(As::Union{LinearMap,AbstractVecOrMatOrQ}...; dims=(1,2))::BlockDiagonalMap
538521

src/conversion.jl

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -19,39 +19,8 @@ Base.convert(::Type{Array}, A::LinearMap) = convert(Matrix, A)
1919
Base.convert(::Type{AbstractMatrix}, A::LinearMap) = AbstractMatrix(A)
2020
Base.convert(::Type{AbstractArray}, A::LinearMap) = convert(AbstractMatrix, A)
2121

22-
# sparse: create sparse matrix representation of LinearMap
23-
function SparseArrays.sparse(A::LinearMap{T}) where {T}
24-
M, N = size(A)
25-
rowind = Int[]
26-
nzval = T[]
27-
colptr = Vector{Int}(undef, N+1)
28-
v = fill(zero(T), N)
29-
Av = Vector{T}(undef, M)
30-
31-
@inbounds for i in eachindex(v)
32-
v[i] = one(T)
33-
_unsafe_mul!(Av, A, v)
34-
js = findall(!iszero, Av)
35-
colptr[i] = length(nzval) + 1
36-
if length(js) > 0
37-
append!(rowind, js)
38-
append!(nzval, Av[js])
39-
end
40-
v[i] = zero(T)
41-
end
42-
colptr[N+1] = length(nzval) + 1
43-
44-
return SparseMatrixCSC(M, N, colptr, rowind, nzval)
45-
end
46-
Base.convert(::Type{SparseMatrixCSC}, A::LinearMap) = sparse(A)
47-
SparseArrays.SparseMatrixCSC(A::LinearMap) = sparse(A)
48-
4922
# special cases
5023

51-
# ScaledMap
52-
SparseArrays.sparse(A::ScaledMap{<:Any, <:Any, <:VecOrMatMap}) =
53-
A.λ * sparse(A.lmap.lmap)
54-
5524
# UniformScalingMap
5625
Base.convert(::Type{AbstractMatrix}, J::UniformScalingMap) = Diagonal(fill(J.λ, J.M))
5726

@@ -61,19 +30,10 @@ Base.convert(::Type{T}, A::WrappedMap) where {T<:Matrix} = convert(T, A.lmap)
6130
Base.Matrix{T}(A::VectorMap) where {T} = copyto!(Matrix{eltype(T)}(undef, size(A)), A.lmap)
6231
Base.convert(::Type{T}, A::VectorMap) where {T<:Matrix} = T(A)
6332
Base.convert(::Type{AbstractMatrix}, A::WrappedMap) = convert(AbstractMatrix, A.lmap)
64-
SparseArrays.sparse(A::WrappedMap) = sparse(A.lmap)
65-
Base.convert(::Type{SparseMatrixCSC}, A::WrappedMap) = convert(SparseMatrixCSC, A.lmap)
6633

6734
# TransposeMap & AdjointMap
6835
for (T, t) in ((AdjointMap, adjoint), (TransposeMap, transpose))
6936
@eval Base.convert(::Type{AbstractMatrix}, A::$T) = $t(convert(AbstractMatrix, A.lmap))
70-
@eval SparseArrays.sparse(A::$T) = $t(convert(SparseMatrixCSC, A.lmap))
71-
end
72-
73-
# LinearCombination
74-
function SparseArrays.sparse(ΣA::LinearCombination{<:Any, <:Tuple{Vararg{VecOrMatMap}}})
75-
mats = map(A->getfield(A, :lmap), ΣA.maps)
76-
return sum(sparse, mats)
7737
end
7838

7939
# CompositeMap
@@ -99,50 +59,19 @@ function Base.Matrix{T}(AB::CompositeMap{<:Any, <:Tuple{VecOrMatMap, VecOrMatMap
9959
B, A = AB.maps
10060
return mul!(Matrix{T}(undef, size(AB)), A.lmap, B.lmap)
10161
end
102-
function SparseArrays.sparse(AB::CompositeMap{<:Any, <:Tuple{VecOrMatMap, VecOrMatMap}})
103-
B, A = AB.maps
104-
return sparse(A.lmap)*sparse(B.lmap)
105-
end
10662
function Base.Matrix{T}(λA::CompositeMap{<:Any, <:Tuple{VecOrMatMap, UniformScalingMap}}) where {T}
10763
A, J = λA.maps
10864
return mul!(Matrix{T}(undef, size(λA)), J.λ, A.lmap)
10965
end
110-
function SparseArrays.sparse(λA::CompositeMap{<:Any, <:Tuple{VecOrMatMap, UniformScalingMap}})
111-
A, J = λA.maps
112-
return J.λ*sparse(A.lmap)
113-
end
11466
function Base.Matrix{T}(Aλ::CompositeMap{<:Any, <:Tuple{UniformScalingMap, VecOrMatMap}}) where {T}
11567
J, A =.maps
11668
return mul!(Matrix{T}(undef, size(Aλ)), A.lmap, J.λ)
11769
end
118-
function SparseArrays.sparse(Aλ::CompositeMap{<:Any, <:Tuple{UniformScalingMap, VecOrMatMap}})
119-
J, A =.maps
120-
return sparse(A.lmap)*J.λ
121-
end
122-
123-
# BlockMap & BlockDiagonalMap
124-
function SparseArrays.sparse(A::BlockMap)
125-
return hvcat(
126-
A.rows,
127-
convert(SparseMatrixCSC, first(A.maps)),
128-
convert.(AbstractArray, _tail(A.maps))...
129-
)
130-
end
131-
Base.convert(::Type{AbstractMatrix}, A::BlockDiagonalMap) = sparse(A)
132-
function SparseArrays.sparse(A::BlockDiagonalMap)
133-
return blockdiag(convert.(SparseMatrixCSC, A.maps)...)
134-
end
13570

13671
# KroneckerMap & KroneckerSumMap
13772
Base.Matrix{T}(A::KroneckerMap) where {T} = kron(convert.(Matrix{T}, A.maps)...)
13873
Base.convert(::Type{AbstractMatrix}, A::KroneckerMap) =
13974
kron(convert.(AbstractMatrix, A.maps)...)
140-
function SparseArrays.sparse(A::KroneckerMap)
141-
return kron(
142-
convert(SparseMatrixCSC, first(A.maps)),
143-
convert.(AbstractMatrix, _tail(A.maps))...
144-
)
145-
end
14675

14776
function Base.Matrix{T}(L::KroneckerSumMap) where {T}
14877
A, B = L.maps
@@ -156,12 +85,6 @@ function Base.convert(::Type{AbstractMatrix}, L::KroneckerSumMap)
15685
IB = Diagonal(ones(Bool, size(B, 1)))
15786
return kron(convert(AbstractMatrix, A), IB) + kron(IA, convert(AbstractMatrix, B))
15887
end
159-
function SparseArrays.sparse(L::KroneckerSumMap)
160-
A, B = L.maps
161-
IA = sparse(Diagonal(ones(Bool, size(A, 1))))
162-
IB = sparse(Diagonal(ones(Bool, size(B, 1))))
163-
return kron(convert(AbstractMatrix, A), IB) + kron(IA, convert(AbstractMatrix, B))
164-
end
16588

16689
# FillMap
16790
Base.Matrix{T}(A::FillMap) where {T} = fill(T(A.λ), size(A))

src/linearcombination.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ Base.mapreduce(::typeof(identity), ::typeof(Base.add_sum), maps::LinearMapTupleO
2121
Base.mapreduce(::typeof(identity), ::typeof(Base.add_sum), maps::AbstractVector{<:LinearMap{T}}) where {T} =
2222
LinearCombination{T}(maps)
2323

24-
mean(f::F, maps::LinearMapTupleOrVector) where {F} = sum(f, maps) / length(maps)
25-
mean(maps::LinearMapTupleOrVector) = mean(identity, maps)
26-
mean(A::LinearCombination) = mean(A.maps)
27-
2824
MulStyle(A::LinearCombination) = MulStyle(A.maps...)
2925

3026
# basic methods

src/wrappedmap.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,12 @@ WrappedMap(lmap::MapOrVecOrMat{T}; kwargs...) where {T} = WrappedMap{T}(lmap; kw
2424
# cheap property checks (usually by type)
2525
_issymmetric(A::AbstractMatrix) = false
2626
_issymmetric(A::AbstractQ) = false
27-
_issymmetric(A::AbstractSparseMatrix) = issymmetric(A)
2827
_issymmetric(A::LinearMap) = issymmetric(A)
2928
_issymmetric(A::LinearAlgebra.RealHermSymComplexSym) = issymmetric(A)
3029
_issymmetric(A::Union{Bidiagonal,Diagonal,SymTridiagonal,Tridiagonal}) = issymmetric(A)
3130

3231
_ishermitian(A::AbstractMatrix) = false
3332
_ishermitian(A::AbstractQ) = false
34-
_ishermitian(A::AbstractSparseMatrix) = ishermitian(A)
3533
_ishermitian(A::LinearMap) = ishermitian(A)
3634
_ishermitian(A::LinearAlgebra.RealHermSymComplexHerm) = ishermitian(A)
3735
_ishermitian(A::Union{Bidiagonal,Diagonal,SymTridiagonal,Tridiagonal}) = ishermitian(A)

0 commit comments

Comments
 (0)