Skip to content

Commit 8b41f1a

Browse files
committed
Square Kronecker products and sums
1 parent 710a0d3 commit 8b41f1a

File tree

7 files changed

+242
-111
lines changed

7 files changed

+242
-111
lines changed

docs/src/types.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ kronsum
7272
LinearMaps.:⊕
7373
```
7474

75+
There exist alternative constructors of Kronecker products and sums for square factors and
76+
summands, respectively. These are designed for cases of 3 or more arguments, and
77+
benchmarking intended use cases for comparison with `KroneckerMap` and `KroneckerSumMap`
78+
is recommended.
79+
80+
```@docs
81+
squarekron
82+
sumkronsum
83+
```
84+
7585
### `BlockMap` and `BlockDiagonalMap`
7686

7787
Types for representing block (diagonal) maps lazily.
@@ -116,6 +126,7 @@ Base.:*(::AbstractMatrix,::LinearMap)
116126
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::AbstractVector)
117127
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::AbstractVector,::Number,::Number)
118128
LinearAlgebra.mul!(::AbstractMatrix,::AbstractMatrix,::LinearMap)
129+
LinearAlgebra.mul!(::AbstractMatrix,::AbstractMatrix,::LinearMap,::Number,::Number)
119130
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::Number)
120131
LinearAlgebra.mul!(::AbstractMatrix,::LinearMap,::Number,::Number,::Number)
121132
*(::LinearAlgebra.AdjointAbsVec,::LinearMap)

src/LinearMaps.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module LinearMaps
22

33
export LinearMap
4-
export , kronsum,
4+
export , squarekron, kronsum, , sumkronsum
55
export FillMap
66
export InverseMap
77

@@ -17,6 +17,7 @@ abstract type LinearMap{T} end
1717

1818
const MapOrVecOrMat{T} = Union{LinearMap{T}, AbstractVecOrMat{T}}
1919
const MapOrMatrix{T} = Union{LinearMap{T}, AbstractMatrix{T}}
20+
const TransposeAbsVecOrMat{T} = Transpose{T,<:AbstractVecOrMat}
2021
const RealOrComplex = Union{Real, Complex}
2122

2223
const LinearMapTuple = Tuple{Vararg{LinearMap}}
@@ -77,10 +78,12 @@ function check_dim_mul(C, A, B)
7778
return nothing
7879
end
7980

81+
_issquare(A) = size(A, 1) == size(A, 2)
82+
8083
_front(As::Tuple) = Base.front(As)
81-
_front(As::AbstractVector) = @inbounds @views As[1:end-1]
84+
_front(As::AbstractVector) = @inbounds @views As[begin:end-1]
8285
_tail(As::Tuple) = Base.tail(As)
83-
_tail(As::AbstractVector) = @inbounds @views As[2:end]
86+
_tail(As::AbstractVector) = @inbounds @views As[begin+1:end]
8487

8588
_combine(A::LinearMap, B::LinearMap) = tuple(A, B)
8689
_combine(A::LinearMap, Bs::LinearMapTuple) = tuple(A, Bs...)
@@ -330,9 +333,9 @@ function _generic_map_mul!(Y, A, s::Number, α, β)
330333
return Y
331334
end
332335

333-
include("left.jl") # left multiplication by a transpose or adjoint vector
334336
include("transpose.jl") # transposing linear maps
335337
include("wrappedmap.jl") # wrap a matrix of linear map in a new type, thereby allowing to alter its properties
338+
include("left.jl") # left multiplication by a matrix/transpose or adjoint vector
336339
include("uniformscalingmap.jl") # the uniform scaling map, to be able to make linear combinations of LinearMap objects and multiples of I
337340
include("linearcombination.jl") # defining linear combinations of linear maps
338341
include("scaledmap.jl") # multiply by a (real or complex) scalar

src/fillmap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Base.size(A::FillMap) = A.size
2121
MulStyle(A::FillMap) = FiveArg()
2222
LinearAlgebra.issymmetric(A::FillMap) = A.size[1] == A.size[2]
2323
LinearAlgebra.ishermitian(A::FillMap) = isreal(A.λ) && A.size[1] == A.size[2]
24-
LinearAlgebra.isposdef(A::FillMap) = (size(A, 1) == size(A, 2) == 1 && isposdef(A.λ))
24+
LinearAlgebra.isposdef(A::FillMap) = (LinearAlgebra.checksquare(A) == 1 && isposdef(A.λ))
2525
Base.:(==)(A::FillMap, B::FillMap) = A.λ == B.λ && A.size == B.size
2626

2727
LinearAlgebra.adjoint(A::FillMap) = FillMap(adjoint(A.λ), reverse(A.size))

src/kronecker.jl

Lines changed: 167 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,43 @@ for k in 3:8 # is 8 sufficient?
7979
kron($(mapargs...), $(Symbol(:A, k)), convert_to_lmaps(As...)...)
8080
end
8181

82+
"""
83+
squarekron(A₁::MapOrMatrix, A₂::MapOrMatrix, A₃::MapOrMatrix, Aᵢ::MapOrMatrix...)::CompositeMap
84+
85+
Construct a (lazy) representation of the Kronecker product `⨂ᵢ₌₁ⁿ Aᵢ` of at least 3 _square_
86+
Kronecker factors. In contrast to [`kron`](@ref), this function assumes that all Kronecker
87+
factors are square, and makes use of the following identity:
88+
89+
```
90+
\\bigotimes_{i=1}^n A_i = \\prod_{i=1}^n I_1 \\otimes \\ldots \\otimes I_{i-1} \\otimes A_i \\otimes I_{i+1} \\otimes \\ldots \\otimes I_n
91+
```
92+
93+
where ```I_k``` is an identity matrix of the size of ```A_k```. By associativity, the
94+
Kronecker product of the identity operators may be combined to larger identity operators
95+
```I_{1:i-1}``` and ```I_{i+1:n}```, which yields
96+
97+
```
98+
\\bigotimes_{i=1}^n A_i = \\prod_{i=1}^n I_{1:i-1} \\otimes A_i \\otimes I_{i+1:n}
99+
```
100+
101+
i.e., a `CompositeMap` where each factor is a Kronecker product consisting of three maps:
102+
outer `UniformScalingMap`s and the respective Kronecker factor. This representation is
103+
expected to yield significantly faster multiplication (and reduce memory allocation)
104+
compared to [`kron`](@ref), but benchmarking intended use cases is highly recommended.
105+
"""
106+
function squarekron(A::MapOrMatrix, B::MapOrMatrix, C::MapOrMatrix, Ds::MapOrMatrix...)
107+
maps = (A, B, C, Ds...)
108+
T = promote_type(map(eltype, maps)...)
109+
all(_issquare, maps) || throw(ArgumentError("operators need to be square in Kronecker sums"))
110+
ns = map(a -> size(a, 1), maps)
111+
firstmap = first(maps) UniformScalingMap(true, prod(ns[2:end]))
112+
lastmap = UniformScalingMap(true, prod(ns[1:end-1])) last(maps)
113+
middlemaps = prod(enumerate(maps[2:end-1])) do (i, map)
114+
UniformScalingMap(true, prod(ns[1:i])) map UniformScalingMap(true, prod(ns[i+2:end]))
115+
end
116+
return firstmap * middlemaps * lastmap
117+
end
118+
82119
struct KronPower{p}
83120
function KronPower(p::Integer)
84121
p > 1 || throw(ArgumentError("the Kronecker power is only defined for exponents larger than 1, got $k"))
@@ -114,74 +151,90 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) =
114151
# multiplication helper functions
115152
#################
116153

117-
@inline function _kronmul!(y, B, x, A, T)
118-
ma, na = size(A)
119-
mb, nb = size(B)
120-
X = reshape(x, (nb, na))
121-
Y = reshape(y, (mb, ma))
122-
if B isa UniformScalingMap
123-
_unsafe_mul!(Y, X, transpose(A))
124-
lmul!(B.λ, y)
154+
@inline function _kronmul!(Y, B, X, A)
155+
# minimize intermediate memory allocation
156+
if size(B, 2) * size(A, 1) <= size(B, 1) * size(A, 2)
157+
temp = similar(Y, (size(B, 2), size(A, 1) ))
158+
_unsafe_mul!(temp, X, transpose(A))
159+
_unsafe_mul!(Y, B, temp)
125160
else
126-
temp = similar(Y, (ma, nb))
127-
_unsafe_mul!(temp, A, copy(transpose(X)))
128-
_unsafe_mul!(Y, B, transpose(temp))
161+
temp = similar(Y, (size(B, 1), size(A, 2)))
162+
_unsafe_mul!(temp, B, X)
163+
_unsafe_mul!(Y, temp, transpose(A))
129164
end
130-
return y
165+
return Y
131166
end
132-
@inline function _kronmul!(y, B, x, A::UniformScalingMap, _)
133-
ma, na = size(A)
134-
mb, nb = size(B)
135-
iszero(A.λ) && return fill!(y, zero(eltype(y)))
136-
X = reshape(x, (nb, na))
137-
Y = reshape(y, (mb, ma))
167+
@inline function _kronmul!(Y, B::UniformScalingMap, X, A)
168+
_unsafe_mul!(Y, X, transpose(A))
169+
!isone(B.λ) && lmul!(B.λ, Y)
170+
return Y
171+
end
172+
@inline function _kronmul!(Y, B, X, A::UniformScalingMap)
138173
_unsafe_mul!(Y, B, X)
139-
!isone(A.λ) && rmul!(y, A.λ)
140-
return y
174+
!isone(A.λ) && rmul!(Y, A.λ)
175+
return Y
141176
end
142-
@inline function _kronmul!(y, B, x, A::VecOrMatMap, _)
143-
ma, na = size(A)
144-
mb, nb = size(B)
145-
X = reshape(x, (nb, na))
146-
Y = reshape(y, (mb, ma))
177+
# disambiguation (cannot occur)
178+
@inline function _kronmul!(Y, B::UniformScalingMap, X, A::UniformScalingMap)
179+
mul!(parent(Y), A.λ * B.λ, parent(X))
180+
return Y
181+
end
182+
@inline function _kronmul!(Y, B, X, A::VecOrMatMap)
147183
At = transpose(A.lmap)
148-
if B isa UniformScalingMap
149-
# the following is (perhaps due to the reshape?) faster than
150-
# _unsafe_mul!(Y, B * X, At)
151-
_unsafe_mul!(Y, X, At)
152-
lmul!(B.λ, y)
153-
elseif nb*ma <= mb*na
184+
if size(B, 2) * size(A, 1) <= size(B, 1) * size(A, 2)
154185
_unsafe_mul!(Y, B, X * At)
155186
else
156187
_unsafe_mul!(Y, Matrix(B * X), At)
157188
end
158-
return y
189+
return Y
190+
end
191+
@inline function _kronmul!(Y, B::UniformScalingMap, X, A::VecOrMatMap)
192+
_unsafe_mul!(Y, X, transpose(A.lmap))
193+
!isone(B.λ) && lmul!(B.λ, Y)
194+
return Y
159195
end
196+
160197
const VectorMap{T} = WrappedMap{T,<:AbstractVector}
161198
const AdjOrTransVectorMap{T} = WrappedMap{T,<:LinearAlgebra.AdjOrTransAbsVec}
162-
@inline _kronmul!(y, B::AdjOrTransVectorMap, x, a::VectorMap, _) = mul!(y, a.lmap, B.lmap * x)
163199

164200
#################
165201
# multiplication with vectors
166202
#################
167203

168204
const KroneckerMap2{T} = KroneckerMap{T, <:Tuple{LinearMap, LinearMap}}
169-
205+
const OuterProductMap{T} = KroneckerMap{T, <:Tuple{VectorMap, AdjOrTransVectorMap}}
206+
function _unsafe_mul!(y::AbstractVecOrMat, L::OuterProductMap, x::AbstractVector)
207+
a, bt = L.maps
208+
mul!(y, a.lmap, bt.lmap * x)
209+
end
170210
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap2, x::AbstractVector)
171211
require_one_based_indexing(y)
172212
A, B = L.maps
173-
_kronmul!(y, B, x, A, eltype(L))
213+
ma, na = size(A)
214+
mb, nb = size(B)
215+
X = reshape(x, (nb, na))
216+
Y = reshape(y, (mb, ma))
217+
_kronmul!(Y, B, X, A)
174218
return y
175219
end
176220
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap, x::AbstractVector)
177221
require_one_based_indexing(y)
178222
maps = L.maps
179223
if length(maps) == 2 # reachable only for L.maps::Vector
180-
@inbounds _kronmul!(y, maps[2], x, maps[1], eltype(L))
224+
A, B = maps
225+
ma, na = size(A)
226+
mb, nb = size(B)
227+
X = reshape(x, (nb, na))
228+
Y = reshape(y, (mb, ma))
229+
_kronmul!(Y, B, X, A)
181230
else
182231
A = first(maps)
183232
B = KroneckerMap{eltype(L)}(_tail(maps))
184-
_kronmul!(y, B, x, A, eltype(L))
233+
ma, na = size(A)
234+
mb, nb = size(B)
235+
X = reshape(x, (nb, na))
236+
Y = reshape(y, (mb, ma))
237+
_kronmul!(Y, B, X, A)
185238
end
186239
return y
187240
end
@@ -225,7 +278,7 @@ struct KroneckerSumMap{T, As<:Tuple{LinearMap, LinearMap}} <: LinearMap{T}
225278
maps::As
226279
function KroneckerSumMap{T}(maps::Tuple{LinearMap,LinearMap}) where {T}
227280
A1, A2 = maps
228-
(size(A1, 1) == size(A1, 2) && size(A2, 1) == size(A2, 2)) ||
281+
(_issquare(A1) && _issquare(A2)) ||
229282
throw(ArgumentError("operators need to be square in Kronecker sums"))
230283
for TA in Base.Iterators.map(eltype, maps)
231284
promote_type(T, TA) == T ||
@@ -269,6 +322,66 @@ kronsum(A::MapOrMatrix, B::MapOrMatrix) =
269322
kronsum(A::MapOrMatrix, B::MapOrMatrix, C::MapOrMatrix, Ds::MapOrMatrix...) =
270323
kronsum(A, kronsum(B, C, Ds...))
271324

325+
"""
326+
sumkronsum(A, B)::LinearCombination
327+
sumkronsum(A, B, Cs...)::LinearCombination
328+
329+
Construct a (lazy) representation of the Kronecker sum `A⊕B` of two or more square
330+
objects of type `LinearMap` or `AbstractMatrix`. This function makes use of the following
331+
representation of Kronecker sums:
332+
333+
```
334+
\\bigoplus_{i=1}^n A_i = \\sum_{i=1}^n I_1 \\otimes \\ldots \\otimes I_{i-1} \\otimes A_i \\otimes I_{i+1} \\otimes \\ldots \\otimes I_n
335+
```
336+
337+
where ```I_k``` is the identity operator of the size of ```A_k```. By associativity, the
338+
Kronecker product of the identity operators may be combined to larger identity operators
339+
```I_{1:i-1}``` and ```I_{i+1:n}```, which yields
340+
341+
```
342+
\\bigoplus_{i=1}^n A_i = \\sum_{i=1}^n I_{1:i-1} \\otimes A_i \\otimes I_{i+1:n}
343+
```
344+
345+
i.e., a `LinearCombination` where each summand is a Kronecker product consisting of three
346+
maps: outer `UniformScalingMap`s and the respective Kronecker factor. This representation is
347+
expected to yield significantly faster multiplication (and reduce memory allocation)
348+
compared to [`kronsum`](@ref), especially for 3 or more Kronecker summands, but
349+
benchmarking intended use cases is highly recommended.
350+
351+
# Examples
352+
```jldoctest; setup=(using LinearAlgebra, SparseArrays, LinearMaps)
353+
julia> J = LinearMap(I, 2) # 2×2 identity map
354+
2×2 LinearMaps.UniformScalingMap{Bool} with scaling factor: true
355+
356+
julia> E = spdiagm(-1 => trues(1)); D = LinearMap(E + E' - 2I);
357+
358+
julia> Δ₁ = kron(D, J) + kron(J, D); # discrete 2D-Laplace operator, Kronecker sum
359+
360+
julia> Δ₂ = sumkronsum(D, D);
361+
362+
julia> Δ₃ = D^⊕(2);
363+
364+
julia> Matrix(Δ₁) == Matrix(Δ₂) == Matrix(Δ₃)
365+
true
366+
```
367+
"""
368+
function sumkronsum(A::MapOrMatrix, B::MapOrMatrix)
369+
LinearAlgebra.checksquare(A, B)
370+
A UniformScalingMap(true, size(B,1)) + UniformScalingMap(true, size(A,1)) B
371+
end
372+
function sumkronsum(A::MapOrMatrix, B::MapOrMatrix, C::MapOrMatrix, Ds::MapOrMatrix...)
373+
maps = (A, B, C, Ds...)
374+
all(_issquare, maps) || throw(ArgumentError("operators need to be square in Kronecker sums"))
375+
ns = map(a -> size(a, 1), maps)
376+
n = length(maps)
377+
firstmap = first(maps) UniformScalingMap(true, prod(ns[2:end]))
378+
lastmap = UniformScalingMap(true, prod(ns[1:end-1])) last(maps)
379+
middlemaps = sum(enumerate(Base.front(Base.tail(maps)))) do (i, map)
380+
UniformScalingMap(true, prod(ns[1:i])) map UniformScalingMap(true, prod(ns[i+2:end]))
381+
end
382+
return firstmap + middlemaps + lastmap
383+
end
384+
272385
struct KronSumPower{p}
273386
function KronSumPower(p::Integer)
274387
p > 1 || throw(ArgumentError("the Kronecker sum power is only defined for exponents larger than 1, got $k"))
@@ -280,14 +393,24 @@ end
280393
⊕(k::Integer)
281394
282395
Construct a lazy representation of the `k`-th Kronecker sum power `A^⊕(k) = A ⊕ A ⊕ ... ⊕ A`,
283-
where `A` can be a square `AbstractMatrix` or a `LinearMap`.
396+
where `A` can be a square `AbstractMatrix` or a `LinearMap`. This calls [`sumkronsum`](@ref)
397+
on the `k`-tuple `(A, ..., A)`.
398+
399+
# Example
400+
```jldoctest
401+
julia> Matrix([1 0; 0 1]^⊕(2))
402+
4×4 Matrix{Int64}:
403+
2 0 0 0
404+
0 2 0 0
405+
0 0 2 0
406+
0 0 0 2
284407
"""
285408
(k::Integer) = KronSumPower(k)
286409

287410
(a, b, c...) = kronsum(a, b, c...)
288411

289412
Base.:(^)(A::MapOrMatrix, ::KronSumPower{p}) where {p} =
290-
kronsum(ntuple(n -> convert(LinearMap, A), Val(p))...)
413+
sumkronsum(ntuple(n -> convert(LinearMap, A), Val(p))...)
291414

292415
Base.size(A::KroneckerSumMap, i) = prod(size.(A.maps, i))
293416
Base.size(A::KroneckerSumMap) = (size(A, 1), size(A, 2))
@@ -305,10 +428,10 @@ Base.:(==)(A::KroneckerSumMap, B::KroneckerSumMap) =
305428

306429
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerSumMap, x::AbstractVector)
307430
A, B = L.maps
308-
ma, na = size(A)
309-
mb, nb = size(B)
310-
X = reshape(x, (nb, na))
311-
Y = reshape(y, (nb, na))
431+
a = size(A, 1)
432+
b = size(B, 1)
433+
X = reshape(x, (b, a))
434+
Y = reshape(y, (b, a))
312435
_unsafe_mul!(Y, X, transpose(A))
313436
_unsafe_mul!(Y, B, X, true, true)
314437
return y

0 commit comments

Comments
 (0)