Skip to content

Commit 2c87fce

Browse files
committed
Only use vf(n)m(add/sub)d231 functions in tiled kernel reductions.
1 parent 2ae8a78 commit 2c87fce

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.6.20"
4+
version = "0.6.21"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -13,9 +13,9 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1313

1414
[compat]
1515
Parameters = "0"
16-
SIMDPirates = "~0.6.5"
16+
SIMDPirates = "~0.6.6"
1717
SLEEFPirates = "~0.4"
18-
VectorizationBase = "~0.7"
18+
VectorizationBase = "~0.7.1"
1919
julia = "1.1"
2020

2121
[extras]

src/lower_compute.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,14 @@ function lower_compute!(
5656
# making BitArrays inefficient.
5757
# parentsyms = [opp.variable for opp ∈ parents(op)]
5858
Uiter = opunrolled ? U - 1 : 0
59-
maskreduct = mask !== nothing && isreduction(op) && vectorized reduceddependencies(op) #any(opp -> opp.variable === var, parents_op)
59+
isreduct = isreduction(op)
60+
if !isnothing(suffix) && isreduct
61+
instrfid = findfirst(isequal(instr.instr), (:vfmadd_fast, :vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
62+
if instrfid !== nothing
63+
instr = Instruction((:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)[instrfid])
64+
end
65+
end
66+
maskreduct = mask !== nothing && isreduct && vectorized reduceddependencies(op) #any(opp -> opp.variable === var, parents_op)
6067
# if a parent is not unrolled, the compiler should handle broadcasting CSE.
6168
# because unrolled/tiled parents result in an unrolled/tiled dependendency,
6269
# we handle both the tiled and untiled case here.
@@ -90,10 +97,15 @@ function lower_compute!(
9097
push!(instrcall.args, parent)
9198
end
9299
if maskreduct && (u == Uiter || unrolled !== vectorized) # only mask last
93-
push!(q.args, Expr(:(=), varsym, Expr(:call, lv(:vifelse), mask, instrcall, varsym)))
94-
else
95-
push!(q.args, Expr(:(=), varsym, instrcall))
100+
if last(instrcall.args) == varsym
101+
pushfirst!(instrcall.args, lv(:vifelse))
102+
insert!(instrcall.args, 3, mask)
103+
else
104+
push!(q.args, Expr(:(=), varsym, Expr(:call, lv(:vifelse), mask, instrcall, varsym)))
105+
continue
106+
end
96107
end
108+
push!(q.args, Expr(:(=), varsym, instrcall))
97109
end
98110
end
99111

0 commit comments

Comments
 (0)