Skip to content

Introduce KhatriRaoMap and FaceSplittingMap #191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/src/history.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@
such as [`Flux.jl`](https://fluxml.ai/Flux.jl/stable/). The reverse differentiation rule
makes `A::LinearMap` usable as a static, i.e., non-trainable, layer in a network, and
requires the adjoint `A'` of `A` to be defined.
* New map types called `KhatriRaoMap` and `FaceSplittingMap` are introduced. These
correspond to lazy representations of the [column-wise Kronecker product](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Column-wise_Kronecker_product)
and the [row-wise Kronecker product](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Face-splitting_product)
(or "transposed Khatri-Rao product"), respectively. They can be constructed from two
matrices `A` and `B` via `khatrirao(A, B)` and `facesplitting(A, B)`, respectively.
The first is particularly efficient as it makes use of the vec-trick for Kronecker
products and computes `y = khatrirao(A, B) * x` for a vector `x` as
`y = vec(B * Diagonal(x) * transpose(A))`. As such, the Khatri-Rao product can actually
be built for general `LinearMap`s, including function-based types. Even for moderate
sizes of 5 or more columns, this map-vector product is faster than first creating the
explicit Khatri-Rao product in memory and then multiplying with the vector; not to
mention the memory savings. Unfortunately, similar efficiency cannot be achieved for the
face-splitting product.

## What's new in v3.8

Expand Down
12 changes: 12 additions & 0 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ Type for lazy inverse of another linear map.
LinearMaps.InverseMap
```

### `KhatriRaoMap` and `FaceSplittingMap`

Types for lazy [column-wise](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Column-wise_Kronecker_product)
and [row-wise](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Face-splitting_product)
Kronecker product, respectively, also referrerd to
as Khatri-Rao and transposed Khatri-Rao (or face-splitting) product.

```@docs
khatrirao
facesplitting
```

## Methods

### Multiplication methods
Expand Down
3 changes: 2 additions & 1 deletion src/LinearMaps.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LinearMaps

export LinearMap
export ⊗, squarekron, kronsum, ⊕, sumkronsum
export ⊗, squarekron, kronsum, ⊕, sumkronsum, khatrirao, facesplitting
export FillMap
export InverseMap

Expand Down Expand Up @@ -344,6 +344,7 @@ include("composition.jl") # composition of linear maps
include("functionmap.jl") # using a function as linear map
include("blockmap.jl") # block linear maps
include("kronecker.jl") # Kronecker product of linear maps
include("khatrirao.jl") # Khatri-Rao and face-splitting products
include("fillmap.jl") # linear maps representing constantly filled matrices
include("embeddedmap.jl") # embedded linear maps
include("conversion.jl") # conversion of linear maps to matrices
Expand Down
99 changes: 99 additions & 0 deletions src/khatrirao.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
struct KhatriRaoMap{T,A<:Tuple{MapOrVecOrMat,MapOrVecOrMat}} <: LinearMap{T}
maps::A
function KhatriRaoMap{T,As}(maps::As) where {T,As<:Tuple{MapOrVecOrMat,MapOrVecOrMat}}
@assert promote_type(T, map(eltype, maps)...) == T "eltype $(eltype(A)) cannot be promoted to $T in KhatriRaoMap constructor"
@inbounds size(maps[1], 2) == size(maps[2], 2) || throw(ArgumentError("matrices need equal number of columns"))
new{T,As}(maps)
end
end
KhatriRaoMap{T}(maps::As) where {T, As} = KhatriRaoMap{T, As}(maps)

"""
khatrirao(A::MapOrVecOrMat, B::MapOrVecOrMat) -> KhatriRaoMap

Construct a lazy representation of the Khatri-Rao (or column-wise Kronecker) product of two
maps or arrays `A` and `B`. For the application to vectors, the tranpose action of `A` on
vectors needs to be defined.
"""
khatrirao(A::MapOrVecOrMat, B::MapOrVecOrMat) =
KhatriRaoMap{Base.promote_op(*, eltype(A), eltype(B))}((A, B))

struct FaceSplittingMap{T,A<:Tuple{AbstractMatrix,AbstractMatrix}} <: LinearMap{T}
maps::A
function FaceSplittingMap{T,As}(maps::As) where {T,As<:Tuple{AbstractMatrix,AbstractMatrix}}
@assert promote_type(T, map(eltype, maps)...) == T "eltype $(eltype(A)) cannot be promoted to $T in KhatriRaoMap constructor"
@inbounds size(maps[1], 1) == size(maps[2], 1) || throw(ArgumentError("matrices need equal number of columns, got $(size(maps[1], 1)) and $(size(maps[2], 1))"))
new{T,As}(maps)
end
end
FaceSplittingMap{T}(maps::As) where {T, As} = FaceSplittingMap{T, As}(maps)

"""
facesplitting(A::AbstractMatrix, B::AbstractMatrix) -> FaceSplittingMap

Construct a lazy representation of the face-splitting (or row-wise Kronecker) product of
two matrices `A` and `B`.
"""
facesplitting(A::AbstractMatrix, B::AbstractMatrix) =
FaceSplittingMap{Base.promote_op(*, eltype(A), eltype(B))}((A, B))

Base.size(K::KhatriRaoMap) = ((A, B) = K.maps; (size(A, 1) * size(B, 1), size(A, 2)))
Base.size(K::FaceSplittingMap) = ((A, B) = K.maps; (size(A, 1), size(A, 2) * size(B, 2)))
Base.adjoint(K::KhatriRaoMap) = facesplitting(map(adjoint, K.maps)...)
Base.adjoint(K::FaceSplittingMap) = khatrirao(map(adjoint, K.maps)...)
Base.transpose(K::KhatriRaoMap) = facesplitting(map(transpose, K.maps)...)
Base.transpose(K::FaceSplittingMap) = khatrirao(map(transpose, K.maps)...)

LinearMaps.MulStyle(::Union{KhatriRaoMap,FaceSplittingMap}) = FiveArg()

function _unsafe_mul!(y, K::KhatriRaoMap, x::AbstractVector)
A, B = K.maps
Y = reshape(y, (size(B, 1), size(A, 1)))
if size(B, 1) <= size(A, 1)
mul!(Y, convert(Matrix, B * Diagonal(x)), transpose(A))
else
mul!(Y, B, transpose(convert(Matrix, A * transpose(Diagonal(x)))))
end
return y
end
function _unsafe_mul!(y, K::KhatriRaoMap, x::AbstractVector, α, β)
A, B = K.maps
Y = reshape(y, (size(B, 1), size(A, 1)))
if size(B, 1) <= size(A, 1)
mul!(Y, convert(Matrix, B * Diagonal(x)), transpose(A), α, β)
else
mul!(Y, B, transpose(convert(Matrix, A * transpose(Diagonal(x)))), α, β)
end
return y
end

function _unsafe_mul!(y, K::FaceSplittingMap, x::AbstractVector)
A, B = K.maps
@inbounds for m in eachindex(y)
y[m] = zero(eltype(y))
l = firstindex(x)
for i in axes(A, 2)
ai = A[m,i]
@simd for k in axes(B, 2)
y[m] += ai*B[m,k]*x[l]
l += 1
end
end
end
return y
end
function _unsafe_mul!(y, K::FaceSplittingMap, x::AbstractVector, α, β)
A, B = K.maps
@inbounds for m in eachindex(y)
y[m] *= β
l = firstindex(x)
for i in axes(A, 2)
ai = A[m,i]
@simd for k in axes(B, 2)
y[m] += ai*B[m,k]*x[l]*α
l += 1
end
end
end
return y
end
6 changes: 6 additions & 0 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ function _unsafe_mul!(y, L::OuterProductMap, x::AbstractVector)
a, bt = L.maps
mul!(y, a.lmap, bt.lmap * x)
end
function _unsafe_mul!(y, L::KroneckerMap{<:Any,<:Tuple{VectorMap,VectorMap}}, x::AbstractVector)
a, b = L.maps
kron!(y, a.lmap, b.lmap)
rmul!(y, first(x))
return y
end
function _unsafe_mul!(y, L::KroneckerMap2, x::AbstractVector)
require_one_based_indexing(y)
A, B = L.maps
Expand Down
26 changes: 26 additions & 0 deletions test/khatrirao.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using Test, LinearMaps, LinearAlgebra

@testset "KhatriRaoMap & FaceSplittingMap" begin
for trans in (identity, complex), m in (2, 4)
A = collect(reshape(trans(1:6), 3, 2))
B = collect(reshape(trans(1:2m), m, 2))
K = @inferred khatrirao(A, B)
@test facesplitting(A', B')' === K
M = mapreduce(kron, hcat, eachcol(A), eachcol(B))
Mx = mapreduce((a, b) -> kron(permutedims(a), permutedims(b)), vcat, eachrow(A'), eachrow(B'))
@test size(K) == size(M)
@test size(@inferred adjoint(K)) == reverse(size(K))
@test size(@inferred transpose(K)) == reverse(size(K))
@test Matrix(K) == M
@test Matrix(K') == Mx
@test LinearMaps.MulStyle(K) === LinearMaps.MulStyle(K') === LinearMaps.FiveArg()
@test (K')' === K
@test transpose(transpose(K)) === K
x = trans(rand(-10:10, size(K, 2)))
y = trans(rand(-10:10, size(K, 1)))
for α in (false, true, trans(rand(2:5))), β in (false, true, trans(rand(2:5)))
@test mul!(copy(y), K, x, α, β) == y * β + K * x * α
@test mul!(copy(x), K', y, α, β) == x * β + K' * y * α
end
end
end
3 changes: 3 additions & 0 deletions test/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
@test Matrix(L) ≈ M
@test L * v ≈ M * v
end
L = ones(3) ⊗ (b = rand(ComplexF64, 4))
@test L * [2] ≈ kron(ones(3), b) * 2
@test Matrix(L) ≈ kron(ones(3), b) rtol=2eps(Float64)
L = ones(3) ⊗ ones(ComplexF64, 4)'
v = rand(4)
@test Matrix(L) == ones(3,4)
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@ include("getindex.jl")
include("inversemap.jl")

include("rrules.jl")

include("khatrirao.jl")