Skip to content

Commit a695eba

Browse files
committed
feat: replace enumerate
1 parent d077140 commit a695eba

File tree

2 files changed

+72
-5
lines changed

2 files changed

+72
-5
lines changed

src/constructors.jl

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,23 +157,71 @@ function process_args(
157157
inline, check_empty, u₁, u₂, v, threads, warncheckarg
158158
end
159159
# check if the body of loop is a block, if not convert it to a block issue#395
160-
function check_loopbody!(q)
161-
if q isa Expr && q.head == :for
160+
# and check if the range of loop is an enumerate, if it is replace it, issue#393
161+
function check_inputs!(q)
162+
if Meta.isexpr(q, :for)
162163
if !Meta.isexpr(q.args[2], :block)
163164
q.args[2] = Expr(:block, q.args[2])
164-
else
165+
replace_enumerate!(q) # must after warp block
166+
else # maybe inner loops in block
167+
replace_enumerate!(q)
165168
for arg in q.args[2].args
166-
check_loopbody!(arg) # check recursively for inner loop
169+
check_inputs!(arg) # check recursively for inner loop
167170
end
168171
end
169172
end
170173
return q
171174
end
175+
function replace_enumerate!(q)
176+
looprange = q.args[1]
177+
if Meta.isexpr(looprange, :block)
178+
for i in 1:length(looprange.args)
179+
convert_single_enumerate!(q, i)
180+
end
181+
else
182+
convert_single_enumerate!(q)
183+
end
184+
return q
185+
end
186+
function convert_single_enumerate!(q, i=nothing)
187+
if isnothing(i) # not nest loop
188+
looprange, body = q.args[1], q.args[2]
189+
else # nest loop
190+
looprange, body = q.args[1].args[i], q.args[2]
191+
end
192+
@assert Meta.isexpr(looprange, :(=), 2)
193+
itersyms, r = looprange.args
194+
if Meta.isexpr(r, :call, 2) && r.args[1] == :enumerate
195+
iter = r.args[2]
196+
if Meta.isexpr(itersyms, :tuple, 2)
197+
indsym, varsym = itersyms.args[1]::Symbol, itersyms.args[2]::Symbol
198+
_replace_looprange!(q, i, indsym, iter)
199+
pushfirst!(body.args, :($varsym = $iter[$indsym + firstindex($iter) - 1]))
200+
elseif Meta.isexpr(itersyms, :tuple, 1) # like `for (i,) in enumerate(...)`
201+
indsym = itersyms.args[1]::Symbol
202+
_replace_looprange!(q, i, indsym, iter)
203+
elseif itersyms isa Symbol # if itersyms are not unbox in loop range
204+
# generate new symbols to avoid name conflict
205+
indsym = gensym(Symbol(itersyms, :_ind))
206+
varsym = gensym(Symbol(itersyms, :_var))
207+
_replace_looprange!(q, i, indsym, iter)
208+
pushfirst!(body.args,
209+
:($varsym = $iter[$indsym + firstindex($iter) - 1]),
210+
:($itersyms = ($indsym, $varsym)), # regroud the indsym and varsym for user
211+
)
212+
else
213+
throw(ArgumentError("Don't know how to handle expression `$itersyms`."))
214+
end
215+
end
216+
return q
217+
end
218+
_replace_looprange!(q, ::Nothing, indsym, iter) = q.args[1] = :($indsym = Base.OneTo(length($iter)))
219+
_replace_looprange!(q, i::Int, indsym, iter) = q.args[1].args[i] = :($indsym = Base.OneTo(length($iter)))
172220

173221
function turbo_macro(mod, src, q, args...)
174222
q = macroexpand(mod, q)
175223
if q.head === :for
176-
check_loopbody!(q)
224+
check_inputs!(q)
177225
ls = LoopSet(q, mod)
178226
inline, check_empty, u₁, u₂, v, threads, warncheckarg = process_args(args)
179227
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg))

test/parsing_inputs.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,22 @@ end
4848
@test E == A * B'
4949
@test F == C * E
5050
end
51+
52+
@testset "enumerate, #393" begin
53+
A = zeros(4)
54+
B = zeros(4)
55+
C = zeros(4)
56+
@turbo for (i, x) in enumerate(A)
57+
A[i] = i + x
58+
end
59+
@turbo for (i,) in enumerate(B)
60+
B[i] += 1
61+
end
62+
@turbo for ix in enumerate(C)
63+
C[ix[1]] = ix[1] + ix[2]
64+
end
65+
@test_throws ArgumentError @turbo for () in enumerate(A) end
66+
@test A == 1:4
67+
@test B == 1:4
68+
@test C == 1:4
69+
end

0 commit comments

Comments
 (0)