Skip to content

Commit e4cc72c

Browse files
authored
Merge branch 'master' into dk/iipfunctionmap
2 parents 0e59e09 + 50f5bf9 commit e4cc72c

File tree

8 files changed

+163
-1
lines changed

8 files changed

+163
-1
lines changed

docs/src/history.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@
77
such as [`Flux.jl`](https://fluxml.ai/Flux.jl/stable/). The reverse differentiation rule
88
makes `A::LinearMap` usable as a static, i.e., non-trainable, layer in a network, and
99
requires the adjoint `A'` of `A` to be defined.
10+
* New map types called `KhatriRaoMap` and `FaceSplittingMap` are introduced. These
11+
correspond to lazy representations of the [column-wise Kronecker product](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Column-wise_Kronecker_product)
12+
and the [row-wise Kronecker product](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Face-splitting_product)
13+
(or "transposed Khatri-Rao product"), respectively. They can be constructed from two
14+
matrices `A` and `B` via `khatrirao(A, B)` and `facesplitting(A, B)`, respectively.
15+
The first is particularly efficient as it makes use of the vec-trick for Kronecker
16+
products and computes `y = khatrirao(A, B) * x` for a vector `x` as
17+
`y = vec(B * Diagonal(x) * transpose(A))`. As such, the Khatri-Rao product can actually
18+
be built for general `LinearMap`s, including function-based types. Even for moderate
19+
sizes of 5 or more columns, this map-vector product is faster than first creating the
20+
explicit Khatri-Rao product in memory and then multiplying with the vector; not to
21+
mention the memory savings. Unfortunately, similar efficiency cannot be achieved for the
22+
face-splitting product.
1023

1124
## What's new in v3.8
1225

docs/src/types.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,18 @@ Type for lazy inverse of another linear map.
118118
InverseMap
119119
```
120120

121+
### `KhatriRaoMap` and `FaceSplittingMap`
122+
123+
Types for lazy [column-wise](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Column-wise_Kronecker_product)
124+
and [row-wise](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Face-splitting_product)
125+
Kronecker product, respectively, also referrerd to
126+
as Khatri-Rao and transposed Khatri-Rao (or face-splitting) product.
127+
128+
```@docs
129+
khatrirao
130+
facesplitting
131+
```
132+
121133
## Methods
122134

123135
### Multiplication methods

src/LinearMaps.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module LinearMaps
22

33
export LinearMap, FunctionMap, FillMap, InverseMap
4-
export , squarekron, kronsum, , sumkronsum
4+
export , squarekron, kronsum, , sumkronsum, khatrirao, facesplitting
55

66
using LinearAlgebra
77
using LinearAlgebra: AbstractQ
@@ -345,6 +345,7 @@ include("scaledmap.jl") # multiply by a (real or complex) scalar
345345
include("composition.jl") # composition of linear maps
346346
include("blockmap.jl") # block linear maps
347347
include("kronecker.jl") # Kronecker product of linear maps
348+
include("khatrirao.jl") # Khatri-Rao and face-splitting products
348349
include("fillmap.jl") # linear maps representing constantly filled matrices
349350
include("embeddedmap.jl") # embedded linear maps
350351
include("conversion.jl") # conversion of linear maps to matrices

src/khatrirao.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
struct KhatriRaoMap{T,A<:Tuple{MapOrVecOrMat,MapOrVecOrMat}} <: LinearMap{T}
2+
maps::A
3+
function KhatriRaoMap{T,As}(maps::As) where {T,As<:Tuple{MapOrVecOrMat,MapOrVecOrMat}}
4+
@assert promote_type(T, map(eltype, maps)...) == T "eltype $(eltype(A)) cannot be promoted to $T in KhatriRaoMap constructor"
5+
@inbounds size(maps[1], 2) == size(maps[2], 2) || throw(ArgumentError("matrices need equal number of columns"))
6+
new{T,As}(maps)
7+
end
8+
end
9+
KhatriRaoMap{T}(maps::As) where {T, As} = KhatriRaoMap{T, As}(maps)
10+
11+
"""
12+
khatrirao(A::MapOrVecOrMat, B::MapOrVecOrMat) -> KhatriRaoMap
13+
14+
Construct a lazy representation of the Khatri-Rao (or column-wise Kronecker) product of two
15+
maps or arrays `A` and `B`. For the application to vectors, the tranpose action of `A` on
16+
vectors needs to be defined.
17+
"""
18+
khatrirao(A::MapOrVecOrMat, B::MapOrVecOrMat) =
19+
KhatriRaoMap{Base.promote_op(*, eltype(A), eltype(B))}((A, B))
20+
21+
struct FaceSplittingMap{T,A<:Tuple{AbstractMatrix,AbstractMatrix}} <: LinearMap{T}
22+
maps::A
23+
function FaceSplittingMap{T,As}(maps::As) where {T,As<:Tuple{AbstractMatrix,AbstractMatrix}}
24+
@assert promote_type(T, map(eltype, maps)...) == T "eltype $(eltype(A)) cannot be promoted to $T in KhatriRaoMap constructor"
25+
@inbounds size(maps[1], 1) == size(maps[2], 1) || throw(ArgumentError("matrices need equal number of columns, got $(size(maps[1], 1)) and $(size(maps[2], 1))"))
26+
new{T,As}(maps)
27+
end
28+
end
29+
FaceSplittingMap{T}(maps::As) where {T, As} = FaceSplittingMap{T, As}(maps)
30+
31+
"""
32+
facesplitting(A::AbstractMatrix, B::AbstractMatrix) -> FaceSplittingMap
33+
34+
Construct a lazy representation of the face-splitting (or row-wise Kronecker) product of
35+
two matrices `A` and `B`.
36+
"""
37+
facesplitting(A::AbstractMatrix, B::AbstractMatrix) =
38+
FaceSplittingMap{Base.promote_op(*, eltype(A), eltype(B))}((A, B))
39+
40+
Base.size(K::KhatriRaoMap) = ((A, B) = K.maps; (size(A, 1) * size(B, 1), size(A, 2)))
41+
Base.size(K::FaceSplittingMap) = ((A, B) = K.maps; (size(A, 1), size(A, 2) * size(B, 2)))
42+
Base.adjoint(K::KhatriRaoMap) = facesplitting(map(adjoint, K.maps)...)
43+
Base.adjoint(K::FaceSplittingMap) = khatrirao(map(adjoint, K.maps)...)
44+
Base.transpose(K::KhatriRaoMap) = facesplitting(map(transpose, K.maps)...)
45+
Base.transpose(K::FaceSplittingMap) = khatrirao(map(transpose, K.maps)...)
46+
47+
LinearMaps.MulStyle(::Union{KhatriRaoMap,FaceSplittingMap}) = FiveArg()
48+
49+
function _unsafe_mul!(y, K::KhatriRaoMap, x::AbstractVector)
50+
A, B = K.maps
51+
Y = reshape(y, (size(B, 1), size(A, 1)))
52+
if size(B, 1) <= size(A, 1)
53+
mul!(Y, convert(Matrix, B * Diagonal(x)), transpose(A))
54+
else
55+
mul!(Y, B, transpose(convert(Matrix, A * transpose(Diagonal(x)))))
56+
end
57+
return y
58+
end
59+
function _unsafe_mul!(y, K::KhatriRaoMap, x::AbstractVector, α, β)
60+
A, B = K.maps
61+
Y = reshape(y, (size(B, 1), size(A, 1)))
62+
if size(B, 1) <= size(A, 1)
63+
mul!(Y, convert(Matrix, B * Diagonal(x)), transpose(A), α, β)
64+
else
65+
mul!(Y, B, transpose(convert(Matrix, A * transpose(Diagonal(x)))), α, β)
66+
end
67+
return y
68+
end
69+
70+
function _unsafe_mul!(y, K::FaceSplittingMap, x::AbstractVector)
71+
A, B = K.maps
72+
@inbounds for m in eachindex(y)
73+
y[m] = zero(eltype(y))
74+
l = firstindex(x)
75+
for i in axes(A, 2)
76+
ai = A[m,i]
77+
@simd for k in axes(B, 2)
78+
y[m] += ai*B[m,k]*x[l]
79+
l += 1
80+
end
81+
end
82+
end
83+
return y
84+
end
85+
function _unsafe_mul!(y, K::FaceSplittingMap, x::AbstractVector, α, β)
86+
A, B = K.maps
87+
@inbounds for m in eachindex(y)
88+
y[m] *= β
89+
l = firstindex(x)
90+
for i in axes(A, 2)
91+
ai = A[m,i]
92+
@simd for k in axes(B, 2)
93+
y[m] += ai*B[m,k]*x[l]*α
94+
l += 1
95+
end
96+
end
97+
end
98+
return y
99+
end

src/kronecker.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,12 @@ function _unsafe_mul!(y, L::OuterProductMap, x::AbstractVector)
209209
a, bt = L.maps
210210
mul!(y, a.lmap, bt.lmap * x)
211211
end
212+
function _unsafe_mul!(y, L::KroneckerMap{<:Any,<:Tuple{VectorMap,VectorMap}}, x::AbstractVector)
213+
a, b = L.maps
214+
kron!(y, a.lmap, b.lmap)
215+
rmul!(y, first(x))
216+
return y
217+
end
212218
function _unsafe_mul!(y, L::KroneckerMap2, x::AbstractVector)
213219
require_one_based_indexing(y)
214220
A, B = L.maps

test/khatrirao.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using Test, LinearMaps, LinearAlgebra
2+
3+
@testset "KhatriRaoMap & FaceSplittingMap" begin
4+
for trans in (identity, complex), m in (2, 4)
5+
A = collect(reshape(trans(1:6), 3, 2))
6+
B = collect(reshape(trans(1:2m), m, 2))
7+
K = @inferred khatrirao(A, B)
8+
@test facesplitting(A', B')' === K
9+
M = mapreduce(kron, hcat, eachcol(A), eachcol(B))
10+
Mx = mapreduce((a, b) -> kron(permutedims(a), permutedims(b)), vcat, eachrow(A'), eachrow(B'))
11+
@test size(K) == size(M)
12+
@test size(@inferred adjoint(K)) == reverse(size(K))
13+
@test size(@inferred transpose(K)) == reverse(size(K))
14+
@test Matrix(K) == M
15+
@test Matrix(K') == Mx
16+
@test LinearMaps.MulStyle(K) === LinearMaps.MulStyle(K') === LinearMaps.FiveArg()
17+
@test (K')' === K
18+
@test transpose(transpose(K)) === K
19+
x = trans(rand(-10:10, size(K, 2)))
20+
y = trans(rand(-10:10, size(K, 1)))
21+
for α in (false, true, trans(rand(2:5))), β in (false, true, trans(rand(2:5)))
22+
@test mul!(copy(y), K, x, α, β) == y * β + K * x * α
23+
@test mul!(copy(x), K', y, α, β) == x * β + K' * y * α
24+
end
25+
end
26+
end

test/kronecker.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
3737
@test Matrix(L) M
3838
@test L * v M * v
3939
end
40+
L = ones(3) (b = rand(ComplexF64, 4))
41+
@test L * [2] kron(ones(3), b) * 2
42+
@test Matrix(L) kron(ones(3), b) rtol=2eps(Float64)
4043
L = ones(3) ones(ComplexF64, 4)'
4144
v = rand(4)
4245
@test Matrix(L) == ones(3,4)

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,5 @@ include("getindex.jl")
4141
include("inversemap.jl")
4242

4343
include("rrules.jl")
44+
45+
include("khatrirao.jl")

0 commit comments

Comments
 (0)