Skip to content

Commit cb25a5f

Browse files
authored
Out-of-place pad for BandedMatrix in resizedata (#430)
1 parent 6621387 commit cb25a5f

File tree

8 files changed

+85
-93
lines changed

8 files changed

+85
-93
lines changed

src/Caching/almostbanded.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ function resizedata!(co::CachedOperator{T,<:AlmostBandedMatrix{T},<:InterlaceOpe
181181
end
182182

183183
(l,u)=bandwidths(co.data.bands)
184-
pad!(co.data,n,n+u)
184+
co.data = pad(co.data, n, n+u)
185185

186186
r = rank(co.data.fill)
187187
ind = findfirst(op->isinf(size(op,1)),co.op.ops)
@@ -227,8 +227,7 @@ function resizedata!(co::CachedOperator{T,<:AlmostBandedMatrix{T},<:InterlaceOpe
227227
p = length(d∞)
228228

229229
(l,u)=bandwidths(co.data.bands)
230-
pad!(co.data,n,n+u)
231-
co.data
230+
co.data = pad(co.data,n,n+u)
232231
# r is number of extra rows, ncols is number of extra columns
233232
r = rank(co.data.fill)
234233
ncols = mapreduce(d->isfinite(d) ? d : 0,+,ddims)
@@ -286,36 +285,38 @@ function resizedata!(QR::QROperator{<:CachedOperator{T,<:AlmostBandedMatrix{T}}}
286285
MO = QR.R_cache
287286
W = QR.H
288287

289-
R = MO.data.bands
290-
M = R.l+1 # number of diag+subdiagonal bands
288+
Rl, Ru = bandwidths(MO.data.bands)
289+
M = Rl + 1 # number of diag+subdiagonal bands
291290

292291
if col+M-1 MO.datasize[1]
293292
resizedata!(MO,(col+M-1)+100,:) # double the last rows
294293
end
295294

295+
R = MO.data.bands # has to be accessed after the resizedata!
296+
296297
if col > size(W,2)
297298
W = QR.H = unsafe_resize!(W,:,2col)
298299
end
299300

300301
F = MO.data.fill.U
301302

302303
for k = QR.ncols+1:col
303-
W[:,k] = view(R.data, (R.u+1).+(0:R.l), k) # diagonal and below
304+
W[:,k] = view(R.data, (Ru+1).+(0:Rl), k) # diagonal and below
304305
wp = view(W,:,k)
305306
W[1,k]+= flipsign(norm(wp),W[1,k])
306307
normalize!(wp)
307308

308309
# scale banded entries
309-
for j = k:k+R.u
310-
dind = R.u+1+k-j
310+
for j = k:k+Ru
311+
dind = Ru+1+k-j
311312
v = view(R.data, range(dind, length=M), j)
312313
dt = dot(wp,v)
313314
axpy!(-2dt,wp,v)
314315
end
315316

316317
# scale banded/filled entries
317-
for j = (k+R.u).+(1:M-1)
318-
p = j-k-R.u
318+
for j = (k+Ru).+(1:M-1)
319+
p = j-k-Ru
319320
v = view(R.data,1:M-p,j) # shift down each time
320321
wp2=view(wp,p+1:M)
321322
dt = dot(wp2,v)

src/Caching/banded.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function resizedata!(B::CachedOperator{T,<:BandedMatrix{T}},n::Integer,m_in::Int
1717
n = min(n, N)
1818

1919
if n > B.datasize[1]
20-
pad!(B.data,min(N,2n),m)
20+
B.data = pad(B.data,min(N,2n),m)
2121

2222
kr=B.datasize[1]+1:n
2323
jr=max(B.datasize[1]+1-B.data.l,1):min(n+B.data.u,M)
@@ -52,34 +52,36 @@ function resizedata!(QR::QROperator{<:CachedOperator{T,<:BandedMatrix{T}}}, ::Co
5252
MO=QR.R_cache
5353
W=QR.H
5454

55-
R=MO.data
56-
M=R.l+1 # number of diag+subdiagonal bands
55+
Rl,Ru = bandwidths(MO.data)
56+
M = Rl + 1 # number of diag+subdiagonal bands
5757

5858
if col+M-1 MO.datasize[1]
5959
resizedata!(MO,(col+M-1)+100,:) # double the last rows
6060
end
6161

62+
R = MO.data # has to be accessed after resizedata!, as the matrix might change
63+
6264
if col > size(W,2)
6365
W=QR.H=unsafe_resize!(W,:,2col)
6466
end
6567

6668
for k=QR.ncols+1:col
67-
W[:,k] = view(R.data, (R.u+1).+(0:R.l), k) # diagonal and below
69+
W[:,k] = view(R.data, (Ru+1).+(0:Rl), k) # diagonal and below
6870
wp=view(W,:,k)
6971
W[1,k]+= flipsign(norm(wp),W[1,k])
7072
normalize!(wp)
7173

7274
# scale banded entries
73-
for j=k:k+R.u
74-
dind=R.u+1+k-j
75+
for j=k:k+Ru
76+
dind=Ru+1+k-j
7577
v=view(R.data, range(dind, length=M), j)
7678
dt=dot(wp,v)
7779
axpy!(-2dt,wp,v)
7880
end
7981

8082
# scale banded/filled entries
81-
for j = (k+R.u).+(1:M-1)
82-
p=j-k-R.u
83+
for j = (k+Ru).+(1:M-1)
84+
p=j-k-Ru
8385
v=view(R.data,1:M-p,j) # shift down each time
8486
wp2=view(wp,p+1:M)
8587
dt=dot(wp2,v)

src/Fun.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ function Fun(sp::Space,v::AbstractVector{Any})
8686
end
8787
end
8888

89+
Fun(f::Fun) = f # Fun of Fun should be like a conversion
8990

9091
hasnumargs(f::Fun,k) = k == 1 || domaindimension(f) == k # all funs take a single argument as a SVector
9192

src/LinearAlgebra/AlmostBandedMatrix.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ function setindex!(B::AlmostBandedMatrix,v,k::Integer,j::Integer)
4949
end
5050

5151

52-
function pad!(B::AlmostBandedMatrix,n::Integer,m::Integer)
53-
pad!(B.bands,n,m)
54-
pad!(B.fill,n,m)
55-
B
52+
function pad(B::AlmostBandedMatrix,n::Integer,m::Integer)
53+
bands = pad(B.bands,n,m)
54+
fill = pad(B.fill,n,m)
55+
AlmostBandedMatrix(bands, fill)
5656
end

src/LinearAlgebra/LowRankMatrix.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ end
3434

3535
# constructors
3636

37-
function pad!(L::LowRankMatrix,n::Integer,::Colon)
38-
L.U=pad(L.U,n,:)
39-
L
37+
function pad(L::LowRankMatrix,n::Integer,::Colon)
38+
U = pad(L.U, n, :)
39+
LowRankMatrix(U, L.V)
4040
end
41-
function pad!(L::LowRankMatrix,::Colon,m::Integer)
42-
L.V=pad(L.V,m,:)
43-
L
41+
function pad(L::LowRankMatrix,::Colon,m::Integer)
42+
V = pad(L.V,m,:)
43+
LowRankMatrix(L.U, V)
4444
end
45-
pad!(L::LowRankMatrix,n::Integer,m::Integer) = pad!(pad!(L,n,:),:,m)
45+
pad(L::LowRankMatrix,n::Integer,m::Integer) = pad(pad(L,n,:),:,m)

src/LinearAlgebra/helper.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,13 @@ function pad(A::AbstractMatrix,n::Integer,m::Integer)
260260
end
261261
end
262262

263+
function pad(A::BandedMatrix, n::Integer, m::Integer)
264+
B = BandedMatrix{eltype(A)}(undef, (n,m), bandwidths(A))
265+
copyto!(B.data, A.data)
266+
B.data[length(A.data)+1:end] .= 0
267+
return B
268+
end
269+
263270
pad(A::AbstractMatrix,::Colon,m::Integer) = pad(A,size(A,1),m)
264271
pad(A::AbstractMatrix,n::Integer,::Colon) = pad(A,n,size(A,2))
265272

@@ -534,7 +541,7 @@ Base.isless(x::PosInfinity, y::Block{1}) = isless(x, Int(y))
534541

535542

536543

537-
pad!(A::BandedMatrix,n,::Colon) = pad!(A,n,n+A.u) # Default is to get all columns
544+
pad(A::BandedMatrix,n,::Colon) = pad(A,n,n+A.u) # Default is to get all columns
538545
columnrange(A,row::Integer) = max(1,row-bandwidth(A,1)):row+bandwidth(A,2)
539546

540547

src/constructors.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,47 @@ Fun(f,n::Integer) = Fun(f,ChebyshevInterval(),n)
242242
Fun(T::Type,d::AbstractVector) = Fun(T(),d)
243243

244244
Fun(f::Fun{SequenceSpace},s::Space) = Fun(s,f.coefficients)
245+
246+
"""
247+
Fun(f)
248+
249+
Return `Fun(f, space)` by choosing an appropriate `space` for the function.
250+
For univariate functions, `space` is chosen to be `Chebyshev()`, whereas for
251+
multivariate functions, it is a tensor product of `Chebyshev()` spaces.
252+
253+
# Examples
254+
```jldoctest
255+
julia> f = Fun(x -> x^2)
256+
Fun(Chebyshev(), [0.5, 0.0, 0.5])
257+
258+
julia> f(0.1) == (0.1)^2
259+
true
260+
261+
julia> f = Fun((x,y) -> x + y);
262+
263+
julia> f(0.1, 0.2) ≈ 0.3
264+
true
265+
```
266+
"""
267+
function Fun(f::Function)
268+
if hasonearg(f)
269+
# check for tuple
270+
try
271+
f(0)
272+
catch ex
273+
if ex isa BoundsError
274+
# assume its a tuple
275+
return Fun(f,ChebyshevInterval()^2)
276+
else
277+
rethrow()
278+
end
279+
end
280+
281+
Fun(f,ChebyshevInterval())
282+
elseif hasnumargs(f,2)
283+
Fun(f,ChebyshevInterval()^2)
284+
else
285+
error("Function not defined on interval or square")
286+
end
287+
end
288+

src/hacks.jl

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,3 @@
1-
## Functions that depend on the structure of BandedMatrix
2-
3-
4-
function pad!(A::BandedMatrix,n,m)
5-
A.data = pad(A.data,size(A.data,1),m)
6-
A.raxis = Base.OneTo(n)
7-
A
8-
end
9-
10-
11-
12-
# linear algebra
13-
14-
15-
## Constructors that involve MultivariateFun
16-
Fun(f::Fun) = f # Fun of Fun should be like a conversion
17-
18-
"""
19-
Fun(f)
20-
21-
Return `Fun(f, space)` by choosing an appropriate `space` for the function.
22-
For univariate functions, `space` is chosen to be `Chebyshev()`, whereas for
23-
multivariate functions, it is a tensor product of `Chebyshev()` spaces.
24-
25-
# Examples
26-
```jldoctest
27-
julia> f = Fun(x -> x^2)
28-
Fun(Chebyshev(), [0.5, 0.0, 0.5])
29-
30-
julia> f(0.1) == (0.1)^2
31-
true
32-
33-
julia> f = Fun((x,y) -> x + y);
34-
35-
julia> f(0.1, 0.2) ≈ 0.3
36-
true
37-
```
38-
"""
39-
function Fun(f::Function)
40-
if hasonearg(f)
41-
# check for tuple
42-
try
43-
f(0)
44-
catch ex
45-
if isa(ex,BoundsError)
46-
# assume its a tuple
47-
return Fun(f,ChebyshevInterval()^2)
48-
else
49-
throw(ex)
50-
end
51-
end
52-
53-
Fun(f,ChebyshevInterval())
54-
elseif hasnumargs(f,2)
55-
Fun(f,ChebyshevInterval()^2)
56-
else
57-
error("Function not defined on interval or square")
58-
end
59-
end
60-
61-
62-
63-
641
## These hacks support PDEs with block matrices
652

663

0 commit comments

Comments
 (0)