@@ -128,80 +128,124 @@ end
128
128
129
129
getindex (M:: ConcreteMultiplication{C,PS,T} ,k:: Integer ,j:: Integer ) where {PS<: PolynomialSpace ,T,C<: PolynomialSpace } = M[k: k,j: j][1 ,1 ]
130
130
131
+ if view (brand (0 ,0 ,0 ,0 ), band (0 )) isa BandedMatrices. BandedMatrixBand
132
+ dataview (V) = BandedMatrices. dataview (V)
133
+ else
134
+ #=
135
+ dataview is broken on BandedMatrices v0.17.6 and older.
136
+ We copy the function over from BandedMatrices.jl, which is distributed under the MIT license
137
+ See https://github.com/JuliaLinearAlgebra/BandedMatrices.jl/blob/master/LICENSE
138
+ =#
139
+ function dataview (V)
140
+ A = parent (parent (V))
141
+ b = first (parentindices (V)). band. i
142
+ m,n = size (A)
143
+ l,u = bandwidths (A)
144
+ data = BandedMatrices. bandeddata (A)
145
+ view (data, u - b + 1 , max (b,0 )+ 1 : min (n,m+ b))
146
+ end
147
+ end
131
148
149
+ _view (:: Any , A, b) = view (A, b)
150
+ _view (:: Val{true} , A:: BandedMatrix , b) = dataview (view (A, b))
132
151
152
+ function _get_bands (B, C, bmk, f, ValBC)
153
+ Cbmk = _view (Val (true ), C, band (bmk* f))
154
+ Bm = _view (Val (true ), B, band (flipsign (bmk- 1 , f)))
155
+ B0 = _view (Val (true ), B, band (flipsign (bmk, f)))
156
+ Bp = _view (ValBC, B, band (flipsign (bmk+ 1 , f)))
157
+ Cbmk, Bm, B0, Bp
158
+ end
133
159
134
- # Fast implementation of C[:,:] = α*J*B+β*C where the bandediwth of B is
135
- # specified by b, not by the parameters in B
136
- function jac_gbmm! (α, J, B, β, C, b)
137
- if β ≠ 1
138
- lmul! (β,C)
139
- end
160
+ function _jac_gbmm! (α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC)
161
+ Jp = _view (ValJ, J, band (1 ))
162
+ J0 = _view (ValJ, J, band (0 ))
163
+ Jm = _view (ValJ, J, band (- 1 ))
140
164
141
- Jp = view (J, band (1 ))
142
- J0 = view (J, band (0 ))
143
- Jm = view (J, band (- 1 ))
144
- n = size (J,1 )
165
+ kr = intersect (- 1 : b- 1 , b- Cm+ 1 : b- 1 + Cn)
145
166
146
- Cn, Cm = size (C)
167
+ # unwrap the loops to forward indexing to the data wherever applicable
168
+ # this might also help with cache localization
169
+ k = - 1
170
+ if k in kr
171
+ Cbmk, Bm, B0, Bp = _get_bands (B, C, b- k, 1 , ValBC)
172
+ for i in 1 : n- b+ k
173
+ Cbmk[i] += α * Bm[i+ 1 ] * Jp[i]
174
+ end
175
+ end
147
176
148
- @views for k= - 1 : b- 1
149
- if 1 - Cn ≤ b- k ≤ Cm- 1 # if inbands
150
- Cbmk = C[band (b- k)]
151
- Bm = B[band (b- k- 1 )]
152
- B0 = B[band (b- k)]
153
- Bp = B[band (b- k+ 1 )]
154
- for i in 1 : n- b+ k
155
- Cbmk[i] += α * Bm[i+ 1 ] * Jp[i]
156
- end
157
- if k ≥ 0
158
- for i in 1 : n- b+ k
159
- Cbmk[i] += α * B0[i] * J0[i]
160
- end
161
- if k ≥ 1
162
- for i in 1 : n- 1 - b+ k
163
- Cbmk[i+ 1 ] += α * Bp[i] * Jm[i]
164
- end
165
- end
166
- end
177
+ k = 0
178
+ if k in kr
179
+ Cbmk, Bm, B0, Bp = _get_bands (B, C, b- k, 1 , Val (true ))
180
+ for i in 1 : n- b+ k
181
+ Cbmk[i] += α * (Bm[i+ 1 ] * Jp[i] + B0[i] * J0[i])
167
182
end
168
183
end
169
184
170
- @views for k= - 1 : b- 1
171
- if 1 - Cn ≤ k- b ≤ Cm- 1 # if inbands
172
- Ckmb = C[band (k- b)]
173
- Bp = B[band (k- b+ 1 )]
174
- B0 = B[band (k- b)]
175
- Bm = B[band (k- b- 1 )]
176
- for (i, Ji) in enumerate (b- k: n- 1 )
177
- Ckmb[i] += α * Bp[i] * Jm[Ji]
178
- end
179
- if k ≥ 0
180
- for (i, Ji) in enumerate (b- k+ 1 : n)
181
- Ckmb[i] += α * B0[i] * J0[Ji]
182
- end
183
- if k ≥ 1
184
- for (i, Ji) in enumerate (b- k+ 1 : n- 1 )
185
- Ckmb[i] += α * Bm[i] * Jp[Ji]
186
- end
187
- end
188
- end
185
+ for k in max (1 , first (kr)): last (kr)
186
+ Cbmk, Bm, B0, Bp = _get_bands (B, C, b- k, 1 , Val (true ))
187
+ Cbmk[1 ] += α * (Bm[2 ] * Jp[1 ] + B0[1 ] * J0[1 ])
188
+ for i in 2 : n- b+ k
189
+ Cbmk[i] += α * (Bm[i+ 1 ] * Jp[i] + B0[i] * J0[i] + Bp[i- 1 ] * Jm[i- 1 ])
189
190
end
190
191
end
191
192
192
- @views begin
193
- C0 = C[ band ( 0 )]
194
- Bm = B[ band ( - 1 )]
195
- Bp = B[ band ( 1 )]
196
- C0 .+ = α .* B[ band ( 0 )] . * J0
197
- for i in 1 : n- 1
198
- C0 [i] += α * Bm [i] * Jp[i ]
193
+ kr = intersect ( - 1 : b - 1 , 1 - Cn + b : Cm - 1 + b)
194
+
195
+ k = - 1
196
+ if k in kr
197
+ Ckmb, Bp, B0, Bm = _get_bands (B, C, b - k, - 1 , ValBC)
198
+ for (i, Ji) in enumerate (b - k : n- 1 )
199
+ Ckmb [i] += α * Bp [i] * Jm[Ji ]
199
200
end
200
- for i in 2 : n
201
- C0[i] += α * Bp[i- 1 ] * Jm[i- 1 ]
201
+ end
202
+
203
+ k = 0
204
+ if k in kr
205
+ Ckmb, Bp, B0, Bm = _get_bands (B, C, b- k, - 1 , Val (true ))
206
+ Ckmb[1 ] += α * Bp[1 ] * Jm[b- k]
207
+ for (i, Ji) in enumerate (b- k+ 1 : n- 1 )
208
+ Ckmb[i] += α * B0[i] * J0[Ji]
209
+ Ckmb[i+ 1 ] += α * Bp[i+ 1 ] * Jm[Ji]
210
+ end
211
+ Ckmb[n- (b- k)] += α * B0[n- (b- k)] * J0[n]
212
+ end
213
+
214
+ for k = max (1 , first (kr)): last (kr)
215
+ Ckmb, Bp, B0, Bm = _get_bands (B, C, b- k, - 1 , Val (true ))
216
+ Ckmb[1 ] += α * Bp[1 ] * Jm[b- k]
217
+ for (i, Ji) in enumerate (b- k+ 1 : n- 1 )
218
+ Ckmb[i] += α * (Bm[i] * Jp[Ji] + B0[i] * J0[Ji])
219
+ Ckmb[i+ 1 ] += α * Bp[i+ 1 ] * Jm[Ji]
202
220
end
221
+ Ckmb[n- (b- k)] += α * B0[n- (b- k)] * J0[n]
222
+ end
223
+
224
+ C0 = _view (Val (true ), C, band (0 ))
225
+ Bm = _view (Val (true ), B, band (- 1 ))
226
+ Bp = _view (Val (true ), B, band (1 ))
227
+ B0 = _view (Val (true ), B, band (0 ))
228
+ for i in 1 : n- 1
229
+ C0[i] += α * (B0[i] * J0[i] + Bm[i] * Jp[i])
230
+ C0[i+ 1 ] += α * Bp[i] * Jm[i]
231
+ end
232
+ C0[n] += α * B0[n] * J0[n]
233
+
234
+ return C
235
+ end
236
+
237
+ # Fast implementation of C[:,:] = α*J*B+β*C where the bandediwth of B is
238
+ # specified by b, not by the parameters in B
239
+ function jac_gbmm! (α, J, B, β, C, b, valJ, valBC)
240
+ if β ≠ 1
241
+ lmul! (β,C)
203
242
end
204
243
244
+ n = size (J,1 )
245
+ Cn, Cm = size (C)
246
+
247
+ _jac_gbmm! (α, J, B, β, C, b, (Cn, Cm), n, valJ, valBC)
248
+
205
249
C
206
250
end
207
251
@@ -220,54 +264,55 @@ function BandedMatrix(S::SubOperator{T,ConcreteMultiplication{C,PS,T},
220
264
ret = BandedMatrix (Zeros, S)
221
265
shft= kr[1 ]- jr[1 ]
222
266
ret[band (shft)] .= a[1 ]
223
- return ret:: BandedMatrix{T}
267
+ return ret
224
268
elseif n== 2
225
269
# we have U_x = [1 α-x; 0 β]
226
270
# for e_1^⊤ U_x\a == a[1]*I-(α-J)*a[2]/β == (a[1]-α*a[2]/β)*I + J*a[2]/β
227
271
# implying
228
272
α,β= recα (T,sp,1 ),recβ (T,sp,1 )
229
- ret= Operator {T} (Recurrence (M. space))[kr,jr]:: BandedMatrix{T}
273
+ ret= Operator {T} (Recurrence (M. space))[kr,jr]
230
274
lmul! (a[2 ]/ β,ret)
231
275
shft= kr[1 ]- jr[1 ]
232
- ret[band (shft)] .+ = a[1 ]- α* a[2 ]/ β
233
- return ret:: BandedMatrix{T}
276
+ @views ret[band (shft)] .+ = a[1 ]- α* a[2 ]/ β
277
+ return ret
234
278
end
235
279
236
280
jkr= max (1 ,min (jr[1 ],kr[1 ])- (n- 1 )÷ 2 ): max (jr[end ],kr[end ])+ (n- 1 )÷ 2
237
281
238
282
# Multiplication is transpose
239
283
J= Operator {T} (Recurrence (M. space))[jkr,jkr]
284
+ valJ = all (>= (1 ), bandwidths (J)) ? Val (true ) : Val (false )
240
285
241
286
B= n- 1 # final bandwidth
242
287
243
288
# Clenshaw for operators
244
- Bk2 = BandedMatrix (Zeros {T} (size (J, 1 ), size (J, 2 )), (B,B))
245
- Bk2[ band (0 )] .= a[n]/ recβ (T,sp,n- 1 )
289
+ Bk2 = BandedMatrix (Zeros {T} (size (J)), (B,B))
290
+ dataview ( view ( Bk2, band (0 ))) .= a[n]/ recβ (T,sp,n- 1 )
246
291
α,β = recα (T,sp,n- 1 ),recβ (T,sp,n- 2 )
247
292
Bk1 = (- α/ β)* Bk2
248
- view (Bk1, band (0 )) . = ( a[n- 1 ]/ β) .+ view (Bk1, band ( 0 ))
249
- jac_gbmm! (one (T)/ β,J,Bk2,one (T),Bk1,0 )
293
+ dataview ( view (Bk1, band (0 ))) .+ = a[n- 1 ]/ β
294
+ jac_gbmm! (one (T)/ β,J,Bk2,one (T),Bk1,0 ,valJ, Val ( true ) )
250
295
b= 1 # we keep track of bandwidths manually to reuse memory
251
296
for k= n- 2 : - 1 : 2
252
297
α,β,γ= recα (T,sp,k),recβ (T,sp,k- 1 ),recγ (T,sp,k+ 1 )
253
298
lmul! (- γ/ β,Bk2)
254
- view (Bk2, band (0 )) . = ( a[k]/ β) .+ view (Bk2, band ( 0 ))
255
- jac_gbmm! (1 / β,J,Bk1,one (T),Bk2,b)
299
+ dataview ( view (Bk2, band (0 ))) .+ = a[k]/ β
300
+ jac_gbmm! (1 / β,J,Bk1,one (T),Bk2,b,valJ, Val ( true ) )
256
301
LinearAlgebra. axpy! (- α/ β,Bk1,Bk2)
257
302
Bk2,Bk1= Bk1,Bk2
258
303
b+= 1
259
304
end
260
305
α,γ= recα (T,sp,1 ),recγ (T,sp,2 )
261
306
lmul! (- γ,Bk2)
262
- view (Bk2, band (0 )) . = a[1 ] .+ view (Bk2, band ( 0 ))
263
- jac_gbmm! (one (T),J,Bk1,one (T),Bk2,b)
307
+ dataview ( view (Bk2, band (0 ))) .+ = a[1 ]
308
+ jac_gbmm! (one (T),J,Bk1,one (T),Bk2,b,valJ, Val ( false ) )
264
309
LinearAlgebra. axpy! (- α,Bk1,Bk2)
265
310
266
311
# relationship between jkr and kr, jr
267
312
kr2,jr2= kr.- jkr[1 ]. + 1 ,jr.- jkr[1 ]. + 1
268
313
269
314
# TODO : reuse memory of Bk2, though profile suggests it's not too important
270
- BandedMatrix (view (Bk2,kr2,jr2)):: BandedMatrix{T}
315
+ BandedMatrix (view (Bk2,kr2,jr2))
271
316
end
272
317
273
318
0 commit comments