Skip to content

Commit a8c8831

Browse files
authored
IteratorSize for Tensorizer and BlockInterlacer (#365)
* IteratorSize for Tensorizer and BlockInterlacer * keys and IteratorSize for CachedIterator
1 parent 6e231f8 commit a8c8831

File tree

4 files changed

+31
-7
lines changed

4 files changed

+31
-7
lines changed

src/ApproxFunBase.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ function assert_integer(k::Number)
135135
return nothing
136136
end
137137

138+
function _IteratorSize(::Type{T}) where {T<:Tuple}
139+
s = ntuple(i-> Base.IteratorSize(fieldtype(T, i)), fieldcount(T))
140+
any(x -> x isa Base.IsInfinite, s) ? Base.IsInfinite() : Base.HasLength()
141+
end
142+
138143
include("LinearAlgebra/LinearAlgebra.jl")
139144
include("Fun.jl")
140145
include("onehotvector.jl")

src/LinearAlgebra/helper.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,10 @@ end
571571

572572
eltype(it::Type{<:CachedIterator{T}}) where {T} = T
573573

574+
Base.IteratorSize(::Type{<:CachedIterator{<:Any,IT}}) where {IT} = Base.IteratorSize(IT)
575+
576+
Base.keys(c::CachedIterator) = keys(c.iterator)
577+
574578
iterate(it::CachedIterator) = iterate(it,1)
575579
function iterate(it::CachedIterator,st::Int)
576580
if st == it.length + 1 && iterate(it.iterator,it.state...) === nothing
@@ -702,8 +706,9 @@ eltype(::Type{<:BlockInterlacer}) = Tuple{Int,Int}
702706

703707
dimensions(b::BlockInterlacer) = map(sum,b.blocks)
704708
dimension(b::BlockInterlacer,k) = sum(b.blocks[k])
705-
Base.length(b::BlockInterlacer) = mapreduce(sum,+,b.blocks)
709+
length(b::BlockInterlacer) = mapreduce(sum,+,b.blocks)
706710

711+
Base.IteratorSize(::Type{BlockInterlacer{T}}) where {T} = _IteratorSize(T)
707712

708713
# the state is always (whichblock,curblock,cursubblock,curcoefficients)
709714
# start(it::BlockInterlacer) = (1,1,map(start,it.blocks),ntuple(zero,length(it.blocks)))

src/Multivariate/TensorSpace.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@ struct Tensorizer{DMS<:Tuple}
2828
blocks::DMS
2929
end
3030

31+
const InfOnes = Ones{Int,1,Tuple{OneToInf{Int}}}
3132
const Tensorizer2D{AA, BB} = Tensorizer{Tuple{AA, BB}}
32-
const TrivialTensorizer{d} = Tensorizer{NTuple{d,Ones{Int,1,Tuple{OneToInf{Int}}}}}
33+
const TrivialTensorizer{d} = Tensorizer{NTuple{d,InfOnes}}
3334

3435
eltype(::Type{<:Tensorizer{<:Tuple{Vararg{Any,N}}}}) where {N} = NTuple{N,Int}
3536
dimensions(a::Tensorizer) = map(sum,a.blocks)
36-
Base.length(a::Tensorizer) = reduce(*, dimensions(a)) # easier type-inference than mapreduce
37+
length(a::Tensorizer) = reduce(*, dimensions(a)) # easier type-inference than mapreduce
38+
39+
Base.IteratorSize(::Type{Tensorizer{T}}) where {T<:Tuple} = _IteratorSize(T)
3740

3841
Base.keys(a::Tensorizer) = oneto(length(a))
3942

@@ -129,7 +132,9 @@ end
129132

130133
cache(a::Tensorizer) = CachedIterator(a)
131134

132-
function Base.findfirst(::TrivialTensorizer{2},kj::Tuple{Int,Int})
135+
@deprecate findfirst(t::Tensorizer, kj::NTuple{2,Int}) findfirst(kj, t)
136+
137+
function Base.findfirst(kj::NTuple{2,Int}, ::TrivialTensorizer{2})
133138
k,j=kj
134139
if k > 0 && j > 0
135140
n=k+j-2
@@ -138,7 +143,7 @@ function Base.findfirst(::TrivialTensorizer{2},kj::Tuple{Int,Int})
138143
nothing
139144
end
140145
end
141-
function Base.findfirst(sp::Tensorizer{<:NTuple{2,Ones{Int}}}, kj::NTuple{2,Int})
146+
function Base.findfirst(kj::NTuple{2,Int}, sp::Tensorizer{<:NTuple{2,Ones{Int}}})
142147
k,j=kj
143148

144149
len1, len2 = length(sp.blocks[1]), length(sp.blocks[2])
@@ -212,9 +217,9 @@ blocklengths(::TrivialTensorizer{2}) = 1:∞
212217
blocklengths(it::Tensorizer) = tensorblocklengths(it.blocks...)
213218
blocklengths(it::CachedIterator) = blocklengths(it.iterator)
214219

215-
function getindex(it::TrivialTensorizer{2},n::Integer)
220+
function getindex(it::TrivialTensorizer{2}, n::Integer)
216221
m=Int(block(it,n))
217-
p=findfirst(it,(1,m))
222+
p=findfirst((1,m), it)
218223
j=1+n-p
219224
j,m-j+1
220225
end

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,15 @@ end
511511
@test findfirst(t, vi) == i
512512
end
513513
end
514+
@testset "cache" begin
515+
ax = Ones{Int}(4);
516+
t = ApproxFunBase.Tensorizer((ax,ax))
517+
c = ApproxFunBase.cache(t)
518+
v = collect(c)
519+
@testset for i in eachindex(c)
520+
@test c[i] == v[i]
521+
end
522+
end
514523
end
515524

516525
@time include("ETDRK4Test.jl")

0 commit comments

Comments
 (0)