Skip to content

Commit 0264d00

Browse files
authored
Don't cache TrivialInterlacer (#504)
* Don't cache TrivialInterlacer * Fix show, add tests * Add comment to UncachedIterator
1 parent 3b60481 commit 0264d00

File tree

4 files changed

+70
-15
lines changed

4 files changed

+70
-15
lines changed

src/LinearAlgebra/helper.jl

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,28 @@ Base.isless(x::PosInfinity, y::Block{1}) = isless(x, Int(y))
543543
pad(A::BandedMatrix,n,::Colon) = pad(A,n,n+A.u) # Default is to get all columns
544544
columnrange(A,row::Integer) = max(1,row-bandwidth(A,1)):row+bandwidth(A,2)
545545

546+
abstract type AbstractCachedIterator{T,IT} end
547+
eltype(it::Type{<:AbstractCachedIterator{T}}) where {T} = T
548+
function Base.IteratorSize(::Type{<:AbstractCachedIterator{<:Any,IT}}) where {IT}
549+
Base.IteratorSize(IT) isa Base.IsInfinite ? Base.IsInfinite() : Base.HasLength()
550+
end
551+
552+
Base.keys(c::AbstractCachedIterator) = oneto(length(c))
553+
length(A::AbstractCachedIterator) = length(A.iterator)
554+
555+
# Lazy wrapper that mimics a CachedIterator, and has an `iterator` field
556+
struct UncachedIterator{T,IT} <: AbstractCachedIterator{T,IT}
557+
iterator :: IT
558+
end
559+
UncachedIterator(it::IT) where IT = UncachedIterator{eltype(it),IT}(it)
560+
561+
iterate(it::UncachedIterator, st...) = iterate(it.iterator, st...)
562+
getindex(it::UncachedIterator, k) = it.iterator[k]
546563

564+
Base.show(io::IO, C::UncachedIterator) = print(io, "$UncachedIterator(", C.iterator, ")")
547565

548566
## Store iterator
549-
mutable struct CachedIterator{T,IT}
567+
mutable struct CachedIterator{T,IT} <: AbstractCachedIterator{T,IT}
550568
iterator::IT
551569
storage::Vector{T}
552570
state
@@ -582,15 +600,6 @@ function resize!(it::CachedIterator{T},n::Integer) where {T}
582600
it
583601
end
584602

585-
586-
eltype(it::Type{<:CachedIterator{T}}) where {T} = T
587-
588-
function Base.IteratorSize(::Type{<:CachedIterator{<:Any,IT}}) where {IT}
589-
Base.IteratorSize(IT) isa Base.IsInfinite ? Base.IsInfinite() : Base.HasLength()
590-
end
591-
592-
Base.keys(c::CachedIterator) = oneto(length(c))
593-
594603
iterate(it::CachedIterator) = iterate(it,1)
595604
function iterate(it::CachedIterator,st::Int)
596605
if st > it.length && iterate(it.iterator,it.state...) === nothing
@@ -612,7 +621,6 @@ end
612621
@deprecate findfirst(A::CachedIterator, x) findfirst(x, A::CachedIterator)
613622
findfirst(x::T, A::CachedIterator{T}) where {T} = findfirst(==(x), A)
614623

615-
length(A::CachedIterator) = length(A.iterator)
616624

617625
## nocat
618626
vnocat(A...) = Base.vect(A...)
@@ -676,7 +684,7 @@ conv(x::AbstractFill, y::AbstractFill) = DSP.conv(x, y)
676684
## BlockInterlacer
677685
# interlaces coefficients by blocks
678686
# this has the property that all the coefficients of a block of a subspace
679-
# are grouped together, starting with the first bloc
687+
# are grouped together, starting with the first block
680688
#
681689
# TODO: cache sums
682690

@@ -747,3 +755,33 @@ iterate(it::TrivialInterlacer{N,OneToInf{Int}}, st...) where {N} =
747755
iterate(Iterators.product(1:N, axes(it.blocks[1],1)), st...)
748756

749757
cache(Q::BlockInterlacer) = CachedIterator(Q)
758+
759+
# don't cache a trivial interlacer, as indexing into it is fast
760+
cache(Q::TrivialInterlacer) = UncachedIterator(Q)
761+
function Base.getindex(Q::TrivialInterlacer{N}, i::Int) where {N}
762+
reverse(divrem(i-1,N) .+ 1)
763+
end
764+
function Base.getindex(Q::TrivialInterlacer, v::AbstractVector)
765+
TrivialInterlacerSection(Q, v)
766+
end
767+
768+
struct TrivialInterlacerSection{TI,I} <: AbstractVector{Tuple{Int,Int}}
769+
interlacer::TI
770+
inds::I
771+
end
772+
Base.size(t::TrivialInterlacerSection) = size(t.inds)
773+
Base.getindex(t::TrivialInterlacerSection, i::Int) = t.interlacer[t.inds[i]]
774+
function Base.getindex(t::TrivialInterlacerSection, i::AbstractVector{Int})
775+
TrivialInterlacerSection(t.interlacer, t.inds[i])
776+
end
777+
function findsub(t::TrivialInterlacerSection{<:TrivialInterlacer{d}}, n::Int) where {d}
778+
if 1 <= n <= d
779+
ind1 = findfirst(x->x[1]==n, t) # d terms need to be searched at most
780+
if ind1 === nothing
781+
return oneunit(firstindex(t)):d:zero(length(t))
782+
end
783+
return ind1:d:length(t)
784+
else
785+
return oneunit(firstindex(t)):d:zero(length(t))
786+
end
787+
end

src/Operators/general/InterlaceOperator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ for TYP in (:BandedMatrix, :BlockBandedMatrix, :BandedBlockBandedMatrix, :Ragged
398398
cr=L.rangeinterlacer[kr]
399399
cd=L.domaininterlacer[jr]
400400
for ν=1:size(L.ops,1),μ=1:size(L.ops,2)
401-
# indicies of ret
401+
# indices of ret
402402
ret_kr=findsub(cr,ν)
403403
ret_jr=findsub(cd,μ)
404404

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,26 @@ end
6666
for (a,b) in it
6767
@test a == b
6868
end
69+
C1 = cache(B1)
70+
C2 = cache(B2)
71+
for i in 0:nD
72+
indsB1 = ApproxFunBase.findsub(C1[1:10], i)
73+
indsB2 = ApproxFunBase.findsub(C2[1:10], i)
74+
@test indsB1 == indsB2
75+
indsB1 = ApproxFunBase.findsub(C1[2:2], i)
76+
indsB2 = ApproxFunBase.findsub(C2[2:2], i)
77+
@test indsB1 == indsB2
78+
end
6979
end
7080
test_nD(1)
7181
test_nD(2)
7282
test_nD(3)
83+
84+
B = ApproxFunBase.BlockInterlacer((o, o))
85+
C = cache(B)
86+
@test contains(Base.sprint(show, C), "UncachedIterator($(repr(B)))")
87+
@test first(C, 10) == C[1:10] == B[1:10] == first(B, 10)
88+
@test C[2:10][1:2:end] == B[2:10][1:2:end]
7389
end
7490
end
7591

test/show.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@
8787
end
8888
end
8989
@testset "Iterators" begin
90-
B = ApproxFunBase.BlockInterlacer((Ones{Int}(2), Ones{Int}(2)))
91-
@test repr(B) == "ApproxFunBase.BlockInterlacer((Ones(2), Ones(2)))"
90+
t = ([1,1], [1,1])
91+
B = ApproxFunBase.BlockInterlacer(t)
92+
@test repr(B) == "$(ApproxFunBase.BlockInterlacer)($(repr(t)))"
9293
C = cache(B)
9394
@test contains(repr(C), "Cached " * repr(B))
9495
end

0 commit comments

Comments
 (0)