Skip to content

Commit b6f73f1

Browse files
authored
Use dataview in jac_gbmm (#197)
* use dataview in jac_gbmm * use inbands view where provable * copy over dataview from BandedMatrices * fix band * unwrap loops * change broadcast to loop in band0
1 parent 2de9a10 commit b6f73f1

File tree

2 files changed

+117
-72
lines changed

2 files changed

+117
-72
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ApproxFunOrthogonalPolynomials"
22
uuid = "b70543e2-c0d9-56b8-a290-0d4d6d4de211"
3-
version = "0.6.12"
3+
version = "0.6.13"
44

55
[deps]
66
ApproxFunBase = "fbd15aa5-315a-5a7d-a8a4-24992e37be05"

src/Spaces/PolynomialSpace.jl

Lines changed: 116 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -128,80 +128,124 @@ end
128128

129129
getindex(M::ConcreteMultiplication{C,PS,T},k::Integer,j::Integer) where {PS<:PolynomialSpace,T,C<:PolynomialSpace} = M[k:k,j:j][1,1]
130130

131+
if view(brand(0,0,0,0), band(0)) isa BandedMatrices.BandedMatrixBand
132+
dataview(V) = BandedMatrices.dataview(V)
133+
else
134+
#=
135+
dataview is broken on BandedMatrices v0.17.6 and older.
136+
We copy the function over from BandedMatrices.jl, which is distributed under the MIT license
137+
See https://github.com/JuliaLinearAlgebra/BandedMatrices.jl/blob/master/LICENSE
138+
=#
139+
function dataview(V)
140+
A = parent(parent(V))
141+
b = first(parentindices(V)).band.i
142+
m,n = size(A)
143+
l,u = bandwidths(A)
144+
data = BandedMatrices.bandeddata(A)
145+
view(data, u - b + 1, max(b,0)+1:min(n,m+b))
146+
end
147+
end
131148

149+
_view(::Any, A, b) = view(A, b)
150+
_view(::Val{true}, A::BandedMatrix, b) = dataview(view(A, b))
132151

152+
function _get_bands(B, C, bmk, f, ValBC)
153+
Cbmk = _view(Val(true), C, band(bmk*f))
154+
Bm = _view(Val(true), B, band(flipsign(bmk-1, f)))
155+
B0 = _view(Val(true), B, band(flipsign(bmk, f)))
156+
Bp = _view(ValBC, B, band(flipsign(bmk+1, f)))
157+
Cbmk, Bm, B0, Bp
158+
end
133159

134-
# Fast implementation of C[:,:] = α*J*B+β*C where the bandediwth of B is
135-
# specified by b, not by the parameters in B
136-
function jac_gbmm!(α, J, B, β, C, b)
137-
if β 1
138-
lmul!(β,C)
139-
end
160+
function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC)
161+
Jp = _view(ValJ, J, band(1))
162+
J0 = _view(ValJ, J, band(0))
163+
Jm = _view(ValJ, J, band(-1))
140164

141-
Jp = view(J, band(1))
142-
J0 = view(J, band(0))
143-
Jm = view(J, band(-1))
144-
n = size(J,1)
165+
kr = intersect(-1:b-1, b-Cm+1:b-1+Cn)
145166

146-
Cn, Cm = size(C)
167+
# unwrap the loops to forward indexing to the data wherever applicable
168+
# this might also help with cache localization
169+
k = -1
170+
if k in kr
171+
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, ValBC)
172+
for i in 1:n-b+k
173+
Cbmk[i] += α * Bm[i+1] * Jp[i]
174+
end
175+
end
147176

148-
@views for k=-1:b-1
149-
if 1-Cn b-k Cm-1 # if inbands
150-
Cbmk = C[band(b-k)]
151-
Bm = B[band(b-k-1)]
152-
B0 = B[band(b-k)]
153-
Bp = B[band(b-k+1)]
154-
for i in 1:n-b+k
155-
Cbmk[i] += α * Bm[i+1] * Jp[i]
156-
end
157-
if k 0
158-
for i in 1:n-b+k
159-
Cbmk[i] += α * B0[i] * J0[i]
160-
end
161-
if k 1
162-
for i in 1:n-1-b+k
163-
Cbmk[i+1] += α * Bp[i] * Jm[i]
164-
end
165-
end
166-
end
177+
k = 0
178+
if k in kr
179+
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, Val(true))
180+
for i in 1:n-b+k
181+
Cbmk[i] += α * (Bm[i+1] * Jp[i] + B0[i] * J0[i])
167182
end
168183
end
169184

170-
@views for k=-1:b-1
171-
if 1-Cn k-b Cm-1 # if inbands
172-
Ckmb = C[band(k-b)]
173-
Bp = B[band(k-b+1)]
174-
B0 = B[band(k-b)]
175-
Bm = B[band(k-b-1)]
176-
for (i, Ji) in enumerate(b-k:n-1)
177-
Ckmb[i] += α * Bp[i] * Jm[Ji]
178-
end
179-
if k 0
180-
for (i, Ji) in enumerate(b-k+1:n)
181-
Ckmb[i] += α * B0[i] * J0[Ji]
182-
end
183-
if k 1
184-
for (i, Ji) in enumerate(b-k+1:n-1)
185-
Ckmb[i] += α * Bm[i] * Jp[Ji]
186-
end
187-
end
188-
end
185+
for k in max(1, first(kr)):last(kr)
186+
Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, Val(true))
187+
Cbmk[1] += α * (Bm[2] * Jp[1] + B0[1] * J0[1])
188+
for i in 2:n-b+k
189+
Cbmk[i] += α * (Bm[i+1] * Jp[i] + B0[i] * J0[i] + Bp[i-1] * Jm[i-1])
189190
end
190191
end
191192

192-
@views begin
193-
C0 = C[band(0)]
194-
Bm = B[band(-1)]
195-
Bp = B[band(1)]
196-
C0 .+= α.*B[band(0)].*J0
197-
for i in 1:n-1
198-
C0[i] += α * Bm[i] * Jp[i]
193+
kr = intersect(-1:b-1, 1-Cn+b:Cm-1+b)
194+
195+
k = -1
196+
if k in kr
197+
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, ValBC)
198+
for (i, Ji) in enumerate(b-k:n-1)
199+
Ckmb[i] += α * Bp[i] * Jm[Ji]
199200
end
200-
for i in 2:n
201-
C0[i] += α * Bp[i-1] * Jm[i-1]
201+
end
202+
203+
k = 0
204+
if k in kr
205+
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, Val(true))
206+
Ckmb[1] += α * Bp[1] * Jm[b-k]
207+
for (i, Ji) in enumerate(b-k+1:n-1)
208+
Ckmb[i] += α * B0[i] * J0[Ji]
209+
Ckmb[i+1] += α * Bp[i+1] * Jm[Ji]
210+
end
211+
Ckmb[n-(b-k)] += α * B0[n-(b-k)] * J0[n]
212+
end
213+
214+
for k = max(1, first(kr)):last(kr)
215+
Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, Val(true))
216+
Ckmb[1] += α * Bp[1] * Jm[b-k]
217+
for (i, Ji) in enumerate(b-k+1:n-1)
218+
Ckmb[i] += α * (Bm[i] * Jp[Ji] + B0[i] * J0[Ji])
219+
Ckmb[i+1] += α * Bp[i+1] * Jm[Ji]
202220
end
221+
Ckmb[n-(b-k)] += α * B0[n-(b-k)] * J0[n]
222+
end
223+
224+
C0 = _view(Val(true), C, band(0))
225+
Bm = _view(Val(true), B, band(-1))
226+
Bp = _view(Val(true), B, band(1))
227+
B0 = _view(Val(true), B, band(0))
228+
for i in 1:n-1
229+
C0[i] += α * (B0[i] * J0[i] + Bm[i] * Jp[i])
230+
C0[i+1] += α * Bp[i] * Jm[i]
231+
end
232+
C0[n] += α * B0[n] * J0[n]
233+
234+
return C
235+
end
236+
237+
# Fast implementation of C[:,:] = α*J*B+β*C where the bandediwth of B is
238+
# specified by b, not by the parameters in B
239+
function jac_gbmm!(α, J, B, β, C, b, valJ, valBC)
240+
if β 1
241+
lmul!(β,C)
203242
end
204243

244+
n = size(J,1)
245+
Cn, Cm = size(C)
246+
247+
_jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, valJ, valBC)
248+
205249
C
206250
end
207251

@@ -220,54 +264,55 @@ function BandedMatrix(S::SubOperator{T,ConcreteMultiplication{C,PS,T},
220264
ret = BandedMatrix(Zeros, S)
221265
shft=kr[1]-jr[1]
222266
ret[band(shft)] .= a[1]
223-
return ret::BandedMatrix{T}
267+
return ret
224268
elseif n==2
225269
# we have U_x = [1 α-x; 0 β]
226270
# for e_1^⊤ U_x\a == a[1]*I-(α-J)*a[2]/β == (a[1]-α*a[2]/β)*I + J*a[2]/β
227271
# implying
228272
α,β=recα(T,sp,1),recβ(T,sp,1)
229-
ret=Operator{T}(Recurrence(M.space))[kr,jr]::BandedMatrix{T}
273+
ret=Operator{T}(Recurrence(M.space))[kr,jr]
230274
lmul!(a[2]/β,ret)
231275
shft=kr[1]-jr[1]
232-
ret[band(shft)] .+= a[1]-α*a[2]/β
233-
return ret::BandedMatrix{T}
276+
@views ret[band(shft)] .+= a[1]-α*a[2]/β
277+
return ret
234278
end
235279

236280
jkr=max(1,min(jr[1],kr[1])-(n-1)÷2):max(jr[end],kr[end])+(n-1)÷2
237281

238282
#Multiplication is transpose
239283
J=Operator{T}(Recurrence(M.space))[jkr,jkr]
284+
valJ = all(>=(1), bandwidths(J)) ? Val(true) : Val(false)
240285

241286
B=n-1 # final bandwidth
242287

243288
# Clenshaw for operators
244-
Bk2 = BandedMatrix(Zeros{T}(size(J,1),size(J,2)), (B,B))
245-
Bk2[band(0)] .= a[n]/recβ(T,sp,n-1)
289+
Bk2 = BandedMatrix(Zeros{T}(size(J)), (B,B))
290+
dataview(view(Bk2, band(0))) .= a[n]/recβ(T,sp,n-1)
246291
α,β = recα(T,sp,n-1),recβ(T,sp,n-2)
247292
Bk1 = (-α/β)*Bk2
248-
view(Bk1, band(0)) .= (a[n-1]/β) .+ view(Bk1, band(0))
249-
jac_gbmm!(one(T)/β,J,Bk2,one(T),Bk1,0)
293+
dataview(view(Bk1, band(0))) .+= a[n-1]/β
294+
jac_gbmm!(one(T)/β,J,Bk2,one(T),Bk1,0,valJ, Val(true))
250295
b=1 # we keep track of bandwidths manually to reuse memory
251296
for k=n-2:-1:2
252297
α,β,γ=recα(T,sp,k),recβ(T,sp,k-1),recγ(T,sp,k+1)
253298
lmul!(-γ/β,Bk2)
254-
view(Bk2, band(0)) .= (a[k]/β) .+ view(Bk2, band(0))
255-
jac_gbmm!(1/β,J,Bk1,one(T),Bk2,b)
299+
dataview(view(Bk2, band(0))) .+= a[k]/β
300+
jac_gbmm!(1/β,J,Bk1,one(T),Bk2,b,valJ,Val(true))
256301
LinearAlgebra.axpy!(-α/β,Bk1,Bk2)
257302
Bk2,Bk1=Bk1,Bk2
258303
b+=1
259304
end
260305
α,γ=recα(T,sp,1),recγ(T,sp,2)
261306
lmul!(-γ,Bk2)
262-
view(Bk2, band(0)) .= a[1] .+ view(Bk2, band(0))
263-
jac_gbmm!(one(T),J,Bk1,one(T),Bk2,b)
307+
dataview(view(Bk2, band(0))) .+= a[1]
308+
jac_gbmm!(one(T),J,Bk1,one(T),Bk2,b,valJ,Val(false))
264309
LinearAlgebra.axpy!(-α,Bk1,Bk2)
265310

266311
# relationship between jkr and kr, jr
267312
kr2,jr2=kr.-jkr[1].+1,jr.-jkr[1].+1
268313

269314
# TODO: reuse memory of Bk2, though profile suggests it's not too important
270-
BandedMatrix(view(Bk2,kr2,jr2))::BandedMatrix{T}
315+
BandedMatrix(view(Bk2,kr2,jr2))
271316
end
272317

273318

0 commit comments

Comments
 (0)