Skip to content

Type-stability in 2d tensorizer iteration #374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ApproxFunBase"
uuid = "fbd15aa5-315a-5a7d-a8a4-24992e37be05"
version = "0.7.68"
version = "0.7.69"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
10 changes: 5 additions & 5 deletions src/LinearAlgebra/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ end
CachedIterator{T,IT}(it::IT, state) where {T,IT} = CachedIterator{T,IT}(it,T[],state,0)
CachedIterator(it::IT) where IT = CachedIterator{eltype(it),IT}(it, ())

function resize!(it::CachedIterator,n::Integer)
function resize!(it::CachedIterator{T},n::Integer) where {T}
m = it.length
if n > m
if n > length(it.storage)
Expand All @@ -559,10 +559,10 @@ function resize!(it::CachedIterator,n::Integer)
it.length = k-1
return it
end
it.storage[k] = xst[1]
it.state = (xst[2],)
v::T, st = xst
it.storage[k] = v
it.state = (st,)
end

it.length = n
end
it
Expand All @@ -579,7 +579,7 @@ Base.keys(c::CachedIterator) = oneto(length(c))

iterate(it::CachedIterator) = iterate(it,1)
function iterate(it::CachedIterator,st::Int)
if st == it.length + 1 && iterate(it.iterator,it.state...) === nothing
if st > it.length && iterate(it.iterator,it.state...) === nothing
nothing
else
(it[st],st+1)
Expand Down
28 changes: 16 additions & 12 deletions src/Multivariate/TensorSpace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ factor(d::AbstractProductSpace,k) = factors(d)[k]
# tensor entry of a tensor product of d spaces
# findfirst is overriden to get efficient inverse
# blocklengths is a tuple of block lengths, e.g., Chebyshev()^2
# would be Tensorizer((1:∞,1:∞))
# would be Tensorizer((Ones{Int}(∞), Ones{Int}(∞)))
# ConstantSpace() ⊗ Chebyshev()
# would be Tensorizer((1:1,1:∞))
# would be Tensorizer((1:1,Ones{Int}(∞)))
# and Chebyshev() ⊗ ArraySpace([Chebyshev(),Chebyshev()])
# would be Tensorizer((1:∞,2:2:∞))
# would be Tensorizer((Ones{Int}(∞), Fill(2,∞)))


struct Tensorizer{DMS<:Tuple}
Expand Down Expand Up @@ -79,7 +79,7 @@ function next(a::TrivialTensorizer{d}, iterator_tuple) where {d}
end


function done(a::TrivialTensorizer, iterator_tuple)
function done(a::TrivialTensorizer, iterator_tuple)::Bool
i, tot = last(iterator_tuple)
return i ≥ tot
end
Expand All @@ -89,12 +89,13 @@ end
start(a::Tensorizer2D) = _start(a)
start(a::TrivialTensorizer{2}) = _start(a)

_start(a) = (1,1), (1,1), (0,0), (a.blocks[1][1],a.blocks[2][1]), (0,length(a))
_start(a) = (1,1, 1,1, 0,0, a.blocks[1][1],a.blocks[2][1]), (0,length(a))

next(a::Tensorizer2D, state) = _next(a, state)
next(a::TrivialTensorizer{2}, state) = _next(a, state)
next(a::Tensorizer2D, state) = _next(a, state::typeof(_start(a)))
next(a::TrivialTensorizer{2}, state) = _next(a, state::typeof(_start(a)))

function _next(a, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot)))
function _next(a, st)
(K,J, k,j, rsh,csh, n,m), (i,tot) = st
ret = k+rsh,j+csh
if k==n && j==m # end of block
if J == 1 || K == length(a.blocks[1]) # end of new block
Expand All @@ -115,13 +116,16 @@ function _next(a, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot)))
else
k += 1
end
ret, ((K,J), (k,j), (rsh,csh), (n,m), (i+1,tot))
ret, ((K,J, k,j, rsh,csh, n,m), (i+1,tot))
end

done(a::Tensorizer2D, state) = _done(a, state)
done(a::TrivialTensorizer{2}, state) = _done(a, state)
done(a::Tensorizer2D, state) = _done(a, state::typeof(_start(a)))
done(a::TrivialTensorizer{2}, state) = _done(a, state::typeof(_start(a)))

_done(a, (_, _, _, _, (i,tot))) = i ≥ tot
function _done(a, st)::Bool
i, tot = last(st)
i ≥ tot
end

iterate(a::Tensorizer) = next(a, start(a))
function iterate(a::Tensorizer, st)
Expand Down