Skip to content

Commit 93e79dd

Browse files
authored
Type-stability in 2d tensorizer iteration (#374)
* type-stability in 2d tensorizer iteration * version bump to v0.7.69
1 parent 603096d commit 93e79dd

File tree

3 files changed

+22
-18
lines changed

3 files changed

+22
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ApproxFunBase"
22
uuid = "fbd15aa5-315a-5a7d-a8a4-24992e37be05"
3-
version = "0.7.68"
3+
version = "0.7.69"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/LinearAlgebra/helper.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ end
546546
CachedIterator{T,IT}(it::IT, state) where {T,IT} = CachedIterator{T,IT}(it,T[],state,0)
547547
CachedIterator(it::IT) where IT = CachedIterator{eltype(it),IT}(it, ())
548548

549-
function resize!(it::CachedIterator,n::Integer)
549+
function resize!(it::CachedIterator{T},n::Integer) where {T}
550550
m = it.length
551551
if n > m
552552
if n > length(it.storage)
@@ -559,10 +559,10 @@ function resize!(it::CachedIterator,n::Integer)
559559
it.length = k-1
560560
return it
561561
end
562-
it.storage[k] = xst[1]
563-
it.state = (xst[2],)
562+
v::T, st = xst
563+
it.storage[k] = v
564+
it.state = (st,)
564565
end
565-
566566
it.length = n
567567
end
568568
it
@@ -579,7 +579,7 @@ Base.keys(c::CachedIterator) = oneto(length(c))
579579

580580
iterate(it::CachedIterator) = iterate(it,1)
581581
function iterate(it::CachedIterator,st::Int)
582-
if st == it.length + 1 && iterate(it.iterator,it.state...) === nothing
582+
if st > it.length && iterate(it.iterator,it.state...) === nothing
583583
nothing
584584
else
585585
(it[st],st+1)

src/Multivariate/TensorSpace.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ factor(d::AbstractProductSpace,k) = factors(d)[k]
1717
# tensor entry of a tensor product of d spaces
1818
# findfirst is overriden to get efficient inverse
1919
# blocklengths is a tuple of block lengths, e.g., Chebyshev()^2
20-
# would be Tensorizer((1:∞,1:∞))
20+
# would be Tensorizer((Ones{Int}(∞), Ones{Int}(∞)))
2121
# ConstantSpace() ⊗ Chebyshev()
22-
# would be Tensorizer((1:1,1:∞))
22+
# would be Tensorizer((1:1,Ones{Int}(∞)))
2323
# and Chebyshev() ⊗ ArraySpace([Chebyshev(),Chebyshev()])
24-
# would be Tensorizer((1:∞,2:2:∞))
24+
# would be Tensorizer((Ones{Int}(∞), Fill(2,∞)))
2525

2626

2727
struct Tensorizer{DMS<:Tuple}
@@ -79,7 +79,7 @@ function next(a::TrivialTensorizer{d}, iterator_tuple) where {d}
7979
end
8080

8181

82-
function done(a::TrivialTensorizer, iterator_tuple)
82+
function done(a::TrivialTensorizer, iterator_tuple)::Bool
8383
i, tot = last(iterator_tuple)
8484
return i tot
8585
end
@@ -89,12 +89,13 @@ end
8989
start(a::Tensorizer2D) = _start(a)
9090
start(a::TrivialTensorizer{2}) = _start(a)
9191

92-
_start(a) = (1,1), (1,1), (0,0), (a.blocks[1][1],a.blocks[2][1]), (0,length(a))
92+
_start(a) = (1,1, 1,1, 0,0, a.blocks[1][1],a.blocks[2][1]), (0,length(a))
9393

94-
next(a::Tensorizer2D, state) = _next(a, state)
95-
next(a::TrivialTensorizer{2}, state) = _next(a, state)
94+
next(a::Tensorizer2D, state) = _next(a, state::typeof(_start(a)))
95+
next(a::TrivialTensorizer{2}, state) = _next(a, state::typeof(_start(a)))
9696

97-
function _next(a, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot)))
97+
function _next(a, st)
98+
(K,J, k,j, rsh,csh, n,m), (i,tot) = st
9899
ret = k+rsh,j+csh
99100
if k==n && j==m # end of block
100101
if J == 1 || K == length(a.blocks[1]) # end of new block
@@ -115,13 +116,16 @@ function _next(a, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot)))
115116
else
116117
k += 1
117118
end
118-
ret, ((K,J), (k,j), (rsh,csh), (n,m), (i+1,tot))
119+
ret, ((K,J, k,j, rsh,csh, n,m), (i+1,tot))
119120
end
120121

121-
done(a::Tensorizer2D, state) = _done(a, state)
122-
done(a::TrivialTensorizer{2}, state) = _done(a, state)
122+
done(a::Tensorizer2D, state) = _done(a, state::typeof(_start(a)))
123+
done(a::TrivialTensorizer{2}, state) = _done(a, state::typeof(_start(a)))
123124

124-
_done(a, (_, _, _, _, (i,tot))) = i tot
125+
function _done(a, st)::Bool
126+
i, tot = last(st)
127+
i tot
128+
end
125129

126130
iterate(a::Tensorizer) = next(a, start(a))
127131
function iterate(a::Tensorizer, st)

0 commit comments

Comments
 (0)