Skip to content

Commit 5da41bb

Browse files
authored
Binary search in block for nD TrivialTensorizer (#355)
* binary search in block for nd TrivialTensorizer * Tuple in nD TrivialTensorizer iteration
1 parent 726187c commit 5da41bb

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/Multivariate/TensorSpace.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Base.keys(a::Tensorizer) = oneto(length(a))
3535

3636
function start(a::TrivialTensorizer{d}) where {d}
3737
# ((block_dim_1, block_dim_2,...), (itaration_number, iterator, iterator_state)), (itemssofar, length)
38-
block = SVector{d}(Ones{Int}(d))
38+
block = ntuple(one, d)
3939
return (block, (0, nothing, nothing)), (0,length(a))
4040
end
4141

@@ -64,7 +64,7 @@ function next(a::TrivialTensorizer{d}, iterator_tuple) where {d}
6464

6565
# increase block, or initialize new block
6666
_res, iter_state = iterate(iterator, iter_state)
67-
res = SVector{d}(_res)
67+
res = Tuple(SVector{d}(_res))
6868
block = res.+1
6969
j = j+1
7070

@@ -79,8 +79,8 @@ end
7979

8080

8181
# (blockrow,blockcol), (subrow,subcol), (rowshift,colshift), (numblockrows,numblockcols), (itemssofar, length)
82-
start(a::Tensorizer2D) = _start(a::Tensorizer2D)
83-
start(a::TrivialTensorizer{2}) = _start(a::Tensorizer2D)
82+
start(a::Tensorizer2D) = _start(a)
83+
start(a::TrivialTensorizer{2}) = _start(a)
8484

8585
_start(a) = (1,1), (1,1), (0,0), (a.blocks[1][1],a.blocks[2][1]), (0,length(a))
8686

@@ -171,10 +171,22 @@ block(::TrivialTensorizer{2},n::Int) =
171171
Block(floor(Integer,sqrt(2n) + 1/2))
172172

173173
function block(::TrivialTensorizer{d},n::Int) where {d}
174-
order::Int = 0
174+
binomial(d, d) >= n && return Block(1)
175+
order = 1
175176
while binomial(order+d, d) < n
176-
order = order + 1
177+
order *= 2
177178
end
179+
searchords = order÷2:order
180+
# perform a binary search
181+
while length(searchords) > 1
182+
midpt = searchords[length(searchords)÷2]
183+
if binomial(midpt+d, d) < n
184+
searchords = (midpt + 1):last(searchords)
185+
else
186+
searchords = first(searchords):midpt
187+
end
188+
end
189+
order = searchords[]
178190
return Block(order+1)
179191
end
180192

0 commit comments

Comments
 (0)