Skip to content

Commit 2563d38

Browse files
committed
Fixed bug that made it possible for an op to point to the wrong ref after condense-> reconstucting, added CSE for constants, and improved cost evaluation of tiling. Fixes #45.
1 parent 891ab4d commit 2563d38

14 files changed

+295
-77
lines changed

src/add_compute.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ function add_reduction_update_parent!(
118118
reductsym = gensym(:reduction)
119119
reductinit = add_constant!(ls, gensym(:reductzero), loopdependencies(parent), reductsym, elementbytes, :numericconstant)
120120
if reduct_zero === :zero
121-
push!(ls.preamble_zeros, identifier(reductinit))
121+
push!(ls.preamble_zeros, (identifier(reductinit), IntOrFloat))
122122
elseif reduct_zero === :one
123-
push!(ls.preamble_ones, identifier(reductinit))
123+
push!(ls.preamble_ones, (identifier(reductinit), IntOrFloat))
124124
else
125125
if reductzero === :true || reductzero === :false
126126
pushpreamble!(ls, Expr(:(=), name(reductinit), reductzero))
@@ -185,11 +185,7 @@ function add_compute!(
185185
add_parent!(parents, deps, reduceddeps, ls, arg, elementbytes, position)
186186
end
187187
elseif arg ls.loopsymbols
188-
loopsym = gensym(arg)
189-
pushpreamble!(ls, Expr(:(=), loopsym, LoopValue()))
190-
loopsymop = add_simple_load!(ls, gensym(loopsym), ArrayReference(loopsym, [arg]), elementbytes, false)
191-
push!(ls.syms_aliasing_refs, name(loopsymop))
192-
push!(ls.refs_aliasing_syms, loopsymop.ref)
188+
loopsymop = add_loopvalue!(ls, arg, elementbytes)
193189
pushparent!(parents, deps, reduceddeps, loopsymop)
194190
else
195191
add_parent!(parents, deps, reduceddeps, ls, arg, elementbytes, position)

src/add_constants.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,33 @@ function add_constant!(ls::LoopSet, var, elementbytes::Int = 8)
88
pushpreamble!(ls, Expr(:(=), sym, var))
99
add_constant!(ls, sym, elementbytes)
1010
end
11+
function add_constant!(ls::LoopSet, var::Number, elementbytes::Int = 8)
12+
op = Operation(length(operations(ls)), gensym(:loopconstnumber), elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
13+
ops = operations(ls)
14+
typ = var isa Integer ? HardInt : HardFloat
15+
if iszero(var)
16+
for (id,typ_) ls.preamble_zeros
17+
(instruction(ops[id]) === LOOPCONSTANT && typ == typ_) && return ops[id]
18+
end
19+
push!(ls.preamble_zeros, (identifier(op),typ))
20+
elseif isone(var)
21+
for (id,typ_) ls.preamble_ones
22+
(instruction(ops[id]) === LOOPCONSTANT && typ == typ_) && return ops[id]
23+
end
24+
push!(ls.preamble_ones, (identifier(op),typ))
25+
elseif var isa Integer
26+
for (id,ivar) ls.preamble_symint
27+
(instruction(ops[id]) === LOOPCONSTANT && ivar == var) && return ops[id]
28+
end
29+
push!(ls.preamble_symint, (identifier(op), var))
30+
else#if var isa FloatX
31+
for (id,fvar) ls.preamble_symfloat
32+
(instruction(ops[id]) === LOOPCONSTANT && fvar == var) && return ops[id]
33+
end
34+
push!(ls.preamble_symfloat, (identifier(op), var))
35+
end
36+
pushop!(ls, op)
37+
end
1138
function add_constant!(ls::LoopSet, var::Symbol, mpref::ArrayReferenceMetaPosition, elementbytes::Int)
1239
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS, mpref.mref)
1340
add_vptr!(ls, op)
@@ -36,12 +63,6 @@ function add_constant!(
3663
ls::LoopSet, value::Number, deps::Vector{Symbol}, assignedsym::Symbol, elementbytes::Int
3764
)
3865
op = add_constant!(ls, gensym(Symbol(value)), deps, assignedsym, elementbytes, :numericconstant)
39-
if iszero(value)
40-
push!(ls.preamble_zeros, identifier(op))
41-
elseif isone(value)
42-
push!(ls.preamble_ones, identifier(op))
43-
else
44-
pushpreamble!(ls, op, value)
45-
end
66+
pushpreamble!(ls, op, value)
4667
op
4768
end

src/add_loads.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,22 @@ function add_load_getindex!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::In
5454
add_load!(ls, var, array, rawindices, elementbytes)
5555
end
5656

57+
function add_loopvalue!(ls::LoopSet, arg::Symbol, elementbytes::Int)
58+
# check for CSE opportunity
59+
loopsym = Symbol("##LOOPSYMBOL##", arg)
60+
ar = ArrayReference(loopsym, [arg])
61+
for op operations(ls)
62+
if isload(op) && op.ref.ref == ar
63+
return op
64+
end
65+
end
66+
pushpreamble!(ls, Expr(:(=), loopsym, LoopValue()))
67+
loopsymop = add_simple_load!(ls, gensym(loopsym), ar, elementbytes, false)
68+
push!(ls.syms_aliasing_refs, name(loopsymop))
69+
push!(ls.refs_aliasing_syms, loopsymop.ref)
70+
loopsymop
71+
end
72+
5773

5874
struct LoopValue end
5975
@inline VectorizationBase.stridedpointer(::LoopValue) = LoopValue()

src/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function add_broadcast!(
7373
# set Cₘₙ = 0
7474
# setC = add_constant!(ls, zero(promote_type(recursive_eltype(A), recursive_eltype(B))), cloopsyms, mC, elementbytes)
7575
setC = add_constant!(ls, gensym(:zero), cloopsyms, mC, elementbytes, :numericconstant)
76-
push!(ls.preamble_zeros, identifier(setC))
76+
push!(ls.preamble_zeros, (identifier(setC), IntOrFloat))
7777
# compute Cₘₙ += Aₘₖ * Bₖₙ
7878
reductop = Operation(
7979
ls, mC, elementbytes, :vmuladd, compute, reductdeps, Symbol[k], Operation[loadA, loadB, setC]

src/condense_loopset.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ isload(os::OperationStruct) = os.node_type == memload
5555
isstore(os::OperationStruct) = os.node_type == memstore
5656
iscompute(os::OperationStruct) = os.node_type == compute
5757
isconstant(os::OperationStruct) = os.node_type == constant
58-
function findmatchingarray(ls::LoopSet, array::Symbol)
58+
function findmatchingarray(ls::LoopSet, mref::ArrayReferenceMeta)
5959
id = 0x01
60-
for as ls.refs_aliasing_syms
61-
vptr(as) === array && return id
60+
for r ls.refs_aliasing_syms
61+
r == mref && return id
6262
id += 0x01
6363
end
6464
0x00
@@ -96,7 +96,7 @@ function OperationStruct!(varnames::Vector{Symbol}, ls::LoopSet, op::Operation)
9696
rd = reduceddeps_uint(ls, op)
9797
cd = childdeps_uint(ls, op)
9898
p = parents_uint(ls, op)
99-
array = accesses_memory(op) ? findmatchingarray(ls, vptr(op.ref)) : 0x00
99+
array = accesses_memory(op) ? findmatchingarray(ls, op.ref) : 0x00
100100
OperationStruct(
101101
ld, rd, cd, p, op.node_type, array, findindoradd!(varnames, name(op))
102102
)
@@ -228,7 +228,7 @@ end
228228
end
229229

230230
# Try to condense in type stable manner
231-
function generate_call(ls::LoopSet, IUT)
231+
function generate_call(ls::LoopSet, IUT, debug::Bool = false)
232232
operation_descriptions = Expr(:curly, :Tuple)
233233
varnames = Symbol[]
234234
for op operations(ls)
@@ -243,12 +243,19 @@ function generate_call(ls::LoopSet, IUT)
243243
argmeta = argmeta_and_consts_description(ls, arraysymbolinds)
244244
loop_bounds = loop_boundaries(ls)
245245
inline, U, T = IUT
246-
if inline
247-
q = Expr(:call, lv(:_avx_!), Expr(:call, Expr(:curly, :Val, (U,T))), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds)
246+
if inline | debug
247+
func = debug ? lv(:_avx_loopset) : lv(:_avx_!)
248+
q = Expr(
249+
:call, func, Expr(:call, Expr(:curly, :Val, (U,T))),
250+
operation_descriptions, arrayref_descriptions, argmeta, loop_bounds
251+
)
248252
foreach(ref -> push!(q.args, vptr(ref)), ls.refs_aliasing_syms)
249253
else
250254
arraydescript = Expr(:tuple)
251-
q = Expr(:call, lv(:__avx__!), Expr(:call, Expr(:curly, :Val, (U,T))), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds, arraydescript)
255+
q = Expr(
256+
:call, lv(:__avx__!), Expr(:call, Expr(:curly, :Val, (U,T))),
257+
operation_descriptions, arrayref_descriptions, argmeta, loop_bounds, arraydescript
258+
)
252259
for array ls.includedactualarrays
253260
push!(q.args, Expr(:call, lv(:unwrap_array), array))
254261
push!(arraydescript.args, Expr(:call, lv(:array_wrapper), array))
@@ -361,6 +368,11 @@ function setup_call_inline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
361368
append!(ls.preamble.args, q.args)
362369
ls.preamble
363370
end
371+
function setup_call_debug(ls::LoopSet)
372+
# avx_loopset(instr, ops, arf, AM, LB, vargs)
373+
pushpreamble!(ls, generate_call(ls, (true,zero(Int8),zero(Int8)), true))
374+
ls.preamble
375+
end
364376
function setup_call(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8))
365377
# We outline/inline at the macro level by creating/not creating an anonymous function.
366378
# The old API instead was based on inlining or not inline the generated function, but

src/constructors.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,7 @@ macro _avx(arg, q)
151151
end
152152

153153

154+
macro avx_debug(q)
155+
esc(LoopVectorization.setup_call_debug(LoopSet(q, __module__)))
156+
end
154157

src/determinestrategy.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ function evaluate_cost_unroll(
9191
rd = reduceddependencies(op)
9292
hasintersection(rd, nested_loop_syms[1:end-length(rd)]) && return Inf
9393
included_vars[id] = true
94-
94+
# @show op first(cost(op, vectorized, Wshift, size_T)), iter
9595
total_cost += iter * first(cost(op, vectorized, Wshift, size_T))
9696
total_cost > max_cost && return total_cost # abort if more expensive; we only want to know the cheapest
9797
end
9898
end
99-
total_cost
99+
total_cost + stride_penalty(ls, order)
100100
end
101101

102102
# only covers vectorized ops; everything else considered lifted?
@@ -198,7 +198,10 @@ function determine_unroll_factor(
198198
end
199199

200200
function tile_cost(X, U, T, UL, TL)
201-
X[1] + X[4] + X[2] * (num_iterations(TL, T)/TL) + X[3] * (num_iterations(UL, U)/UL)
201+
Tfactor = (num_iterations(TL, T)/TL)
202+
Ufactor = (num_iterations(UL, U)/UL)
203+
# X[1]*Tfactor*Ufactor + X[4] + X[2] * Tfactor + X[3] * Ufactor
204+
X[1] + X[4] + X[2] * Tfactor + X[3] * Ufactor
202205
end
203206
function solve_tilesize(X, R, UL, TL)
204207
@inbounds any(iszero, (R[1],R[2],R[3])) && return -1,-1,Inf #solve_smalltilesize(X, R, Umax, Tmax)
@@ -338,7 +341,7 @@ function stride_penalty(ls::LoopSet, op::Operation, order::Vector{Symbol})
338341
num_loops = length(order)
339342
contigsym = first(loopdependencies(op))
340343
contigsym == Symbol("##DISCONTIGUOUSSUBARRAY##") && return 0
341-
iter = 0
344+
iter = 1
342345
for i 0:num_loops - 1
343346
loopsym = order[num_loops - i]
344347
loopsym === contigsym && return iter
@@ -386,34 +389,29 @@ function evaluate_cost_tile(
386389
reg_pressure = reg_pres_buf(ls)
387390
# @inbounds reg_pressure[2] = 1
388391
# @inbounds reg_pressure[3] = 1
389-
unrollediter = length(ls, unrolled)
390-
tilediter = length(ls, tiled)
391-
unrollediter = unrolled === vectorized ? num_iterations(unrollediter, W) : unrollediter # tiled cannot be vectorized, so do not check
392-
iter::Int = tilediter * unrollediter
392+
iter::Int = 1
393393
for n 1:N
394394
itersym = order[n]
395395
# Add to set of defined symbles
396396
push!(nested_loop_syms, itersym)
397-
stepsize = 1
398-
if n > 2
399-
itersymlooplen = length(ls, itersym)
400-
iter *= itersym === vectorized ? num_iterations(itersymlooplen, W) : itersymlooplen
401-
end
397+
looplength = length(ls, itersym)
398+
iter *= itersym === vectorized ? num_iterations(looplength, W) : looplength
402399
# check which vars we can define at this level of loop nest
403400
for (id, op) enumerate(ops)
404401
# isconstant(op) && continue
405402
# @assert id == identifier(op)+1 # testing, for now
406403
# won't define if already defined...
407404
included_vars[id] && continue
408405
# it must also be a subset of defined symbols
409-
loopdependencies(op) nested_loop_syms || continue
406+
all(ld -> ld nested_loop_syms, loopdependencies(op)) || continue
410407
# # @show nested_loop_syms
411408
# # @show reduceddependencies(op)
412409
rd = reduceddependencies(op)
413410
hasintersection(rd, nested_loop_syms[1:end-length(rd)]) && return 0,0,Inf
414411
included_vars[id] = true
415412
unrolledtiled[1,id] = unrolled loopdependencies(op)
416413
unrolledtiled[2,id] = tiled loopdependencies(op)
414+
# @show op iter, unrolledtiled[:,id]
417415
iters[id] = iter
418416
innerloop loopdependencies(op) && set_upstream_family!(descendentsininnerloop, op, true)
419417
end

src/graphs.jl

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ Base.size(lo::LoopOrder) = (2,2,2,length(lo.loopnames))
140140
Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i::Int) = lo.oporder[i]
141141
Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i...) = lo.oporder[LinearIndices(size(lo))[i...]]
142142

143+
@enum NumberType::Int8 HardInt HardFloat IntOrFloat INVALID
144+
143145
# Must make it easy to iterate
144146
# outer_reductions is a vector of indices (within operation vectors) of the reduction operation, eg the vmuladd op in a dot product
145147
# O(N) search is faster at small sizes
@@ -154,8 +156,8 @@ struct LoopSet
154156
preamble_symsym::Vector{Tuple{Int,Symbol}}
155157
preamble_symint::Vector{Tuple{Int,Int}}
156158
preamble_symfloat::Vector{Tuple{Int,Float64}}
157-
preamble_zeros::Vector{Int}
158-
preamble_ones::Vector{Int}
159+
preamble_zeros::Vector{Tuple{Int,NumberType}}
160+
preamble_ones::Vector{Tuple{Int,NumberType}}
159161
includedarrays::Vector{Symbol}
160162
includedactualarrays::Vector{Symbol}
161163
syms_aliasing_refs::Vector{Symbol}
@@ -199,21 +201,41 @@ function pushpreamble!(ls::LoopSet, op::Operation, v::Symbol)
199201
end
200202
nothing
201203
end
202-
pushpreamble!(ls::LoopSet, op::Operation, v::Integer) = push!(ls.preamble_symint, (identifier(op),convert(Int,v)))
203-
pushpreamble!(ls::LoopSet, op::Operation, v::Real) = push!(ls.preamble_symfloat, (identifier(op),convert(Float64,v)))
204+
function pushpreamble!(ls::LoopSet, op::Operation, v::Number)
205+
typ = v isa Integer ? HardInt : HardFloat
206+
id = identifier(op)
207+
if iszero(v)
208+
push!(ls.preamble_zeros, (id, typ))
209+
elseif isone(v)
210+
push!(ls.preamble_ones, (id, typ))
211+
elseif v isa Integer
212+
push!(ls.preamble_symint, (id, convert(Int,v)))
213+
else
214+
push!(ls.preamble_symfloat, (id, convert(Float64,v)))
215+
end
216+
end
204217
pushpreamble!(ls::LoopSet, ex::Expr) = push!(ls.preamble.args, ex)
205218
function pushpreamble!(ls::LoopSet, op::Operation, RHS::Expr)
206219
c = gensym(:licmconst)
207220
if RHS.head === :call && first(RHS.args) === :zero
208-
push!(ls.preamble_zeros, identifier(op))
221+
push!(ls.preamble_zeros, (identifier(op), IntOrFloat))
209222
elseif RHS.head === :call && first(RHS.args) === :one
210-
push!(ls.preamble_ones, identifier(op))
223+
push!(ls.preamble_ones, (identifier(op), IntOrFloat))
211224
else
212225
pushpreamble!(ls, Expr(:(=), c, RHS))
213226
pushpreamble!(ls, op, c)
214227
end
215228
nothing
216229
end
230+
function zerotype(ls::LoopSet, op::Operation)
231+
opid = identifier(op)
232+
for (id,typ) ls.preamble_zeros
233+
id == opid && return typ
234+
end
235+
INVALID
236+
end
237+
238+
217239

218240
includesarray(ls::LoopSet, array::Symbol) = array ls.includedarrays
219241

@@ -425,7 +447,7 @@ function add_operation!(
425447
elseif f === :zero || f === :one
426448
c = gensym(f)
427449
op = add_constant!(ls, c, ls.loopsymbols[1:position], LHS, elementbytes, :numericconstant)
428-
push!(f === :zero ? ls.preamble_zeros : ls.preamble_ones, identifier(op))
450+
push!(f === :zero ? ls.preamble_zeros : ls.preamble_ones, (identifier(op), IntOrFloat))
429451
op
430452
else
431453
add_compute!(ls, LHS, RHS, elementbytes, position)
@@ -452,7 +474,7 @@ function add_operation!(
452474
elseif f === :zero || f === :one
453475
c = gensym(f)
454476
op = add_constant!(ls, c, ls.loopsymbols[1:position], LHS_sym, elementbytes, :numericconstant)
455-
push!(f === :zero ? ls.preamble_zeros : ls.preamble_ones, identifier(op))
477+
push!(f === :zero ? ls.preamble_zeros : ls.preamble_ones, (identifier(op), IntOrFloat))
456478
op
457479
else
458480
add_compute!(ls, LHS_sym, RHS, elementbytes, position, LHS_ref)

0 commit comments

Comments
 (0)