Skip to content

Commit a3c19e2

Browse files
committed
ifelse reduct fixes
1 parent 2bd9b58 commit a3c19e2

File tree

5 files changed

+64
-19
lines changed

5 files changed

+64
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.86"
4+
version = "0.12.87"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/codegen/lower_threads.jl

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -213,25 +213,38 @@ end
213213
end
214214

215215
function outer_reduct_combine_expressions(ls::LoopSet, retv)
216-
gf = GlobalRef(Core, :getfield)
217-
q = Expr(:block, :(var"#load#thread#ret#" = $gf(ThreadingUtilities.load(var"#thread#ptr#", typeof($retv), $(reg_size(ls))),2,false)))
218-
# push!(q.args, :(@show var"#load#thread#ret#"))
219-
for (i,or) enumerate(ls.outer_reductions)
220-
op = ls.operations[or]
221-
var = name(op)
222-
mvar = mangledvar(op)
223-
instr = instruction(op)
224-
out = Symbol(mvar, "##onevec##")
225-
instrcall = Expr(:call, lv(reduce_to_onevecunroll(instr)))
226-
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), out))
227-
if length(ls.outer_reductions) > 1
228-
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), Expr(:call, gf, Symbol("#load#thread#ret#"), i, false)))
229-
else
230-
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), Symbol("#load#thread#ret#")))
216+
gf = GlobalRef(Core, :getfield)
217+
q = Expr(:block, :(var"#load#thread#ret#" = $gf(ThreadingUtilities.load(var"#thread#ptr#", typeof($retv), $(reg_size(ls))),2,false)))
218+
# push!(q.args, :(@show var"#load#thread#ret#"))
219+
for (i,or) enumerate(ls.outer_reductions)
220+
op = ls.operations[or]
221+
var = name(op)
222+
mvar = mangledvar(op)
223+
instr = instruction(op)
224+
out = Symbol(mvar, "##onevec##")
225+
instrcall::Expr = if instr.instr :ifelse
226+
Expr(:call, lv(reduce_to_onevecunroll(instr)))
227+
else
228+
reductexpr::Expr = let ls = ls #FIXME: this should be tested...
229+
ifelse_reduction(:IfElseReduced, op) do opv
230+
@assert length(ls.outer_reductions) > 1
231+
j = findfirst(==(identifier(opv)), ls.outer_reductions)
232+
otherarg = Expr(:call, lv(:vecmemaybe), Expr(:call, GlobalRef(Core,:getfield), Symbol("#load#thread#ret#"), j, false))
233+
Expr(:call, lv(:vecmemaybe), Symbol(mangledvar(opv), "##onevec##")), (otherarg,)
231234
end
232-
push!(q.args, Expr(:(=), out, Expr(:call, :data, instrcall)))
235+
end
236+
Expr(:call, reductexpr)
233237
end
234-
q
238+
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), out))
239+
if length(ls.outer_reductions) > 1
240+
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), Expr(:call, gf, Symbol("#load#thread#ret#"), i, false)))
241+
else
242+
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), Symbol("#load#thread#ret#")))
243+
end
244+
push!(q.args, Expr(:(=), out, Expr(:call, :data, instrcall)))
245+
# push!(q.args, Expr(:(=), out, :(@show $data($instrcall))))
246+
end
247+
q
235248
end
236249

237250
function thread_loop_summary!(ls::LoopSet, ua::UnrollArgs, threadedloop::Loop, issecondthreadloop::Bool)

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ Execute an `@turbo` block. The block's code is represented via the arguments:
713713
@generated function _turbo_!(
714714
::Val{var"#UNROLL#"}, ::Val{var"#OPS#"}, ::Val{var"#ARF#"}, ::Val{var"#AM#"}, ::Val{var"#LPSYM#"}, ::Val{Tuple{var"#LB#",var"#V#"}}, var"#flattened#var#arguments#"::Vararg{Any,var"#num#vargs#"}
715715
) where {var"#UNROLL#", var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#", var"#V#", var"#num#vargs#"}
716-
1 + 1 # Irrelevant line you can comment out/in to force recompilation...
716+
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
717717
ls = _turbo_loopset(var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#".parameters, var"#V#".parameters, var"#UNROLL#")
718718
pushfirst!(ls.preamble.args, :(var"#lv#tuple#args#" = reassemble_tuple(Tuple{var"#LB#",var"#V#"}, var"#flattened#var#arguments#")))
719719
post = hoist_constant_memory_accesses!(ls)

test/ifelsemasks.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,14 @@ T = Float32
468468
end
469469
end
470470

471+
function absmax_tturbo(a) # LV threaded
472+
result=zero(eltype(a))
473+
@tturbo for i in 1:length(a)
474+
abs(a[i])>result && (result=abs(a[i]))
475+
end
476+
result
477+
end
478+
471479
function findminturbo(x)
472480
indmin = 0
473481
minval = typemax(eltype(x))
@@ -478,6 +486,16 @@ T = Float32
478486
end
479487
minval, indmin
480488
end
489+
function findmintturbo(x)
490+
indmin = 0
491+
minval = typemax(eltype(x))
492+
@tturbo for i eachindex(x)
493+
newmin = x[i] < minval
494+
minval = newmin ? x[i] : minval
495+
indmin = newmin ? i : indmin
496+
end
497+
minval, indmin
498+
end
481499
function findminturbo_u4(x)
482500
indmin = 0
483501
minval = typemax(eltype(x))
@@ -497,9 +515,21 @@ T = Float32
497515
mv, mi = findminturbo(a)
498516
mv2, mi2 = findminturbo_u4(a)
499517
@test mv == a[mi] == minimum(a) == mv2 == a[mi2]
518+
for n in 1000:1000:10_000
519+
x = rand(-T(100):T(100), n);
520+
@test absmax_tturbo(x) == mapreduce(abs, max, x)
521+
mv, mi = findmintturbo(x)
522+
@test mv == x[mi] == minimum(x)
523+
end
500524
else
501525
a = rand(T, N); b = rand(T, N);
502526
@test findmin(a) == findminturbo(a) == findminturbo_u4(a)
527+
for n in 1000:1000:10_000
528+
x = randn(T, n);
529+
@test absmax_tturbo(x) == mapreduce(abs, max, x)
530+
mv, mi = findmintturbo(x)
531+
@test mv == x[mi] == minimum(x)
532+
end
503533
end;
504534
c1 = similar(a); c2 = similar(a);
505535

test/outer_reductions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
function reg_term(omega, B = size(omega,2); alpha=0.01)
23
reg = 0.0
34
for b in 1:B
@@ -199,5 +200,6 @@ end
199200
end
200201
omega = rand(87,87);
201202
@test reg_term(omega) reg_term_turbo(omega)
203+
202204
end
203205

0 commit comments

Comments
 (0)