Skip to content

Commit a3dca0a

Browse files
committed
remove mapreduce for generic f, add Statistics.mean
1 parent 373f3ef commit a3dca0a

File tree

5 files changed

+28
-22
lines changed

5 files changed

+28
-22
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "3.6.0"
55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
77
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
8+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89

910
[compat]
1011
julia = "1.6"

src/LinearMaps.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ using LinearAlgebra
88
import LinearAlgebra: mul!
99
using SparseArrays
1010

11+
import Statistics: mean
12+
1113
using Base: require_one_based_indexing
1214

1315
abstract type LinearMap{T} end
@@ -22,6 +24,15 @@ const LinearMapTupleOrVector = Union{LinearMapTuple,LinearMapVector}
2224

2325
Base.eltype(::LinearMap{T}) where {T} = T
2426

27+
# conversion to LinearMap
28+
Base.convert(::Type{LinearMap}, A::LinearMap) = A
29+
Base.convert(::Type{LinearMap}, A::AbstractVecOrMat) = LinearMap(A)
30+
31+
convert_to_lmaps() = ()
32+
convert_to_lmaps(A) = (convert(LinearMap, A),)
33+
@inline convert_to_lmaps(A, B, Cs...) =
34+
(convert(LinearMap, A), convert(LinearMap, B), convert_to_lmaps(Cs...)...)
35+
2536
abstract type MulStyle end
2637

2738
struct FiveArg <: MulStyle end
@@ -65,14 +76,6 @@ function check_dim_mul(C, A, B)
6576
return nothing
6677
end
6778

68-
# conversion of AbstractVecOrMat to LinearMap
69-
convert_to_lmaps_(A::AbstractVecOrMat) = LinearMap(A)
70-
convert_to_lmaps_(A::LinearMap) = A
71-
convert_to_lmaps() = ()
72-
convert_to_lmaps(A) = (convert_to_lmaps_(A),)
73-
@inline convert_to_lmaps(A, B, Cs...) =
74-
(convert_to_lmaps_(A), convert_to_lmaps_(B), convert_to_lmaps(Cs...)...)
75-
7679
_front(As::Tuple) = Base.front(As)
7780
_front(As::AbstractVector) = @inbounds @views As[1:end-1]
7881
_tail(As::Tuple) = Base.tail(As)

src/kronecker.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ Construct a lazy representation of the `k`-th Kronecker power
9797
(A, B, Cs...) = kron(convert_to_lmaps(A, B, Cs...)...)
9898

9999
Base.:(^)(A::MapOrMatrix, ::KronPower{p}) where {p} =
100-
kron(ntuple(n -> convert_to_lmaps_(A), Val(p))...)
100+
kron(ntuple(n -> convert(LinearMap, A), Val(p))...)
101101

102102
Base.size(A::KroneckerMap) = map(*, size.(A.maps)...)
103103

@@ -287,7 +287,7 @@ where `A` can be a square `AbstractMatrix` or a `LinearMap`.
287287
(a, b, c...) = kronsum(a, b, c...)
288288

289289
Base.:(^)(A::MapOrMatrix, ::KronSumPower{p}) where {p} =
290-
kronsum(ntuple(n->convert_to_lmaps_(A), Val(p))...)
290+
kronsum(ntuple(n -> convert(LinearMap, A), Val(p))...)
291291

292292
Base.size(A::KroneckerSumMap, i) = prod(size.(A.maps, i))
293293
Base.size(A::KroneckerSumMap) = (size(A, 1), size(A, 2))

src/linearcombination.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@ Base.mapreduce(::typeof(identity), ::typeof(Base.add_sum), maps::LinearMapTupleO
2020
# this method is required for type stability in the mixed-map-equal-eltype case
2121
Base.mapreduce(::typeof(identity), ::typeof(Base.add_sum), maps::AbstractVector{<:LinearMap{T}}) where {T} =
2222
LinearCombination{T}(maps)
23-
# the following two methods are needed to make e.g. `mean` work,
24-
# for which `f` is some sort of promotion function
25-
Base.mapreduce(f::F, ::typeof(Base.add_sum), maps::LinearMapTupleOrVector) where {F} =
26-
LinearCombination{promote_type(map(eltype, maps)...)}(f.(maps))
27-
Base.mapreduce(f::F, ::typeof(Base.add_sum), maps::AbstractVector{<:LinearMap{T}}) where {F,T} =
28-
LinearCombination{T}(f.(maps))
23+
24+
mean(f::F, maps::LinearMapTupleOrVector) where {F} = sum(f, maps) / length(maps)
25+
mean(maps::LinearMapTupleOrVector) = sum(maps) / length(maps)
2926

3027
MulStyle(A::LinearCombination) = MulStyle(A.maps...)
3128

test/linearcombination.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@ using LinearMaps: FiveArg, LinearMapTuple, LinearMapVector
1616
@test (@inferred sum(Lv.maps::LinearMapVector)) == Lv
1717
@test isa((@inferred mean(Lv.maps)),
1818
LinearMaps.ScaledMap{ComplexF64,Float64,<:LinearMaps.LinearCombination{ComplexF64,<:LinearMapVector}})
19-
@test (@inferred mean(x -> x*x, Lv.maps)) == (@inferred sum(x -> x*x, Lv.maps)/n)
19+
@test (@inferred mean(L.maps)) == (@inferred mean(Lv.maps)) == (@inferred sum(Lv.maps))/n
20+
@test (@inferred mean(x -> x*x, L.maps)) == (@inferred sum(x -> x*x, L.maps))/n
21+
@test mean(x -> x*x, Lv.maps) == (sum(x -> x*x, Lv.maps))/n
2022
@test L == Lv
21-
@test isa((@inferred sum([CS!, LinearMap(randn(eltype(CS!), size(CS!)))])),
22-
LinearMaps.LinearCombination{<:ComplexF64,<:LinearMapVector})
23-
A = randn(eltype(CS!), size(CS!))
24-
@test (@inferred mean([CS!, LinearMap(A)])) == (@inferred sum([CS!, LinearMap(A)])/2)
25-
@test isa(sum([CS!, LinearMap(real(A))]),
23+
A = LinearMap(randn(eltype(CS!), size(CS!)))
24+
Ar = LinearMap(real(A.lmap))
25+
@test isa((@inferred sum([CS!, A])),
2626
LinearMaps.LinearCombination{<:ComplexF64,<:LinearMapVector})
27+
@test (@inferred mean([CS!, A])) == (@inferred sum([CS!, A]))/2
28+
@test (@inferred mean([CS!, A])) == (@inferred mean(identity, [CS!, A])) == (@inferred sum([CS!, A]))/2
29+
@test isa(sum([CS!, Ar]), LinearMaps.LinearCombination{<:ComplexF64,<:LinearMapVector})
30+
@test sum([CS!, Ar])/2 == mean([CS!, Ar])
31+
@test sum([CS!, Ar]) == sum(identity, [CS!, Ar])
2732
for sum1 in (CS!, L, Lv), sum2 in (CS!, L, Lv)
2833
m1 = sum1 == CS! ? 1 : 10
2934
m2 = sum2 == CS! ? 1 : 10

0 commit comments

Comments
 (0)