Skip to content

Commit 4748b14

Browse files
authored
fix: #395, check if loop body is a block and warp it if not (#396)
* fix:#395, check if loop body is a block and warp it if not * refactor: test head with isexpr, use `@turbo` * Bump version to 0.12.107 * fix: missing `!`
1 parent 23efbf4 commit 4748b14

File tree

4 files changed

+67
-1
lines changed

4 files changed

+67
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.106"
4+
version = "0.12.107"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/constructors.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,24 @@ function process_args(
156156
end
157157
inline, check_empty, u₁, u₂, v, threads, warncheckarg
158158
end
159+
# check if the body of loop is a block, if not convert it to a block issue#395
160+
function check_loopbody!(q)
161+
if q isa Expr && q.head == :for
162+
if !Meta.isexpr(q.args[2], :block)
163+
q.args[2] = Expr(:block, q.args[2])
164+
else
165+
for arg in q.args[2].args
166+
check_loopbody!(arg) # check recursively for inner loop
167+
end
168+
end
169+
end
170+
return q
171+
end
172+
159173
function turbo_macro(mod, src, q, args...)
160174
q = macroexpand(mod, q)
161175
if q.head === :for
176+
check_loopbody!(q)
162177
ls = LoopSet(q, mod)
163178
inline, check_empty, u₁, u₂, v, threads, warncheckarg = process_args(args)
164179
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg))

test/grouptests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ const START_TIME = time()
88

99
@time if LOOPVECTORIZATION_TEST == "all" || LOOPVECTORIZATION_TEST == "part1"
1010
@time include("broadcast.jl")
11+
@time include("parsing_inputs.jl")
1112
end
1213

1314
@time if LOOPVECTORIZATION_TEST == "all" || LOOPVECTORIZATION_TEST == "part2"

test/parsing_inputs.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using LoopVectorization, Test, ArrayInterface
2+
3+
# macros for generate loops whose body is not a block
4+
macro gen_loop_issue395(ex)
5+
sym, ind = ex.args
6+
loop_body = :(ret[$ind] = $sym[$ind])
7+
loop = Expr(:for, :($ind = axes($sym, 1)), loop_body)
8+
return esc(:(@turbo $loop))
9+
end
10+
macro gen_single_loop(B, A)
11+
loop_body = :($B[i] = $A[i])
12+
loop = Expr(:for, :(i = indices(($B, $A), 1)), loop_body)
13+
return esc(:(@turbo $loop))
14+
end
15+
macro gen_nest_loop(C, A, B)
16+
loop_body = :($C[i, j] = $A[i] * $B[j])
17+
loop_head = Expr(:block, :(j = indices(($C, $B), (2, 1))), :(i = indices(($C, $A), 1)))
18+
loop = Expr(:for, loop_head, loop_body)
19+
return esc(:(@turbo $loop))
20+
end
21+
macro gen_A_mul_B(C, A, B)
22+
inner_body = :(Cji += $A[j, k] * $B[k, i])
23+
inner_loop = Expr(:for, :(k = indices(($A, $B), (2, 1))), inner_body)
24+
loop = :(
25+
for i in indices(($C, $B), 2), j in indices(($C, $A), 1)
26+
Cji = zero(eltype($C))
27+
$inner_loop
28+
$C[j, i] = Cji
29+
end
30+
)
31+
return esc(:(@turbo $loop))
32+
end
33+
34+
@testset "check_block, #395" begin
35+
A = rand(4)
36+
B = rand(4)
37+
C = rand(4, 4)
38+
D = zeros(4)
39+
E = zeros(4, 4)
40+
F = zeros(4, 4)
41+
ret = zeros(4)
42+
@gen_single_loop D A
43+
@gen_loop_issue395 B[i]
44+
@gen_nest_loop E A B
45+
@gen_A_mul_B F C E
46+
@test D == A
47+
@test ret == B
48+
@test E == A * B'
49+
@test F == C * E
50+
end

0 commit comments

Comments
 (0)