Skip to content

Commit e39cbc1

Browse files
authored
Square Kronecker products and sums (#183)
1 parent 10b6901 commit e39cbc1

File tree

6 files changed

+233
-74
lines changed

6 files changed

+233
-74
lines changed

docs/src/history.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
argument to the constructor (see the docstring for details). Note that `A` must be
99
compatible with the solver: `A` can, for example, be a factorization, or another
1010
`LinearMap` in combination with an iterative solver.
11+
* New constructors for lazy representations of Kronecker products ([`squarekron`](@ref))
12+
and sums ([`sumkronsum`](@ref)) for _square_ factors and summands, respectively, are
13+
introduced. They target cases with 3 or more factors/summands, and benchmarking intended
14+
use cases for comparison with `KroneckerMap` (constructed via [`Base.kron`](@ref)) and
15+
`KroneckerSumMap` (constructed via [`kronsum`](@ref)) is recommended.
1116

1217
## What's new in v3.7
1318

docs/src/types.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ kronsum
7272
LinearMaps.:⊕
7373
```
7474

75+
There exist alternative constructors of Kronecker products and sums for square factors and
76+
summands, respectively. These are designed for cases of 3 or more arguments, and
77+
benchmarking intended use cases for comparison with `KroneckerMap` and `KroneckerSumMap`
78+
is recommended.
79+
80+
```@docs
81+
squarekron
82+
sumkronsum
83+
```
84+
7585
### `BlockMap` and `BlockDiagonalMap`
7686

7787
Types for representing block (diagonal) maps lazily.

src/LinearMaps.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module LinearMaps
22

33
export LinearMap
4-
export , kronsum,
4+
export , squarekron, kronsum, , sumkronsum
55
export FillMap
66
export InverseMap
77

@@ -78,6 +78,8 @@ function check_dim_mul(C, A, B)
7878
return nothing
7979
end
8080

81+
_issquare(A) = size(A, 1) == size(A, 2)
82+
8183
_front(As::Tuple) = Base.front(As)
8284
_front(As::AbstractVector) = @inbounds @views As[begin:end-1]
8385
_tail(As::Tuple) = Base.tail(As)
@@ -331,7 +333,7 @@ end
331333

332334
include("transpose.jl") # transposing linear maps
333335
include("wrappedmap.jl") # wrap a matrix of linear map in a new type, thereby allowing to alter its properties
334-
include("left.jl") # left multiplication by a transpose or adjoint vector
336+
include("left.jl") # left multiplication by a matrix/transpose or adjoint vector
335337
include("uniformscalingmap.jl") # the uniform scaling map, to be able to make linear combinations of LinearMap objects and multiples of I
336338
include("linearcombination.jl") # defining linear combinations of linear maps
337339
include("scaledmap.jl") # multiply by a (real or complex) scalar

src/kronecker.jl

Lines changed: 173 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,45 @@ for k in 3:8 # is 8 sufficient?
7979
kron($(mapargs...), $(Symbol(:A, k)), convert_to_lmaps(As...)...)
8080
end
8181

82+
@doc raw"""
83+
squarekron(A₁::MapOrMatrix, A₂::MapOrMatrix, A₃::MapOrMatrix, Aᵢ::MapOrMatrix...)::CompositeMap
84+
85+
Construct a (lazy) representation of the Kronecker product `⨂ᵢ₌₁ⁿ Aᵢ` of at least 3 _square_
86+
Kronecker factors. In contrast to [`kron`](@ref), this function assumes that all Kronecker
87+
factors are square, and makes use of the following identity[^1]:
88+
89+
```math
90+
\bigotimes_{i=1}^n A_i = \prod_{i=1}^n I_1 \otimes \ldots \otimes I_{i-1} \otimes A_i \otimes I_{i+1} \otimes \ldots \otimes I_n
91+
```
92+
93+
where ``I_k`` is an identity matrix of the size of ``A_k``. By associativity, the
94+
Kronecker product of the identity operators may be combined to larger identity operators
95+
``I_{1:i-1}`` and ``I_{i+1:n}``, which yields
96+
97+
```math
98+
\bigotimes_{i=1}^n A_i = \prod_{i=1}^n I_{1:i-1} \otimes A_i \otimes I_{i+1:n}
99+
```
100+
101+
i.e., a `CompositeMap` where each factor is a Kronecker product consisting of three maps:
102+
outer `UniformScalingMap`s and the respective Kronecker factor. This representation is
103+
expected to yield significantly faster multiplication (and reduce memory allocation)
104+
compared to [`kron`](@ref), but benchmarking intended use cases is highly recommended.
105+
106+
[^1]: Fernandes, P. and Plateau, B. and Stewart, W. J. ["Efficient Descriptor-Vector Multiplications in Stochastic Automata Networks"](https://doi.org/10.1145/278298.278303), _Journal of the ACM_, 45(3), 381–414, 1998.
107+
"""
108+
function squarekron(A::MapOrMatrix, B::MapOrMatrix, C::MapOrMatrix, Ds::MapOrMatrix...)
109+
maps = (A, B, C, Ds...)
110+
T = promote_type(map(eltype, maps)...)
111+
all(_issquare, maps) || throw(ArgumentError("operators need to be square in Kronecker sums"))
112+
ns = map(a -> size(a, 1), maps)
113+
firstmap = first(maps) UniformScalingMap(true, prod(ns[2:end]))
114+
lastmap = UniformScalingMap(true, prod(ns[1:end-1])) last(maps)
115+
middlemaps = prod(enumerate(maps[2:end-1])) do (i, map)
116+
UniformScalingMap(true, prod(ns[1:i])) map UniformScalingMap(true, prod(ns[i+2:end]))
117+
end
118+
return firstmap * middlemaps * lastmap
119+
end
120+
82121
struct KronPower{p}
83122
function KronPower(p::Integer)
84123
p > 1 || throw(ArgumentError("the Kronecker power is only defined for exponents larger than 1, got $k"))
@@ -114,74 +153,90 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) =
114153
# multiplication helper functions
115154
#################
116155

117-
@inline function _kronmul!(y, B, x, A, T)
118-
ma, na = size(A)
119-
mb, nb = size(B)
120-
X = reshape(x, (nb, na))
121-
Y = reshape(y, (mb, ma))
122-
if B isa UniformScalingMap
123-
_unsafe_mul!(Y, X, transpose(A))
124-
lmul!(B.λ, y)
156+
@inline function _kronmul!(Y, B, X, A)
157+
# minimize intermediate memory allocation
158+
if size(B, 2) * size(A, 1) <= size(B, 1) * size(A, 2)
159+
temp = similar(Y, (size(B, 2), size(A, 1) ))
160+
_unsafe_mul!(temp, X, transpose(A))
161+
_unsafe_mul!(Y, B, temp)
125162
else
126-
temp = similar(Y, (ma, nb))
127-
_unsafe_mul!(temp, A, copy(transpose(X)))
128-
_unsafe_mul!(Y, B, transpose(temp))
163+
temp = similar(Y, (size(B, 1), size(A, 2)))
164+
_unsafe_mul!(temp, B, X)
165+
_unsafe_mul!(Y, temp, transpose(A))
129166
end
130-
return y
167+
return Y
131168
end
132-
@inline function _kronmul!(y, B, x, A::UniformScalingMap, _)
133-
ma, na = size(A)
134-
mb, nb = size(B)
135-
iszero(A.λ) && return fill!(y, zero(eltype(y)))
136-
X = reshape(x, (nb, na))
137-
Y = reshape(y, (mb, ma))
169+
@inline function _kronmul!(Y, B::UniformScalingMap, X, A)
170+
_unsafe_mul!(Y, X, transpose(A))
171+
!isone(B.λ) && lmul!(B.λ, Y)
172+
return Y
173+
end
174+
@inline function _kronmul!(Y, B, X, A::UniformScalingMap)
138175
_unsafe_mul!(Y, B, X)
139-
!isone(A.λ) && rmul!(y, A.λ)
140-
return y
176+
!isone(A.λ) && rmul!(Y, A.λ)
177+
return Y
141178
end
142-
@inline function _kronmul!(y, B, x, A::VecOrMatMap, _)
143-
ma, na = size(A)
144-
mb, nb = size(B)
145-
X = reshape(x, (nb, na))
146-
Y = reshape(y, (mb, ma))
179+
# disambiguation (cannot occur)
180+
@inline function _kronmul!(Y, B::UniformScalingMap, X, A::UniformScalingMap)
181+
mul!(parent(Y), A.λ * B.λ, parent(X))
182+
return Y
183+
end
184+
@inline function _kronmul!(Y, B, X, A::VecOrMatMap)
147185
At = transpose(A.lmap)
148-
if B isa UniformScalingMap
149-
# the following is (perhaps due to the reshape?) faster than
150-
# _unsafe_mul!(Y, B * X, At)
151-
_unsafe_mul!(Y, X, At)
152-
lmul!(B.λ, y)
153-
elseif nb*ma <= mb*na
186+
if size(B, 2) * size(A, 1) <= size(B, 1) * size(A, 2)
154187
_unsafe_mul!(Y, B, X * At)
155188
else
156189
_unsafe_mul!(Y, Matrix(B * X), At)
157190
end
158-
return y
191+
return Y
192+
end
193+
@inline function _kronmul!(Y, B::UniformScalingMap, X, A::VecOrMatMap)
194+
_unsafe_mul!(Y, X, transpose(A.lmap))
195+
!isone(B.λ) && lmul!(B.λ, Y)
196+
return Y
159197
end
198+
160199
const VectorMap{T} = WrappedMap{T,<:AbstractVector}
161200
const AdjOrTransVectorMap{T} = WrappedMap{T,<:LinearAlgebra.AdjOrTransAbsVec}
162-
@inline _kronmul!(y, B::AdjOrTransVectorMap, x, a::VectorMap, _) = mul!(y, a.lmap, B.lmap * x)
163201

164202
#################
165203
# multiplication with vectors
166204
#################
167205

168206
const KroneckerMap2{T} = KroneckerMap{T, <:Tuple{LinearMap, LinearMap}}
169-
207+
const OuterProductMap{T} = KroneckerMap{T, <:Tuple{VectorMap, AdjOrTransVectorMap}}
208+
function _unsafe_mul!(y, L::OuterProductMap, x::AbstractVector)
209+
a, bt = L.maps
210+
mul!(y, a.lmap, bt.lmap * x)
211+
end
170212
function _unsafe_mul!(y, L::KroneckerMap2, x::AbstractVector)
171213
require_one_based_indexing(y)
172214
A, B = L.maps
173-
_kronmul!(y, B, x, A, eltype(L))
215+
ma, na = size(A)
216+
mb, nb = size(B)
217+
X = reshape(x, (nb, na))
218+
Y = reshape(y, (mb, ma))
219+
_kronmul!(Y, B, X, A)
174220
return y
175221
end
176222
function _unsafe_mul!(y, L::KroneckerMap, x::AbstractVector)
177223
require_one_based_indexing(y)
178224
maps = L.maps
179225
if length(maps) == 2 # reachable only for L.maps::Vector
180-
@inbounds _kronmul!(y, maps[2], x, maps[1], eltype(L))
226+
A, B = maps
227+
ma, na = size(A)
228+
mb, nb = size(B)
229+
X = reshape(x, (nb, na))
230+
Y = reshape(y, (mb, ma))
231+
_kronmul!(Y, B, X, A)
181232
else
182233
A = first(maps)
183234
B = KroneckerMap{eltype(L)}(_tail(maps))
184-
_kronmul!(y, B, x, A, eltype(L))
235+
ma, na = size(A)
236+
mb, nb = size(B)
237+
X = reshape(x, (nb, na))
238+
Y = reshape(y, (mb, ma))
239+
_kronmul!(Y, B, X, A)
185240
end
186241
return y
187242
end
@@ -225,7 +280,7 @@ struct KroneckerSumMap{T, As<:Tuple{LinearMap, LinearMap}} <: LinearMap{T}
225280
maps::As
226281
function KroneckerSumMap{T}(maps::Tuple{LinearMap,LinearMap}) where {T}
227282
A1, A2 = maps
228-
(size(A1, 1) == size(A1, 2) && size(A2, 1) == size(A2, 2)) ||
283+
(_issquare(A1) && _issquare(A2)) ||
229284
throw(ArgumentError("operators need to be square in Kronecker sums"))
230285
for TA in Base.Iterators.map(eltype, maps)
231286
promote_type(T, TA) == T ||
@@ -269,6 +324,68 @@ kronsum(A::MapOrMatrix, B::MapOrMatrix) =
269324
kronsum(A::MapOrMatrix, B::MapOrMatrix, C::MapOrMatrix, Ds::MapOrMatrix...) =
270325
kronsum(A, kronsum(B, C, Ds...))
271326

327+
@doc raw"""
328+
sumkronsum(A, B)::LinearCombination
329+
sumkronsum(A, B, Cs...)::LinearCombination
330+
331+
Construct a (lazy) representation of the Kronecker sum `A⊕B` of two or more square
332+
objects of type `LinearMap` or `AbstractMatrix`. This function makes use of the following
333+
representation of Kronecker sums[^1]:
334+
335+
```math
336+
\bigoplus_{i=1}^n A_i = \sum_{i=1}^n I_1 \otimes \ldots \otimes I_{i-1} \otimes A_i \otimes I_{i+1} \otimes \ldots \otimes I_n
337+
```
338+
339+
where ``I_k`` is the identity operator of the size of ``A_k``. By associativity, the
340+
Kronecker product of the identity operators may be combined to larger identity operators
341+
``I_{1:i-1}`` and ``I_{i+1:n}``, which yields
342+
343+
```math
344+
\bigoplus_{i=1}^n A_i = \sum_{i=1}^n I_{1:i-1} \otimes A_i \otimes I_{i+1:n},
345+
```
346+
347+
i.e., a `LinearCombination` where each summand is a Kronecker product consisting of three
348+
maps: outer `UniformScalingMap`s and the respective Kronecker factor. This representation is
349+
expected to yield significantly faster multiplication (and reduce memory allocation)
350+
compared to [`kronsum`](@ref), especially for 3 or more Kronecker summands, but
351+
benchmarking intended use cases is highly recommended.
352+
353+
# Examples
354+
```jldoctest; setup=(using LinearAlgebra, SparseArrays, LinearMaps)
355+
julia> J = LinearMap(I, 2) # 2×2 identity map
356+
2×2 LinearMaps.UniformScalingMap{Bool} with scaling factor: true
357+
358+
julia> E = spdiagm(-1 => trues(1)); D = LinearMap(E + E' - 2I);
359+
360+
julia> Δ₁ = kron(D, J) + kron(J, D); # discrete 2D-Laplace operator, Kronecker sum
361+
362+
julia> Δ₂ = sumkronsum(D, D);
363+
364+
julia> Δ₃ = D^⊕(2);
365+
366+
julia> Matrix(Δ₁) == Matrix(Δ₂) == Matrix(Δ₃)
367+
true
368+
```
369+
370+
[^1]: Fernandes, P. and Plateau, B. and Stewart, W. J. ["Efficient Descriptor-Vector Multiplications in Stochastic Automata Networks"](https://doi.org/10.1145/278298.278303), _Journal of the ACM_, 45(3), 381–414, 1998.
371+
"""
372+
function sumkronsum(A::MapOrMatrix, B::MapOrMatrix)
373+
LinearAlgebra.checksquare(A, B)
374+
A UniformScalingMap(true, size(B,1)) + UniformScalingMap(true, size(A,1)) B
375+
end
376+
function sumkronsum(A::MapOrMatrix, B::MapOrMatrix, C::MapOrMatrix, Ds::MapOrMatrix...)
377+
maps = (A, B, C, Ds...)
378+
all(_issquare, maps) || throw(ArgumentError("operators need to be square in Kronecker sums"))
379+
ns = map(a -> size(a, 1), maps)
380+
n = length(maps)
381+
firstmap = first(maps) UniformScalingMap(true, prod(ns[2:end]))
382+
lastmap = UniformScalingMap(true, prod(ns[1:end-1])) last(maps)
383+
middlemaps = sum(enumerate(Base.front(Base.tail(maps)))) do (i, map)
384+
UniformScalingMap(true, prod(ns[1:i])) map UniformScalingMap(true, prod(ns[i+2:end]))
385+
end
386+
return firstmap + middlemaps + lastmap
387+
end
388+
272389
struct KronSumPower{p}
273390
function KronSumPower(p::Integer)
274391
p > 1 || throw(ArgumentError("the Kronecker sum power is only defined for exponents larger than 1, got $k"))
@@ -280,14 +397,26 @@ end
280397
⊕(k::Integer)
281398
282399
Construct a lazy representation of the `k`-th Kronecker sum power `A^⊕(k) = A ⊕ A ⊕ ... ⊕ A`,
283-
where `A` can be a square `AbstractMatrix` or a `LinearMap`.
400+
where `A` can be a square `AbstractMatrix` or a `LinearMap`. This calls [`sumkronsum`](@ref)
401+
on the `k`-tuple `(A, ..., A)` for `k ≥ 3`.
402+
403+
# Example
404+
```jldoctest
405+
julia> Matrix([1 0; 0 1]^⊕(2))
406+
4×4 Matrix{Int64}:
407+
2 0 0 0
408+
0 2 0 0
409+
0 0 2 0
410+
0 0 0 2
284411
"""
285412
(k::Integer) = KronSumPower(k)
286413

287414
(a, b, c...) = kronsum(a, b, c...)
288415

416+
Base.:(^)(A::MapOrMatrix, ::KronSumPower{2}) =
417+
kronsum(convert(LinearMap, A), convert(LinearMap, A))
289418
Base.:(^)(A::MapOrMatrix, ::KronSumPower{p}) where {p} =
290-
kronsum(ntuple(n -> convert(LinearMap, A), Val(p))...)
419+
sumkronsum(ntuple(_ -> convert(LinearMap, A), Val(p))...)
291420

292421
Base.size(A::KroneckerSumMap, i) = prod(size.(A.maps, i))
293422
Base.size(A::KroneckerSumMap) = (size(A, 1), size(A, 2))
@@ -305,10 +434,10 @@ Base.:(==)(A::KroneckerSumMap, B::KroneckerSumMap) =
305434

306435
function _unsafe_mul!(y, L::KroneckerSumMap, x::AbstractVector)
307436
A, B = L.maps
308-
ma, na = size(A)
309-
mb, nb = size(B)
310-
X = reshape(x, (nb, na))
311-
Y = reshape(y, (nb, na))
437+
a = size(A, 1)
438+
b = size(B, 1)
439+
X = reshape(x, (b, a))
440+
Y = reshape(y, (b, a))
312441
_unsafe_mul!(Y, X, transpose(A))
313442
_unsafe_mul!(Y, B, X, true, true)
314443
return y

src/left.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function _unsafe_mul!(X, Y::TransposeAbsVecOrMat, A::LinearMap)
7272
return X
7373
end
7474
# unwrap WrappedMaps
75-
_unsafe_mul!(X, Y::AbstractMatrix, A::WrappedMap) = mul!(X, Y, A.lmap)
75+
_unsafe_mul!(X, Y::AbstractMatrix, A::WrappedMap) = _unsafe_mul!(X, Y, A.lmap)
7676
# disambiguation
7777
_unsafe_mul!(X, Y::TransposeAbsVecOrMat, A::WrappedMap) = _unsafe_mul!(X, Y, A.lmap)
7878

0 commit comments

Comments
 (0)