Skip to content

Commit 65eb422

Browse files
authored
Overload LinearAlgebra.tr (#210)
1 parent b6b1044 commit 65eb422

File tree

8 files changed

+117
-4
lines changed

8 files changed

+117
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LinearMaps"
22
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
3-
version = "3.10.2"
3+
version = "3.11.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/history.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Version history
22

3+
## What's new in v3.11
4+
5+
* The `tr` function from `LinearAlgebra.jl` is now overloaded both for generic `LinearMap`
6+
types and specialized for most provided `LinearMap` types. In the generic case, this is
7+
computationally as expensive as computing the whole matrix representation, though the
8+
latter is, of course, not stored.
9+
310
## What's new in v3.10
411

512
* A new `MulStyle` trait called `TwoArg` has been added. It should be used for `LinearMap`s

docs/src/types.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ as in the usual matrix case: `transpose(A) * x` and `mul!(y, A', x)`, for instan
179179
a linear map for which you only have a function definition (e.g. to be able
180180
to use its `transpose` or `adjoint`).
181181

182+
!!! note
183+
In Julia versions v1.9 and higher, conversion to sparse matrices requires loading
184+
`SparseArrays.jl` by the user in advance.
185+
182186
### Slicing methods
183187

184188
Complete slicing, i.e., `A[:,j]`, `A[:,J]`, `A[i,:]`, `A[I,:]` and `A[:,:]` for `i`, `j`
@@ -188,3 +192,12 @@ slicing) to standard unit vectors of appropriate length. By complete slicing we
188192
two-dimensional Cartesian indexing where at least one of the "indices" is a colon. This is
189193
facilitated by overloads of `Base.getindex`. Partial slicing à la `A[I,J]` and scalar or
190194
linear indexing are _not_ supported.
195+
196+
### Sum, product, mean and trace
197+
198+
Natural function overloads for `Base.sum`, `Base.prod`, `Statistics.mean` and `LinearAlgebra.tr`
199+
exist.
200+
201+
!!! note
202+
In Julia versions v1.9 and higher, creating the mean linear operator requires loading
203+
`Statistics.jl` by the user in advance.

src/LinearMaps.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ export ⊗, squarekron, kronsum, ⊕, sumkronsum, khatrirao, facesplitting
55

66
using LinearAlgebra
77
using LinearAlgebra: AbstractQ
8-
import LinearAlgebra: mul!
8+
import LinearAlgebra: mul!, tr
99

1010
using Base: require_one_based_indexing
1111

@@ -348,6 +348,7 @@ include("conversion.jl") # conversion of linear maps to matrices
348348
include("show.jl") # show methods for LinearMap objects
349349
include("getindex.jl") # getindex functionality
350350
include("inversemap.jl")
351+
include("trace.jl")
351352

352353
"""
353354
LinearMap(A::LinearMap; kwargs...)::WrappedMap

src/kronecker.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ compared to [`kron`](@ref), but benchmarking intended use cases is highly recomm
108108
function squarekron(A::MapOrMatrix, B::MapOrMatrix, C::MapOrMatrix, Ds::MapOrMatrix...)
109109
maps = (A, B, C, Ds...)
110110
T = promote_type(map(eltype, maps)...)
111-
all(_issquare, maps) || throw(ArgumentError("operators need to be square in Kronecker sums"))
111+
all(_issquare, maps) || throw(ArgumentError("operators need to be square in squarekron"))
112112
ns = map(a -> size(a, 1), maps)
113113
firstmap = first(maps) UniformScalingMap(true, prod(ns[2:end]))
114114
lastmap = UniformScalingMap(true, prod(ns[1:end-1])) last(maps)
@@ -376,7 +376,7 @@ true
376376
[^1]: Fernandes, P. and Plateau, B. and Stewart, W. J. ["Efficient Descriptor-Vector Multiplications in Stochastic Automata Networks"](https://doi.org/10.1145/278298.278303), _Journal of the ACM_, 45(3), 381–414, 1998.
377377
"""
378378
function sumkronsum(A::MapOrMatrix, B::MapOrMatrix)
379-
LinearAlgebra.checksquare(A, B)
379+
(_issquare(A) && _issquare(B)) || throw(ArgumentError("operators need to be square in Kronecker sums"))
380380
A UniformScalingMap(true, size(B,1)) + UniformScalingMap(true, size(A,1)) B
381381
end
382382
function sumkronsum(A::MapOrMatrix, B::MapOrMatrix, C::MapOrMatrix, Ds::MapOrMatrix...)

src/trace.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
function tr(A::LinearMap)
2+
_issquare(A) || throw(ArgumentError("operator needs to be square in tr"))
3+
_tr(A)
4+
end
5+
6+
function _tr(A::LinearMap{T}) where {T}
7+
S = typeof(oneunit(eltype(A)) + oneunit(eltype(A)))
8+
ax1, ax2 = axes(A)
9+
xi = zeros(eltype(A), ax2)
10+
y = similar(xi, T, ax1)
11+
o = one(T)
12+
z = zero(T)
13+
s = zero(S)
14+
@inbounds for (i, j) in zip(ax1, ax2)
15+
xi[j] = o
16+
mul!(y, A, xi)
17+
xi[j] = z
18+
s += y[i]
19+
end
20+
return s
21+
end
22+
function _tr(A::OOPFunctionMap{T}) where {T}
23+
S = typeof(oneunit(eltype(A)) + oneunit(eltype(A)))
24+
ax1, ax2 = axes(A)
25+
xi = zeros(eltype(A), ax2)
26+
o = one(T)
27+
z = zero(T)
28+
s = zero(S)
29+
@inbounds for (i, j) in zip(ax1, ax2)
30+
xi[j] = o
31+
s += (A * xi)[i]
32+
xi[j] = z
33+
end
34+
return s
35+
end
36+
# specialiations
37+
_tr(A::AbstractVecOrMat) = tr(A)
38+
_tr(A::WrappedMap) = _tr(A.lmap)
39+
_tr(A::TransposeMap) = _tr(A.lmap)
40+
_tr(A::AdjointMap) = conj(_tr(A.lmap))
41+
_tr(A::UniformScalingMap) = A.M * A.λ
42+
_tr(A::ScaledMap) = A.λ * _tr(A.lmap)
43+
function _tr(L::KroneckerMap)
44+
if all(_issquare, L.maps)
45+
return prod(_tr, L.maps)
46+
else
47+
return invoke(_tr, Tuple{LinearMap}, L)
48+
end
49+
end
50+
function _tr(L::OuterProductMap{<:RealOrComplex})
51+
a, bt = L.maps
52+
return bt.lmap*a.lmap
53+
end
54+
function _tr(L::OuterProductMap)
55+
a, bt = L.maps
56+
mapreduce(*, +, a.lmap, bt.lmap)
57+
end
58+
function _tr(L::KroneckerSumMap)
59+
A, B = L.maps # A and B are square by construction
60+
return _tr(A) * size(B, 1) + _tr(B) * size(A, 1)
61+
end
62+
_tr(A::FillMap) = A.size[1] * A.λ

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,5 @@ include("inversemap.jl")
4343
include("rrules.jl")
4444

4545
include("khatrirao.jl")
46+
47+
include("trace.jl")

test/trace.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using LinearMaps, LinearAlgebra, Test
2+
3+
@testset "trace" begin
4+
for A in (randn(5, 5), randn(ComplexF64, 5, 5))
5+
@test tr(LinearMap(A)) == tr(A)
6+
@test tr(transpose(LinearMap(A))) == tr(A)
7+
@test tr(adjoint(LinearMap(A))) == tr(A')
8+
end
9+
@test tr(LinearMap(3I, 10)) == 30
10+
@test tr(LinearMap{Int}(cumsum, 10)) == 10
11+
@test tr(LinearMap{Int}(cumsum, reversecumsumreverse, 10)') == 10
12+
@test tr(LinearMap{Complex{Int}}(cumsum, reversecumsumreverse, 10)') == 10
13+
@test tr(LinearMap{Int}(cumsum!, 10)) == 10
14+
@test tr(2LinearMap{Int}(cumsum!, 10)) == 20
15+
A = randn(3, 5); B = copy(transpose(A))
16+
@test tr(A B) == tr(kron(A, B))
17+
@test tr(A B A B) tr(kron(A, B, A, B))
18+
A = randn(5, 5); B = copy(transpose(A))
19+
@test tr(A B) tr(kron(A, B))
20+
@test tr(A B A) tr(kron(A, B, A))
21+
@test tr(A B A B) tr(kron(A, B, A, B))
22+
v = A[:,1]
23+
@test tr(v v') norm(v)^2
24+
v = [randn(2,2) for _ in 1:3]
25+
@test tr(v v') mapreduce(*, +, v, v')
26+
@test tr(LinearMap{Int}(cumsum!, 10) LinearMap{Int}(cumsum!, 10)) == 200
27+
@test tr(FillMap(true, 5, 5)) == 5
28+
end

0 commit comments

Comments
 (0)