Skip to content

Precompute bandwidth in sumspace deriv #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions src/Operators/general/InterlaceOperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,28 @@ end
const VectorInterlaceOperator = InterlaceOperator{T,1,DS,RS} where {T,DS,RS<:Space{D,R}} where {D,R<:AbstractVector}
const MatrixInterlaceOperator = InterlaceOperator{T,2,DS,RS} where {T,DS,RS<:Space{D,R}} where {D,R<:AbstractVector}

@static if VERSION >= v"1.8"
Base.@constprop :aggressive interlace_bandwidths(args...) = _interlace_bandwidths(args...)
else
interlace_bandwidths(args...) = _interlace_bandwidths(args...)
end

function InterlaceOperator(ops::AbstractMatrix{<:Operator},ds::Space,rs::Space)
@inline function _interlace_bandwidths(ops::AbstractMatrix{<:Operator},
ds, rs, allbanded = all(isbanded,ops))
# calculate bandwidths TODO: generalize
p=size(ops,1)
dsi = interlacer(ds)
rsi = interlacer(rs)

if size(ops,2) == p && all(isbanded,ops) &&# only support blocksize (1,) for now
if size(ops,2) == p && allbanded &&# only support blocksize (1,) for now
all(i->isa(i,AbstractFill) && getindex_value(i) == 1, dsi.blocks) &&
all(i->isa(i,AbstractFill) && getindex_value(i) == 1, rsi.blocks)

l,u = 0,0
for k=1:p,j=1:p
for k=axes(ops,1), j=axes(ops,2)
l=max(l,p*bandwidth(ops[k,j],1)+k-j)
end
for k=1:p,j=1:p
for k=axes(ops,1), j=axes(ops,2)
u=max(u,p*bandwidth(ops[k,j],2)+j-k)
end
elseif p == 1 && size(ops,2) == 2 && size(ops[1],2) == 1
Expand All @@ -105,20 +111,14 @@ function InterlaceOperator(ops::AbstractMatrix{<:Operator},ds::Space,rs::Space)
else
l,u = (1-dimension(rs),dimension(ds)-1) # not banded
end

MT = Matrix{Operator{promote_eltypeof(ops)}}
opsm = strictconvert(MT, ops)
InterlaceOperator(opsm,ds,rs,
cache(dsi),
cache(rsi),
(l,u))
l,u
end


function InterlaceOperator(ops::VectorOrTupleOfOp, ds::Space, rs::Space)
@inline function _interlace_bandwidths(ops::VectorOrTupleOfOp,
ds, rs, allbanded = all(isbanded,ops))
# calculate bandwidths
p=size(ops,1)
if all(isbanded,ops)
if allbanded
l,u = 0,0
#TODO: this code assumes an interlace strategy that might not be right
for k=1:p
Expand All @@ -130,13 +130,33 @@ function InterlaceOperator(ops::VectorOrTupleOfOp, ds::Space, rs::Space)
else
l,u = (1-dimension(rs),dimension(ds)-1) # not banded
end
l,u
end

function InterlaceOperator(ops::AbstractMatrix{<:Operator}, ds::Space, rs::Space,
bw = interlace_bandwidths(ops, ds, rs))

dsi = interlacer(ds)
rsi = interlacer(rs)

MT = Matrix{Operator{promote_eltypeof(ops)}}
opsm = strictconvert(MT, ops)
InterlaceOperator(opsm,ds,rs,
cache(dsi),
cache(rsi),
bw)
end


function InterlaceOperator(ops::VectorOrTupleOfOp, ds::Space, rs::Space,
bw = interlace_bandwidths(ops, ds, rs))

VT = Vector{Operator{promote_eltypeof(ops)}}
opsv = strictconvert(VT, convert_vector(ops))
InterlaceOperator(opsv,ds,rs,
cache(BlockInterlacer(tuple(blocklengths(ds)))),
cache(interlacer(rs)),
(l,u))
bw)
end

interlace_domainspace(ops::AbstractMatrix, ::Type{NoSpace}) = domainspace(ops)
Expand Down
19 changes: 13 additions & 6 deletions src/Spaces/ProductSpaceOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,24 @@ end
## Derivative

#TODO: do in @calculus_operator?
_spacename(::SumSpace) = SumSpace
_spacename(::PiecewiseSpace) = PiecewiseSpace

function InterlaceOperator_Diagonal(t, S)
allbanded = all(isbanded, t)
ds, rs = S, _spacename(S)(map(rangespace, t))
D = Diagonal(convert_vector_or_svector(t))
bw = interlace_bandwidths(D, ds, rs, allbanded)
InterlaceOperator(D, ds, rs, bw)
end

for (Op,OpWrap) in ((:Derivative,:DerivativeWrapper),(:Integral,:IntegralWrapper))
_Op = Symbol(:_, Op)
@eval begin
@inline function $_Op(S::PiecewiseSpace, k::Number)
assert_integer(k)
t = map(s->$Op(s,k),components(S))
D = Diagonal(convert_vector_or_svector(t))
O = InterlaceOperator(D, PiecewiseSpace)
O = InterlaceOperator_Diagonal(t, S)
$OpWrap(O,k)
end
@inline function $_Op(S::ArraySpace, k::Number)
Expand All @@ -254,8 +263,7 @@ end
# mixed bases.
if typeof(canonicaldomain(S))==typeof(domain(S))
t = map(s->Derivative(s,k),components(S))
D = Diagonal(convert_vector_or_svector(t))
O = InterlaceOperator(D, SumSpace)
O = InterlaceOperator_Diagonal(t, S)
DerivativeWrapper(O,k)
else
DefaultDerivative(S,k)
Expand Down Expand Up @@ -301,8 +309,7 @@ Multiplication(f::Fun{SumSpace{SV1,D,R1}},sp::SumSpace{SV2,D,R2}) where {SV1,SV2

function Multiplication(f::Fun, sp::SumSpace)
t = map(s->Multiplication(f,s),components(sp))
D = Diagonal(convert_vector_or_svector(t))
O = InterlaceOperator(D, SumSpace)
O = InterlaceOperator_Diagonal(t, sp)
MultiplicationWrapper(f, O)
end

Expand Down
2 changes: 1 addition & 1 deletion src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ function show(io::IO,ss::PiecewiseSpace)
end
show(io,s[1])
for sp in s[2:end]
print(io,"")
print(io,"")
show(io,sp)
end
if length(s) == 1
Expand Down