Skip to content

Commit 887acf9

Browse files
authored
no BlasFloat specialzation for banded/approxbanded (#426)
* no BlasFloat specialzation for banded/approxbanded * resizedata in ragged * resizedata in blockbanded
1 parent 190d74d commit 887acf9

File tree

4 files changed

+45
-277
lines changed

4 files changed

+45
-277
lines changed

src/Caching/almostbanded.jl

Lines changed: 5 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ function resizedata!(QR::QROperator{<:CachedOperator{T,<:AlmostBandedMatrix{T}}}
310310
dind = R.u+1+k-j
311311
v = view(R.data,dind:dind+M-1,j)
312312
dt = dot(wp,v)
313-
axpy!(-2*dt,wp,v)
313+
axpy!(-2dt,wp,v)
314314
end
315315

316316
# scale banded/filled entries
@@ -320,87 +320,17 @@ function resizedata!(QR::QROperator{<:CachedOperator{T,<:AlmostBandedMatrix{T}}}
320320
wp2=view(wp,p+1:M)
321321
dt = dot(wp2,v)
322322
for=k:k+p-1
323-
@inbounds dt = muladd(conj(W[ℓ-k+1,k]),
324-
unsafe_getindex(MO.data.fill,ℓ,j),dt)
323+
dt = muladd(conj(W[ℓ-k+1,k]), MO.data.fill[ℓ,j], dt)
325324
end
326-
axpy!(-2*dt,wp2,v)
325+
axpy!(-2dt,wp2,v)
327326
end
328327

329328
# scale filled entries
330329

331-
for j = 1:size(F,2)
330+
for j = axes(F,2)
332331
v = view(F,k:k+M-1,j) # the k,jth entry of F
333332
dt = dot(wp,v)
334-
axpy!(-2*dt,wp,v)
335-
end
336-
end
337-
QR.ncols = col
338-
QR
339-
end
340-
341-
342-
343-
# BLAS versions, requires BlasFloat
344-
345-
function resizedata!(QR::QROperator{<:CachedOperator{T,<:AlmostBandedMatrix{T}}}, ::Colon, col) where {T<:BlasFloat}
346-
if col QR.ncols
347-
return QR
348-
end
349-
350-
MO = QR.R_cache
351-
W = QR.H
352-
353-
R = MO.data.bands
354-
M = R.l+1 # number of diag+subdiagonal bands
355-
356-
if col+M-1 MO.datasize[1]
357-
resizedata!(MO,(col+M-1)+100,:) # double the last rows
358-
end
359-
360-
if col > size(W,2)
361-
W = QR.H = unsafe_resize!(W,:,2col)
362-
end
363-
364-
F = MO.data.fill.U
365-
366-
f = pointer(F)
367-
m,n = size(R)
368-
w = pointer(W)
369-
r = pointer(R.data)
370-
sz = sizeof(T)
371-
st = stride(R.data,2)
372-
stw = stride(W,2)
373-
374-
for k = QR.ncols+1:col
375-
v = r+sz*(R.u + (k-1)*st) # diagonal entry
376-
wp = w+stw*sz*(k-1) # k-th column of W
377-
BLAS.blascopy!(M,v,1,wp,1)
378-
W[1,k]+= flipsign(BLAS.nrm2(M,wp,1),W[1,k])
379-
normalize!(M,wp)
380-
381-
for j = k:k+R.u
382-
v = r+sz*(R.u + (k-1)*st + (j-k)*(st-1))
383-
dt = dot(M,wp,1,v,1)
384-
BLAS.axpy!(M,-2*dt,wp,1,v,1)
385-
end
386-
387-
for j = k+R.u+1:k+R.u+M-1
388-
p = j-k-R.u
389-
v = r+sz*((j-1)*st) # shift down each time
390-
dt = dot(M-p,wp+p*sz,1,v,1)
391-
for=k:k+p-1
392-
@inbounds dt = muladd(conj(W[ℓ-k+1,k]),
393-
unsafe_getindex(MO.data.fill,ℓ,j),dt)
394-
end
395-
BLAS.axpy!(M-p,-2*dt,wp+p*sz,1,v,1)
396-
end
397-
398-
fp = f+(k-1)*sz
399-
fst = stride(F,2)
400-
for j = 1:size(F,2)
401-
v = fp+fst*(j-1)*sz # the k,jth entry of F
402-
dt = dot(M,wp,1,v,1)
403-
BLAS.axpy!(M,-2*dt,wp,1,v,1)
333+
axpy!(-2dt,wp,v)
404334
end
405335
end
406336
QR.ncols = col

src/Caching/banded.jl

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ function QROperator(R::CachedOperator{T,<:BandedMatrix{T}}) where T
4242
end
4343

4444

45-
function resizedata!(QR::QROperator{<:CachedOperator{T,<:BandedMatrix{T},
46-
MM,DS,RS,BI}},
47-
::Colon,col) where {T,MM,DS,RS,BI}
45+
function resizedata!(QR::QROperator{<:CachedOperator{T,<:BandedMatrix{T}}}, ::Colon,col) where {T}
4846
if col QR.ncols
4947
return QR
5048
end
5149

50+
col = min(col, size(QR,2))
51+
5252
MO=QR.R_cache
5353
W=QR.H
5454

@@ -74,7 +74,7 @@ function resizedata!(QR::QROperator{<:CachedOperator{T,<:BandedMatrix{T},
7474
dind=R.u+1+k-j
7575
v=view(R.data,dind:dind+M-1,j)
7676
dt=dot(wp,v)
77-
axpy!(-2*dt,wp,v)
77+
axpy!(-2dt,wp,v)
7878
end
7979

8080
# scale banded/filled entries
@@ -83,66 +83,7 @@ function resizedata!(QR::QROperator{<:CachedOperator{T,<:BandedMatrix{T},
8383
v=view(R.data,1:M-p,j) # shift down each time
8484
wp2=view(wp,p+1:M)
8585
dt=dot(wp2,v)
86-
axpy!(-2*dt,wp2,v)
87-
end
88-
end
89-
QR.ncols=col
90-
QR
91-
end
92-
93-
94-
# BLAS versions, requires BlasFloat
95-
96-
97-
98-
function resizedata!(QR::QROperator{<:CachedOperator{T,<:BandedMatrix{T},
99-
MM,DS,RS,BI}},
100-
::Colon,col) where {T<:BlasFloat,MM,DS,RS,BI}
101-
if col QR.ncols
102-
return QR
103-
end
104-
105-
col = min(col, size(QR,2))
106-
107-
MO=QR.R_cache
108-
W=QR.H
109-
110-
R=MO.data
111-
M=R.l+1 # number of diag+subdiagonal bands
112-
113-
if col+M-1 MO.datasize[1]
114-
resizedata!(MO,(col+M-1)+100,:) # double the last rows
115-
end
116-
117-
if col > size(W,2)
118-
W=QR.H=unsafe_resize!(W,:,2col)
119-
end
120-
121-
m,n=size(R)
122-
w=pointer(W)
123-
r=pointer(R.data)
124-
sz=sizeof(T)
125-
st=stride(R.data,2)
126-
stw=stride(W,2)
127-
128-
for k=QR.ncols+1:col
129-
v=r+sz*(R.u + (k-1)*st) # diagonal entry
130-
wp=w+stw*sz*(k-1) # k-th column of W
131-
BLAS.blascopy!(M,v,1,wp,1)
132-
W[1,k]+= flipsign(BLAS.nrm2(M,wp,1),W[1,k])
133-
normalize!(M,wp)
134-
135-
for j=k:k+R.u
136-
v=r+sz*(R.u + (k-1)*st + (j-k)*(st-1))
137-
dt = dot(M,wp,1,v,1)
138-
BLAS.axpy!(M,-2*dt,wp,1,v,1)
139-
end
140-
141-
for j=k+R.u+1:k+R.u+M-1
142-
p=j-k-R.u
143-
v=r+sz*((j-1)*st) # shift down each time
144-
dt = dot(M-p,wp+p*sz,1,v,1)
145-
BLAS.axpy!(M-p,-2*dt,wp+p*sz,1,v,1)
86+
axpy!(-2dt,wp2,v)
14687
end
14788
end
14889
QR.ncols=col

src/Caching/blockbanded.jl

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,10 @@ QROperator(R::CachedOperator{T,BlockBandedMatrix{T}}) where {T} =
198198
# end
199199

200200
# always resize by column
201-
resizedata!(QR::QROperator{CachedOperator{T,BlockBandedMatrix{T},
202-
MM,DS,RS,BI}},
203-
::Colon, col::Int) where {T,MM,DS,RS,BI} =
201+
resizedata!(QR::QROperator{<:CachedOperator{T,BlockBandedMatrix{T}}}, ::Colon, col::Int) where {T} =
204202
resizedata!(QR, :, block(domainspace(QR.R_cache),col))
205203

206-
function resizedata!(QR::QROperator{CachedOperator{T,BlockBandedMatrix{T},
207-
MM,DS,RS,BI}},
208-
::Colon, COL::Block) where {T<:BlasFloat,MM,DS,RS,BI}
204+
function resizedata!(QR::QROperator{<:CachedOperator{T,BlockBandedMatrix{T}}}, ::Colon, COL::Block) where {T<:BlasFloat}
209205
MO = QR.R_cache
210206
W = QR.H
211207
R = MO.data
@@ -227,7 +223,6 @@ function resizedata!(QR::QROperator{CachedOperator{T,BlockBandedMatrix{T},
227223
K_end = Int(blockcolstop(MO, Block(J_col))) # last row block in last column
228224
J_end = Int(blockrowstop(MO, Block(K_end))) # QR will affect up to this column
229225
j_end = blockstop(domainspace(MO), Block(J_end)) # we need to resize up this column
230-
sz = sizeof(T)
231226

232227
if j_end MO.datasize[2]
233228
# add columns up to column rs, which is last column affected by QR
@@ -249,12 +244,11 @@ function resizedata!(QR::QROperator{CachedOperator{T,BlockBandedMatrix{T},
249244
resize!(W.data, W.cols[end]-1)
250245
end
251246

252-
w = pointer(W.data)
253-
r = pointer(R.data)
247+
ri = firstindex(R.data)
254248

255249
bs = R.block_sizes
256250

257-
for j =QR.ncols+1 : col # first column of block
251+
for j = QR.ncols+1 : col # first column of block
258252
bi = findblockindex.(bs.axes, (j, j)) # converts from global indices to block indices
259253
K1, J1 = Int.(block.(bi)) # this is the diagonal block corresponding to j
260254
κ, ξ = blockindex.(bi)
@@ -267,51 +261,37 @@ function resizedata!(QR::QROperator{CachedOperator{T,BlockBandedMatrix{T},
267261
k_end = last(bs.axes[1][Block(K_CS)])
268262

269263
w_j = W.cols[j] # the data index for the j-th column of W
270-
wp = w+sz*(w_j-1) # j-th column of W
271264

272265
M = k_end - j + 1 # the number of entries we are diagonalizing. we know the stride tells us the total number of rows
273266

274-
BLAS.blascopy!(M, r+sz*shft, 1, wp, 1) # copy the column into W
267+
WM = view(W.data, range(w_j, length=M))
268+
RM = view(R.data, range(ri + shft, length=M))
275269

270+
copyto!(WM, RM)
276271

277272
# we need to scale the first entry and then normalize
278-
W.data[w_j] += flipsign(BLAS.nrm2(M,wp,1), W.data[w_j])
279-
normalize!(M, wp)
273+
W.data[w_j] += flipsign(norm(WM), W.data[w_j])
274+
normalize!(WM)
280275

281276
# scale rest of columns in first block
282277
# for ξ_2 = 2:
283278

284279
for ξ_2 = ξ:length(bs.axes[2][Block(J1)])
285280
# we now apply I-2v*v' in place
286-
r_sh = r+sz*(shft + st*(ξ_2-ξ)) # the pointer the (j,ξ_2)-th entry
287-
dt = dot(M, wp, 1, r_sh, 1)
288-
BLAS.axpy!(M, -2*dt, wp, 1, r_sh ,1)
281+
rish = ri + shft + st*(ξ_2-ξ)
282+
RM = view(R.data, range(rish, length=M))
283+
dt = dot(WM, RM)
284+
axpy!(-2dt, WM, RM)
289285
end
290286

291287
for J = J1+1:min(K1+u,J_end)
292288
st = bs.block_strides[J]
293289
shft = bs.block_starts[K1,J] + κ-2 # the index of the pointer to the j, j entry
294290
for ξ_2 = axes(bs.axes[2][Block(J)],1)
295291
# we now apply I-2v*v' in place
296-
# r_sh = r+sz*(shft + st*(ξ_2-1)) # the pointer the (j,ξ_2)-th entry
297-
298-
# TODO: remove these debugging statement
299-
# @assert w_j-1 + M ≤ length(W.data)
300-
# @assert shft + st*(ξ_2-1) + M ≤ length(R.data)
301-
# @assert 0 ≤ w_j-1
302-
# if ! (0 ≤ shft + st*(ξ_2-1))
303-
# @show shft, st, ξ_2, l, u
304-
# @show κ, bs.block_starts[K1,J]
305-
# @show K1, J
306-
# @show MO.op
307-
# end
308-
# dt = dot(M, wp, 1, r_sh, 1)
309-
# BLAS.axpy!(M, -2*dt, wp, 1, r_sh ,1)
310-
311-
dt = dot(view(W.data, w_j:w_j+M-1) ,
312-
view(R.data, shft + st*(ξ_2-1) +1:shft + st*(ξ_2-1) +M))
313-
axpy!(-2*dt, view(W.data, w_j:w_j+M-1) ,
314-
view(R.data, shft + st*(ξ_2-1) +1:shft + st*(ξ_2-1) +M))
292+
RM = view(R.data, shft + st*(ξ_2-1) .+ (1:M))
293+
dt = dot(WM, RM)
294+
axpy!(-2dt, WM, RM)
315295
end
316296
end
317297
end

0 commit comments

Comments
 (0)