Skip to content

Preserve bandwidth in ConcreteMultiplication to BandedMatrix conversion #358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 30 additions & 32 deletions src/Spaces/PolynomialSpace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,43 +151,56 @@ See https://github.com/JuliaLinearAlgebra/BandedMatrices.jl/blob/master/LICENSE
end

_view(::Any, A, b) = view(A, b)
_view(::Val{true}, A::BandedMatrix, b) = dataview(view(A, b))
function _view(::Val{true}, A::BandedMatrix, b::Band)
l, u = bandwidths(A)
-l <= b.i <= u || throw(ArgumentError("invalid band $b for bandwidths $((-l,u))"))
dataview(view(A, b))
end

function _get_bands(B, C, bmk, f, ValBC)
function _get_bands(B, C, bmk, f, valB)
Cbmk = _view(Val(true), C, band(bmk*f))
Bm = _view(Val(true), B, band(flipsign(bmk-1, f)))
B0 = _view(Val(true), B, band(flipsign(bmk, f)))
Bp = _view(ValBC, B, band(flipsign(bmk+1, f)))
Bp = _view(valB, B, band(flipsign(bmk+1, f)))
Cbmk, Bm, B0, Bp
end

function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC)
Jp = _view(ValJ, J, band(1))
J0 = _view(ValJ, J, band(0))
Jm = _view(ValJ, J, band(-1))
# Fast implementation of C[:,:] = α*J*B+β*C where the bandediwth of B is
# specified by b, not by the parameters in B
function jac_gbmm!(α, J, B, β, C, b, valB)
if β ≠ 1
lmul!(β,C)
end

n = size(J,1)
Cn, Cm = size(C)

Jp = _view(Val(true), J, band(1))
J0 = _view(Val(true), J, band(0))
Jm = _view(Val(true), J, band(-1))

kr = intersect(-1:b-1, b-Cm+1:b-1+Cn)

# unwrap the loops to forward indexing to the data wherever applicable
# this might also help with cache localization
k = -1
if k in kr
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, ValBC)
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, valB)
for i in 1:n-b+k
Cbmk[i] += α * Bm[i+1] * Jp[i]
end
end

k = 0
if k in kr
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, Val(true))
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, valB)
for i in 1:n-b+k
Cbmk[i] += α * (Bm[i+1] * Jp[i] + B0[i] * J0[i])
end
end

for k in max(1, first(kr)):last(kr)
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, Val(true))
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, valB)
Cbmk[1] += α * (Bm[2] * Jp[1] + B0[1] * J0[1])
for i in 2:n-b+k
Cbmk[i] += α * (Bm[i+1] * Jp[i] + B0[i] * J0[i] + Bp[i-1] * Jm[i-1])
Expand All @@ -198,15 +211,15 @@ function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC)

k = -1
if k in kr
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, ValBC)
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, valB)
for (i, Ji) in enumerate(b-k:n-1)
Ckmb[i] += α * Bp[i] * Jm[Ji]
end
end

k = 0
if k in kr
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, Val(true))
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, valB)
Ckmb[1] += α * Bp[1] * Jm[b-k]
for (i, Ji) in enumerate(b-k+1:n-1)
Ckmb[i] += α * B0[i] * J0[Ji]
Expand Down Expand Up @@ -238,21 +251,6 @@ function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC)
return C
end

# Fast implementation of C[:,:] = α*J*B+β*C where the bandediwth of B is
# specified by b, not by the parameters in B
function jac_gbmm!(α, J, B, β, C, b, valJ, valBC)
if β ≠ 1
lmul!(β,C)
end

n = size(J,1)
Cn, Cm = size(C)

_jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, valJ, valBC)

C
end

function BandedMatrix(S::SubOperator{T,ConcreteMultiplication{C,PS,T},
NTuple{2,UnitRange{Int}}}) where {PS<:PolynomialSpace,T,C<:PolynomialSpace}
M=parent(S)
Expand Down Expand Up @@ -285,31 +283,31 @@ function BandedMatrix(S::SubOperator{T,ConcreteMultiplication{C,PS,T},

#Multiplication is transpose
J=Operator{T}(Recurrence(M.space))[jkr,jkr]
valJ = all(>=(1), bandwidths(J)) ? Val(true) : Val(false)

B=n-1 # final bandwidth

# Clenshaw for operators
Bk2 = BandedMatrix(Zeros{T}(size(J)), (B,B))
dataview(view(Bk2, band(0))) .= a[n]/recβ(T,sp,n-1)
α,β = recα(T,sp,n-1),recβ(T,sp,n-2)
Bk1 = (-α/β)*Bk2
Bk1 = lmul!(-α/β, copy(Bk2))
dataview(view(Bk1, band(0))) .+= a[n-1]/β
jac_gbmm!(one(T)/β,J,Bk2,one(T),Bk1,0,valJ, Val(true))
jac_gbmm!(one(T)/β,J,Bk2,one(T),Bk1,0, Val(true))
b=1 # we keep track of bandwidths manually to reuse memory
for k=n-2:-1:2
# b goes from 1:
α,β,γ=recα(T,sp,k),recβ(T,sp,k-1),recγ(T,sp,k+1)
lmul!(-γ/β,Bk2)
dataview(view(Bk2, band(0))) .+= a[k]/β
jac_gbmm!(1/β,J,Bk1,one(T),Bk2,b,valJ,Val(true))
jac_gbmm!(1/β,J,Bk1,one(T),Bk2,b,Val(true))
LinearAlgebra.axpy!(-α/β,Bk1,Bk2)
Bk2,Bk1=Bk1,Bk2
b+=1
end
α,γ=recα(T,sp,1),recγ(T,sp,2)
lmul!(-γ,Bk2)
dataview(view(Bk2, band(0))) .+= a[1]
jac_gbmm!(one(T),J,Bk1,one(T),Bk2,b,valJ,Val(false))
jac_gbmm!(one(T),J,Bk1,one(T),Bk2,b,Val(false))
LinearAlgebra.axpy!(-α,Bk1,Bk2)

# relationship between jkr and kr, jr
Expand Down
Loading