Skip to content

Commit e47454f

Browse files
committed
fix: name iter and disable single itersym
1 parent a695eba commit e47454f

File tree

2 files changed

+37
-27
lines changed

2 files changed

+37
-27
lines changed

src/constructors.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ end
7878

7979

8080
function LoopSet(q::Expr, mod::Symbol = :Main)
81-
contract_pass!(q)
8281
ls = LoopSet(mod)
82+
check_inputs!(q, ls.prepreamble)
83+
contract_pass!(q)
8384
copyto!(ls, q)
8485
resize!(ls.loop_order, num_loops(ls))
8586
ls
@@ -158,32 +159,32 @@ function process_args(
158159
end
159160
# check if the body of loop is a block, if not convert it to a block issue#395
160161
# and check if the range of loop is an enumerate, if it is replace it, issue#393
161-
function check_inputs!(q)
162+
function check_inputs!(q, prepreamble)
162163
if Meta.isexpr(q, :for)
163164
if !Meta.isexpr(q.args[2], :block)
164165
q.args[2] = Expr(:block, q.args[2])
165-
replace_enumerate!(q) # must after warp block
166+
replace_enumerate!(q, prepreamble) # must after warp block
166167
else # maybe inner loops in block
167-
replace_enumerate!(q)
168+
replace_enumerate!(q, prepreamble)
168169
for arg in q.args[2].args
169-
check_inputs!(arg) # check recursively for inner loop
170+
check_inputs!(arg, prepreamble) # check recursively for inner loop
170171
end
171172
end
172173
end
173174
return q
174175
end
175-
function replace_enumerate!(q)
176+
function replace_enumerate!(q, prepreamble)
176177
looprange = q.args[1]
177178
if Meta.isexpr(looprange, :block)
178179
for i in 1:length(looprange.args)
179-
convert_single_enumerate!(q, i)
180+
replace_single_enumerate!(q, prepreamble, i)
180181
end
181182
else
182-
convert_single_enumerate!(q)
183+
replace_single_enumerate!(q, prepreamble)
183184
end
184185
return q
185186
end
186-
function convert_single_enumerate!(q, i=nothing)
187+
function replace_single_enumerate!(q, prepreamble, i=nothing)
187188
if isnothing(i) # not nest loop
188189
looprange, body = q.args[1], q.args[2]
189190
else # nest loop
@@ -192,7 +193,13 @@ function convert_single_enumerate!(q, i=nothing)
192193
@assert Meta.isexpr(looprange, :(=), 2)
193194
itersyms, r = looprange.args
194195
if Meta.isexpr(r, :call, 2) && r.args[1] == :enumerate
195-
iter = r.args[2]
196+
_iter = r.args[2]
197+
if _iter isa Symbol
198+
iter = _iter
199+
else # name complex expr
200+
iter = gensym(:iter)
201+
push!(prepreamble.args, :($iter = $_iter))
202+
end
196203
if Meta.isexpr(itersyms, :tuple, 2)
197204
indsym, varsym = itersyms.args[1]::Symbol, itersyms.args[2]::Symbol
198205
_replace_looprange!(q, i, indsym, iter)
@@ -201,14 +208,8 @@ function convert_single_enumerate!(q, i=nothing)
201208
indsym = itersyms.args[1]::Symbol
202209
_replace_looprange!(q, i, indsym, iter)
203210
elseif itersyms isa Symbol # if itersyms are not unbox in loop range
204-
# generate new symbols to avoid name conflict
205-
indsym = gensym(Symbol(itersyms, :_ind))
206-
varsym = gensym(Symbol(itersyms, :_var))
207-
_replace_looprange!(q, i, indsym, iter)
208-
pushfirst!(body.args,
209-
:($varsym = $iter[$indsym + firstindex($iter) - 1]),
210-
:($itersyms = ($indsym, $varsym)), # regroud the indsym and varsym for user
211-
)
211+
throw(ArgumentError("`for $itersyms in enumerate($r)` is not supported,
212+
please use `for ($(itersyms)_i, $(itersyms)_v) in enumerate($r)` instead."))
212213
else
213214
throw(ArgumentError("Don't know how to handle expression `$itersyms`."))
214215
end
@@ -221,7 +222,6 @@ _replace_looprange!(q, i::Int, indsym, iter) = q.args[1].args[i] = :($indsym = B
221222
function turbo_macro(mod, src, q, args...)
222223
q = macroexpand(mod, q)
223224
if q.head === :for
224-
check_inputs!(q)
225225
ls = LoopSet(q, mod)
226226
inline, check_empty, u₁, u₂, v, threads, warncheckarg = process_args(args)
227227
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg))

test/parsing_inputs.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using LoopVectorization, Test, ArrayInterface
2+
using LoopVectorization: check_inputs!
23

34
# macros for generate loops whose body is not a block
45
macro gen_loop_issue395(ex)
@@ -52,18 +53,27 @@ end
5253
@testset "enumerate, #393" begin
5354
A = zeros(4)
5455
B = zeros(4)
55-
C = zeros(4)
56-
@turbo for (i, x) in enumerate(A)
57-
A[i] = i + x
56+
C = zeros(4, 4)
57+
D = zeros(4, 4)
58+
@turbo for (i, x) in enumerate(1:4)
59+
A[i] = x
5860
end
5961
@turbo for (i,) in enumerate(B)
6062
B[i] += 1
6163
end
62-
@turbo for ix in enumerate(C)
63-
C[ix[1]] = ix[1] + ix[2]
64+
@turbo for (j, Aj) in enumerate(A), (i, Bi) in enumerate(B)
65+
C[i, j] = Aj * Bi
66+
end
67+
@turbo for (j, Bj) in enumerate(B)
68+
for (i, Ai) in enumerate(A)
69+
D[i, j] = Ai * Bj
70+
end
6471
end
65-
@test_throws ArgumentError @turbo for () in enumerate(A) end
6672
@test A == 1:4
67-
@test B == 1:4
68-
@test C == 1:4
73+
@test B == ones(4)
74+
@test A .* B' == C' == D
75+
@test_throws ArgumentError check_inputs!(:(for ix in enumerate(A)
76+
A[ix[1]] = ix[1] + ix[2]
77+
end), Any[])
78+
@test_throws ArgumentError check_inputs!(:(for () in enumerate(A); end), Any[])
6979
end

0 commit comments

Comments
 (0)