Skip to content

Commit c7a4941

Browse files
committed
traits in size/bandwidths in wrapperstructure
1 parent e720dba commit c7a4941

File tree

4 files changed

+106
-58
lines changed

4 files changed

+106
-58
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
2020
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
2121
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2222
LowRankApprox = "898213cb-b102-5a47-900c-97e73b919f73"
23+
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
2324
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2425
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2526
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -45,6 +46,7 @@ Infinities = "0.1"
4546
IntervalSets = "0.5, 0.6, 0.7"
4647
LazyArrays = "0.20, 0.21, 0.22, 1"
4748
LowRankApprox = "0.2, 0.3, 0.4, 0.5"
49+
SimpleTraits = "0.9"
4850
SpecialFunctions = "0.10, 1.0, 2"
4951
StaticArrays = "0.12, 1.0"
5052
julia = "1.6"

src/ApproxFunBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ import DomainSets: dimension
9595

9696
import IntervalSets: (..), endpoints
9797

98+
using SimpleTraits
99+
98100
const Vec{d,T} = SVector{d,T}
99101

100102
export pad!, pad, chop!, sample,

src/Operators/Operator.jl

Lines changed: 99 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,12 @@ end
109109
blocksize(A::Operator,k) = k==1 ? length(blocklengths(rangespace(A))) : length(blocklengths(domainspace(A)))
110110
blocksize(A::Operator) = (blocksize(A,1),blocksize(A,2))
111111

112-
112+
# operators need to define size(A, k::Integer)
113113
size(A::Operator) = (size(A,1),size(A,2))
114-
size(A::Operator,k::Integer) = k==1 ? dimension(rangespace(A)) : dimension(domainspace(A))
115114
length(A::Operator) = size(A,1) * size(A,2)
116115

116+
size(A::Operator,k::Integer) = opsize(A, k)
117+
117118

118119
# used to compute "end" for last index
119120
function lastindex(A::Operator, n::Integer)
@@ -139,50 +140,25 @@ Base.ndims(::Operator) = 2
139140
## bandrange and indexrange
140141
isbandedbelow(A::Operator) = isfinite(bandwidth(A,1))::Bool
141142
isbandedabove(A::Operator) = isfinite(bandwidth(A,2))::Bool
142-
isbanded(A::Operator) = all(isfinite, bandwidths(A))::Bool
143+
isbanded(A::Operator) = opisbanded(A)
143144

144145

145146
isbandedblockbandedbelow(_) = false
146147
isbandedblockbandedabove(_) = false
147148

148-
isbandedblockbanded(A::Operator) = isbandedblockbandedabove(A) && isbandedblockbandedbelow(A)
149+
isbandedblockbanded(A::Operator) = opisbandedblockbanded(A::Operator)
149150

150151

151-
# this should be determinable at compile time
152-
#TODO: I think it can be generalized to the case when the domainspace
153-
# blocklengths == rangespace blocklengths, in which case replace the definition
154-
# of p with maximum(blocklength(domainspace(A)))
155-
function blockbandwidths(A::Operator)
156-
hastrivialblocks(A) && return bandwidths(A)
157-
158-
if hasconstblocks(A)
159-
a,b = bandwidths(A)
160-
p = getindex_value(blocklengths(domainspace(A)))
161-
return (-fld(-a,p),-fld(-b,p))
162-
end
163-
164-
#TODO: Generalize to finite dimensional
165-
if size(A,2) == 1
166-
rs = rangespace(A)
167-
168-
if hasconstblocks(rs)
169-
a = bandwidth(A,1)
170-
p = getindex_value(blocklengths(rs))
171-
return (-fld(-a,p),0)
172-
end
173-
end
174-
175-
return (length(blocklengths(rangespace(A)))-1,length(blocklengths(domainspace(A)))-1)
176-
end
152+
blockbandwidths(A::Operator) = opblockbandwidths(A::Operator)
177153

178154
# assume dense blocks
179-
subblockbandwidths(K::Operator) = maximum(blocklengths(rangespace(K)))-1, maximum(blocklengths(domainspace(K)))-1
155+
subblockbandwidths(K::Operator) = opsubblockbandwidths(K::Operator)
180156

181157
isblockbandedbelow(A::Operator) = isfinite(blockbandwidth(A,1))::Bool
182158
isblockbandedabove(A::Operator) = isfinite(blockbandwidth(A,2))::Bool
183-
isblockbanded(A::Operator) = all(isfinite, blockbandwidths(A))::Bool
159+
isblockbanded(A::Operator) = opisblockbanded(A::Operator)
184160

185-
israggedbelow(A::Operator) = isbandedbelow(A)::Bool || isbandedblockbanded(A)::Bool || isblockbandedbelow(A)::Bool
161+
israggedbelow(A::Operator) = opisraggedbelow(A)
186162

187163

188164
blockbandwidth(K::Operator, k::Integer) = blockbandwidths(K)[k]
@@ -197,9 +173,7 @@ bandwidth(A::Operator, k::Integer) = bandwidths(A)[k]
197173
Return the bandwidth of `op` in the form `(l,u)`, where `l ≥ 0` represents
198174
the number of subdiagonals and `u ≥ 0` represents the number of superdiagonals.
199175
"""
200-
bandwidths(A::Operator) = (size(A,1)-1,size(A,2)-1)
201-
bandwidths(A::Operator, k::Integer) = bandwidths(A)[k]
202-
176+
bandwidths(A::Operator) = opbandwidths(A)
203177

204178

205179
## Strides
@@ -208,8 +182,7 @@ bandwidths(A::Operator, k::Integer) = bandwidths(A)[k]
208182
# A diagonal operator has essentially infinite stride
209183
# which we represent by a factorial, so that
210184
# the gcd with any number < 10 is the number
211-
stride(A::Operator) =
212-
isdiag(A) ? factorial(10) : 1
185+
stride(A::Operator) = opstride(A::Operator)
213186

214187
isdiag(A::Operator) = bandwidths(A)==(0,0)
215188
istriu(A::Operator) = bandwidth(A, 1) <= 0
@@ -486,36 +459,105 @@ end
486459
# Convenience for wrapper ops
487460
unwrap_axpy!(α,P,A) = axpy!(α,view(parent(P).op,P.indexes[1],P.indexes[2]),A)
488461
iswrapper(_) = false
489-
haswrapperstructure(_) = false
462+
463+
haswrapperstructure(@nospecialize(::Type)) = false
464+
haswrapperstructure(x::Operator) = haswrapperstructure(typeof(x))
465+
466+
@traitdef HasWrapperStructure{X}
467+
@traitimpl HasWrapperStructure{X} <- haswrapperstructure(X)
468+
469+
# Forward various structure query functions to the parent for wrappers
470+
471+
@traitfn opbandwidths(A::X) where {X; !HasWrapperStructure{X}} =
472+
(size(A,1)-1,size(A,2)-1)
473+
@traitfn opbandwidths(A::X) where {X; HasWrapperStructure{X}} =
474+
bandwidths(A.op)
475+
476+
@traitfn opstride(A::X) where {X; HasWrapperStructure{X}} =
477+
stride(A.op)
478+
@traitfn opstride(A::X) where {X; !HasWrapperStructure{X}} =
479+
isdiag(A) ? factorial(10) : 1
480+
481+
@traitfn opisblockbanded(A::X) where {X; HasWrapperStructure{X}} =
482+
isblockbanded(A.op)
483+
@traitfn opisblockbanded(A::X) where {X; !HasWrapperStructure{X}} =
484+
all(isfinite, blockbandwidths(A))::Bool
485+
486+
@traitfn opisbandedblockbanded(A::X) where {X; HasWrapperStructure{X}} =
487+
isbandedblockbanded(A.op)
488+
@traitfn opisbandedblockbanded(A::X) where {X; !HasWrapperStructure{X}} =
489+
isbandedblockbandedabove(A) && isbandedblockbandedbelow(A)
490+
491+
@traitfn opisbanded(A::X) where {X; HasWrapperStructure{X}} =
492+
isbanded(A.op)
493+
@traitfn opisbanded(A::X) where {X; !HasWrapperStructure{X}} =
494+
all(isfinite, bandwidths(A))::Bool
495+
496+
@traitfn opisraggedbelow(A::X) where {X; HasWrapperStructure{X}} =
497+
israggedbelow(A.op)
498+
@traitfn function opisraggedbelow(A::X) where {X; !HasWrapperStructure{X}}
499+
isbandedbelow(A)::Bool ||
500+
isbandedblockbanded(A)::Bool ||
501+
isblockbandedbelow(A)::Bool
502+
end
503+
504+
# this should be determinable at compile time
505+
#TODO: I think it can be generalized to the case when the domainspace
506+
# blocklengths == rangespace blocklengths, in which case replace the definition
507+
# of p with maximum(blocklength(domainspace(A)))
508+
@traitfn opblockbandwidths(A::X) where {X; HasWrapperStructure{X}} =
509+
opblockbandwidths(A.op)
510+
@traitfn function opblockbandwidths(A::X) where {X; !HasWrapperStructure{X}}
511+
hastrivialblocks(A) && return bandwidths(A)
512+
513+
if hasconstblocks(A)
514+
a,b = bandwidths(A)
515+
p = getindex_value(blocklengths(domainspace(A)))
516+
return (-fld(-a,p),-fld(-b,p))
517+
end
518+
519+
#TODO: Generalize to finite dimensional
520+
if size(A,2) == 1
521+
rs = rangespace(A)
522+
523+
if hasconstblocks(rs)
524+
a = bandwidth(A,1)
525+
p = getindex_value(blocklengths(rs))
526+
return (-fld(-a,p),0)
527+
end
528+
end
529+
530+
return (length(blocklengths(rangespace(A)))-1,length(blocklengths(domainspace(A)))-1)
531+
end
532+
533+
@traitfn opsubblockbandwidths(A::X) where {X; HasWrapperStructure{X}} =
534+
subblockbandwidths(A.op)
535+
@traitfn opsubblockbandwidths(A::X) where {X; !HasWrapperStructure{X}} =
536+
maximum(blocklengths(rangespace(A)))-1, maximum(blocklengths(domainspace(A)))-1
537+
538+
@traitfn opsize(A::X, k::Integer) where {X; HasWrapperStructure{X}} =
539+
opsize(A.op, k)
540+
@traitfn opsize(::X, k::PosInfinity) where {X; HasWrapperStructure{X}} = ℵ₀
541+
542+
defaultsize(A, k) = k==1 ? dimension(rangespace(A)) : dimension(domainspace(A))
543+
@traitfn opsize(A::X, k::Integer) where {X; !HasWrapperStructure{X}} =
544+
defaultsize(A, k)
490545

491546
# use this for wrapper operators that have the same structure but
492547
# not necessarily the same entries
493548
#
494549
# Ex: c*op or real(op)
495-
macro wrapperstructure(Wrap, forwardsize = true)
496-
fns = [:(ApproxFunBase.bandwidths),:(LinearAlgebra.stride),
497-
:(ApproxFunBase.isbandedblockbanded),:(ApproxFunBase.isblockbanded),
498-
:(ApproxFunBase.israggedbelow),:(ApproxFunBase.isbanded),
499-
:(ApproxFunBase.blockbandwidths),:(ApproxFunBase.subblockbandwidths),
500-
:(LinearAlgebra.issymmetric)]
501-
502-
if forwardsize
503-
fns = [fns; :(Base.size)]
504-
end
550+
macro wrapperstructure(Wrap)
551+
fns = [:(LinearAlgebra.issymmetric)]
505552

506553
v1 = map(fns) do func
507-
508554
:($func(D::$Wrap) = $func(D.op))
509555
end
510556

511557
fns2 = [:(ApproxFunBase.bandwidth),:(ApproxFunBase.colstart),:(ApproxFunBase.colstop),
512558
:(ApproxFunBase.rowstart),:(ApproxFunBase.rowstop),:(ApproxFunBase.blockbandwidth),
513559
:(ApproxFunBase.subblockbandwidth)]
514560

515-
if forwardsize
516-
fns2 = [fns2; :(Base.size)]
517-
end
518-
519561
v2 = map(fns2) do func
520562
quote
521563
$func(D::$Wrap,k::Integer) = $func(D.op,k)
@@ -524,7 +566,7 @@ macro wrapperstructure(Wrap, forwardsize = true)
524566
end
525567

526568
ret = quote
527-
ApproxFunBase.haswrapperstructure(::$Wrap) = true
569+
ApproxFunBase.haswrapperstructure(::Type{<:$Wrap}) = true
528570
$(v1...)
529571
$(v2...)
530572
end
@@ -537,7 +579,7 @@ end
537579
# use this for wrapper operators that have the same entries but
538580
# not necessarily the same spaces
539581
#
540-
macro wrappergetindex(Wrap, forwardsize = true)
582+
macro wrappergetindex(Wrap)
541583
v = map((:(ApproxFunBase.BandedMatrix),:(ApproxFunBase.RaggedMatrix),
542584
:(Base.Matrix),:(Base.Vector),:(Base.AbstractVector))) do TYP
543585
quote
@@ -602,7 +644,7 @@ macro wrappergetindex(Wrap, forwardsize = true)
602644
ApproxFunBase.default_BandedBlockBandedMatrix)
603645
end
604646

605-
ApproxFunBase.@wrapperstructure($Wrap, $forwardsize) # structure is automatically inherited
647+
ApproxFunBase.@wrapperstructure($Wrap) # structure is automatically inherited
606648
end
607649

608650
esc(ret)

src/Operators/spacepromotion.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ end
4646

4747
# Similar to wrapper, but different domain/domainspace/rangespace
4848

49-
@wrappergetindex SpaceOperator false
49+
@wrappergetindex SpaceOperator
50+
51+
size(S::SpaceOperator, k::Integer) = defaultsize(S, k)
5052

5153
# SpaceOperator can change blocks, so we need to override this
5254
getindex(A::SpaceOperator,KR::BlockRange, JR::BlockRange) = defaultgetindex(A,KR,JR)

0 commit comments

Comments
 (0)