Skip to content

Commit a45eacc

Browse files
committed
Fixed a bug in lowering code where a store is not unrolled but the parent is.
1 parent 74801ee commit a45eacc

File tree

9 files changed

+553
-536
lines changed

9 files changed

+553
-536
lines changed

src/LoopVectorization.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ include("add_constants.jl")
2929
include("add_ifelse.jl")
3030
include("broadcast.jl")
3131
include("determinestrategy.jl")
32+
include("lower_compute.jl")
33+
include("lower_constant.jl")
34+
include("lower_memory_common.jl")
35+
include("lower_load.jl")
36+
include("lower_store.jl")
3237
include("lowering.jl")
3338
include("condense_loopset.jl")
3439
include("reconstruct_loopset.jl")

src/costs.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,11 @@ const REDUCTION_CLASS = Dict{Instruction,Float64}(
185185
Instruction(:vfmadd_fast) => ADDITIVE_IN_REDUCTIONS,
186186
Instruction(:vfmsub_fast) => ADDITIVE_IN_REDUCTIONS,
187187
Instruction(:vfnmadd_fast) => ADDITIVE_IN_REDUCTIONS,
188-
Instruction(:vfnmsub_fast) => ADDITIVE_IN_REDUCTIONS
188+
Instruction(:vfnmsub_fast) => ADDITIVE_IN_REDUCTIONS,
189+
Instruction(:reduced_add) => ADDITIVE_IN_REDUCTIONS,
190+
Instruction(:reduced_prod) => MULTIPLICATIVE_IN_REDUCTIONS,
191+
Instruction(:reduced_all) => ALL,
192+
Instruction(:reduced_any) => ANY
189193
)
190194
reduction_instruction_class(instr::Symbol) = get(REDUCTION_CLASS, Instruction(instr), NaN)
191195
reduction_instruction_class(instr::Instruction) = get(REDUCTION_CLASS, instr, NaN)

src/lower_compute.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# A compute op needs to know the unrolling and tiling status of each of its parents.
2+
#
3+
function lower_compute_scalar!(
4+
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
5+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
6+
)
7+
lower_compute!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, false)
8+
end
9+
function lower_compute_unrolled!(
10+
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
11+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
12+
)
13+
lower_compute!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, true)
14+
end
15+
struct FalseCollection end
16+
Base.getindex(::FalseCollection, i...) = false
17+
function lower_compute!(
18+
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
19+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing,
20+
opunrolled = unrolled loopdependencies(op)
21+
)
22+
23+
var = op.variable
24+
mvar = mangledvar(op)
25+
parents_op = parents(op)
26+
nparents = length(parents_op)
27+
parentstiled = if suffix === nothing
28+
optiled = false
29+
tiledouterreduction = -1
30+
FalseCollection()
31+
else
32+
tiledouterreduction = isouterreduction(op)
33+
suffix_ = Symbol(suffix, :_)
34+
if tiledouterreduction == -1
35+
mvar = Symbol(mvar, suffix_)
36+
end
37+
optiled = true
38+
[tiled loopdependencies(opp) for opp parents_op]
39+
end
40+
parentsunrolled = [unrolled loopdependencies(opp) || unrolled reducedchildren(opp) for opp parents_op]
41+
if !opunrolled && any(parentsunrolled)
42+
parents_op = copy(parents_op)
43+
for i eachindex(parentsunrolled)
44+
parentsunrolled[i] || continue
45+
parentsunrolled[i] = false
46+
parentop = parents_op[i]
47+
newparentop = Operation(
48+
parentop.identifier, gensym(parentop.variable), parentop.elementbytes, parentop.instruction, parentop.node_type,
49+
parentop.dependencies, parentop.reduced_deps, parentop.parents, parentop.ref, parentop.reduced_children
50+
)
51+
parentname = mangledvar(parentop)
52+
newparentname = mangledvar(newparentop)
53+
parents_op[i] = newparentop
54+
if parentstiled[i]
55+
parentname = Symbol(parentname, suffix_)
56+
newparentname = Symbol(newparentname, suffix_)
57+
end
58+
for u 0:U-1
59+
push!(q.args, Expr(:(=), Symbol(newparentname, u), Symbol(parentname, u)))
60+
end
61+
# @show #instruction(newparentop)
62+
reduce_expr!(q, newparentname, Instruction(reduction_to_single_vector(instruction(newparentop))), U)
63+
push!(q.args, Expr(:(=), newparentname, Symbol(newparentname, 0)))
64+
end
65+
end
66+
instr = op.instruction
67+
# cache unroll and tiling check of parents
68+
# not broadcasted, because we use frequent checks of individual bools
69+
# making BitArrays inefficient.
70+
# parentsyms = [opp.variable for opp ∈ parents(op)]
71+
Uiter = opunrolled ? U - 1 : 0
72+
maskreduct = mask !== nothing && isreduction(op) && vectorized reduceddependencies(op) #any(opp -> opp.variable === var, parents_op)
73+
# if a parent is not unrolled, the compiler should handle broadcasting CSE.
74+
# because unrolled/tiled parents result in an unrolled/tiled dependendency,
75+
# we handle both the tiled and untiled case here.
76+
# bajillion branches that go the same way on each iteration
77+
# but smaller function is probably worthwhile. Compiler could theoreically split anyway
78+
# but I suspect that the branches are so cheap compared to the cost of everything else going on
79+
# that smaller size is more advantageous.
80+
modsuffix = 0
81+
for u 0:Uiter
82+
instrcall = Expr(instr) # Expr(:call, instr)
83+
varsym = if tiledouterreduction > 0 # then suffix !== nothing
84+
modsuffix = ((u + suffix*U) & 3)
85+
Symbol(mvar, modsuffix)
86+
elseif opunrolled
87+
Symbol(mvar, u)
88+
else
89+
mvar
90+
end
91+
for n 1:nparents
92+
parent = mangledvar(parents_op[n])
93+
if n == tiledouterreduction
94+
parent = Symbol(parent, modsuffix)
95+
else
96+
if parentstiled[n]
97+
parent = Symbol(parent, suffix_)
98+
end
99+
if parentsunrolled[n]
100+
parent = Symbol(parent, u)
101+
end
102+
end
103+
push!(instrcall.args, parent)
104+
end
105+
if maskreduct && (u == Uiter || unrolled !== vectorized) # only mask last
106+
push!(q.args, Expr(:(=), varsym, Expr(:call, lv(:vifelse), mask, instrcall, varsym)))
107+
else
108+
push!(q.args, Expr(:(=), varsym, instrcall))
109+
end
110+
end
111+
end
112+
113+

src/lower_constant.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
2+
function lower_constant!(
3+
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, U::Int,
4+
suffix::Union{Nothing,Int}, mask::Any = nothing
5+
)
6+
instruction = op.instruction
7+
mvar = variable_name(op, suffix)
8+
constsym = instruction.instr
9+
# constsym = mangledvar(op)
10+
if vectorized loopdependencies(op) || vectorized reducedchildren(op) || vectorized reduceddependencies(op)
11+
# call = Expr(:call, lv(:vbroadcast), W, mangledvar(op))
12+
call = Expr(:call, lv(:vbroadcast), W, constsym)
13+
if unrolled loopdependencies(op) || unrolled reducedchildren(op) || unrolled reduceddependencies(op)
14+
for u 0:U-1
15+
push!(q.args, Expr(:(=), Symbol(mvar, u), call))
16+
end
17+
else
18+
push!(q.args, Expr(:(=), mvar, call))
19+
end
20+
else
21+
if unrolled loopdependencies(op) || unrolled reducedchildren(op) || unrolled reduceddependencies(op)
22+
for u 0:U-1
23+
push!(q.args, Expr(:(=), Symbol(mvar, u), constsym))
24+
end
25+
else
26+
push!(q.args, Expr(:(=), mvar, constsym))
27+
end
28+
end
29+
nothing
30+
end
31+
32+
33+
function lower_licm_constants!(ls::LoopSet)
34+
ops = operations(ls)
35+
for (id, sym) ls.preamble_symsym
36+
setconstantop!(ls, ops[id], sym)
37+
end
38+
for (id,intval) ls.preamble_symint
39+
setop!(ls, ops[id], Expr(:call, lv(:sizeequivalentint), ls.T, intval))
40+
end
41+
for (id,floatval) ls.preamble_symfloat
42+
setop!(ls, ops[id], Expr(:call, lv(:sizeequivalentfloat), ls.T, intval))
43+
end
44+
for id ls.preamble_zeros
45+
setop!(ls, ops[id], Expr(:call, :zero, ls.T))
46+
end
47+
for id ls.preamble_ones
48+
setop!(ls, ops[id], Expr(:call, :one, ls.T))
49+
end
50+
end
51+
52+
53+

src/lower_load.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
function pushvectorload!(q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, W::Symbol, mask, vecnotunrolled::Bool)
2+
@unpack u, unrolled = td
3+
ptr = refname(op)
4+
name, mo = name_memoffset(var, op, td, W, vecnotunrolled)
5+
instrcall = Expr(:call, lv(:vload), ptr, mo)
6+
if mask !== nothing && (vecnotunrolled || u == U - 1)
7+
push!(instrcall.args, mask)
8+
end
9+
push!(q.args, Expr(:(=), name, instrcall))
10+
end
11+
function lower_load_scalar!(
12+
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
13+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
14+
)
15+
loopdeps = loopdependencies(op)
16+
@assert vectorized loopdeps
17+
var = variable_name(op, suffix)
18+
ptr = refname(op)
19+
isunrolled = unrolled loopdeps
20+
U = isunrolled ? U : 1
21+
for u zero(Int32):Base.unsafe_trunc(Int32,U-1)
22+
varname = varassignname(var, u, isunrolled)
23+
td = UnrollArgs(u, unrolled, tiled, suffix)
24+
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:load), ptr, mem_offset_u(op, td))))
25+
end
26+
nothing
27+
end
28+
function lower_load_vectorized!(
29+
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
30+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
31+
)
32+
loopdeps = loopdependencies(op)
33+
@assert vectorized loopdeps
34+
if unrolled loopdeps
35+
umin = zero(Int32)
36+
U = U
37+
else
38+
umin = -one(Int32)
39+
U = 0
40+
end
41+
# Urange = unrolled ∈ loopdeps ? 0:U-1 : 0
42+
var = variable_name(op, suffix)
43+
vecnotunrolled = vectorized !== unrolled
44+
for u umin:Base.unsafe_trunc(Int32,U-1)
45+
td = UnrollArgs(u, unrolled, tiled, suffix)
46+
pushvectorload!(q, op, var, td, U, W, mask, vecnotunrolled)
47+
end
48+
nothing
49+
end
50+
51+
# TODO: this code should be rewritten to be more "orthogonal", so that we're just combining separate pieces.
52+
# Using sentinel values (eg, T = -1 for non tiling) in part to avoid recompilation.
53+
function lower_load!(
54+
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
55+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
56+
)
57+
if vectorized loopdependencies(op)
58+
lower_load_vectorized!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
59+
else
60+
lower_load_scalar!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
61+
end
62+
end
63+
64+
65+

src/lower_memory_common.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
2+
struct UnrollArgs{T}
3+
u::Int32
4+
unrolled::Symbol
5+
tiled::Symbol
6+
suffix::T
7+
end
8+
function parentind(ind::Symbol, op::Operation)
9+
for (id,opp) enumerate(parents(op))
10+
name(opp) === ind && return id
11+
end
12+
-1
13+
end
14+
function symbolind(ind::Symbol, op::Operation, td::UnrollArgs)
15+
id = parentind(ind, op)
16+
id == -1 && return Expr(:call, :-, ind, one(Int32))
17+
@unpack u, unrolled, tiled, suffix = td
18+
parent = parents(op)[id]
19+
pvar = if tiled loopdependencies(parent)
20+
variable_name(parent, suffix)
21+
else
22+
mangledvar(parent)
23+
end
24+
pvar = unrolled loopdependencies(parent) ? Symbol(pvar, u) : pvar
25+
Expr(:call, :-, pvar, one(Int32))
26+
end
27+
function mem_offset(op::Operation, td::UnrollArgs)
28+
# @assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
29+
ret = Expr(:tuple)
30+
indices = getindices(op)
31+
loopedindex = op.ref.loopedindex
32+
start = (first(indices) === Symbol("##DISCONTIGUOUSSUBARRAY##")) + 1
33+
for (n,ind) enumerate(@view(indices[start:end]))
34+
if ind isa Int
35+
push!(ret.args, ind)
36+
elseif loopedindex[n]
37+
push!(ret.args, ind)
38+
else
39+
push!(ret.args, symbolind(ind, op, td))
40+
end
41+
end
42+
ret
43+
end
44+
function mem_offset_u(op::Operation, td::UnrollArgs)
45+
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
46+
@unpack unrolled, u = td
47+
incr = u
48+
ret = Expr(:tuple)
49+
indices = getindices(op)
50+
loopedindex = op.ref.loopedindex
51+
if incr == 0
52+
return mem_offset(op, td)
53+
# append_inds!(ret, indices, loopedindex)
54+
else
55+
start = (first(indices) === Symbol("##DISCONTIGUOUSSUBARRAY##")) + 1
56+
for (n,ind) enumerate(@view(indices[start:end]))
57+
if ind isa Int
58+
push!(ret.args, ind)
59+
elseif ind === unrolled
60+
push!(ret.args, Expr(:call, :+, ind, incr))
61+
elseif loopedindex[n]
62+
push!(ret.args, ind)
63+
else
64+
push!(ret.args, symbolind(ind, op, td))
65+
end
66+
end
67+
end
68+
ret
69+
end
70+
function mem_offset_u(op::Operation, td::UnrollArgs, mul::Symbol)
71+
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
72+
@unpack unrolled, u = td
73+
incr = u
74+
ret = Expr(:tuple)
75+
indices = getindices(op)
76+
loopedindex = op.ref.loopedindex
77+
if incr == 0
78+
return mem_offset(op, td)
79+
# append_inds!(ret, indices, loopedindex)
80+
else
81+
start = (first(indices) === Symbol("##DISCONTIGUOUSSUBARRAY##")) + 1
82+
for (n,ind) enumerate(@view(indices[start:end]))
83+
if ind isa Int
84+
push!(ret.args, ind)
85+
elseif ind === unrolled
86+
push!(ret.args, Expr(:call, :+, ind, Expr(:call, lv(:valmul), mul, incr)))
87+
elseif loopedindex[n]
88+
push!(ret.args, ind)
89+
else
90+
push!(ret.args, symbolind(ind, op, td))
91+
end
92+
end
93+
end
94+
ret
95+
end
96+
97+
# function add_expr(q, incr)
98+
# if q.head === :call && q.args[2] === :+
99+
# qc = copy(q)
100+
# push!(qc.args, incr)
101+
# qc
102+
# else
103+
# Expr(:call, :+, q, incr)
104+
# end
105+
# end
106+
function varassignname(var::Symbol, u::Int32, isunrolled::Bool)
107+
isunrolled ? Symbol(var, u) : var
108+
end
109+
# name_memoffset only gets called when vectorized
110+
function name_memoffset(var::Symbol, op::Operation, td::UnrollArgs, W::Symbol, vecnotunrolled::Bool)
111+
@unpack u, unrolled = td
112+
if u < 0 # sentinel value meaning not unrolled
113+
name = var
114+
mo = mem_offset(op, td)
115+
else
116+
name = Symbol(var, u)
117+
mo = vecnotunrolled ? mem_offset_u(op, td) : mem_offset_u(op, td, W)
118+
end
119+
name, mo
120+
end
121+

0 commit comments

Comments
 (0)