Skip to content

Commit 9379e49

Browse files
committed
Make determinestrategy more flexible in allowing different unroll/tile/vectorized orders when tiling.
1 parent 37f43c4 commit 9379e49

14 files changed

+434
-452
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector
99
maybestaticfirst, maybestaticlast
1010
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod,
1111
sizeequivalentfloat, sizeequivalentint, vadd!, vsub!, vmul!, vfdiv!, vfmadd!, vfnmadd!, vfmsub!, vfnmsub!,
12-
vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, #prefetch,
12+
vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, sizeequivalentfloat, sizeequivalentint, #prefetch,
1313
vmullog2, vmullog10, vdivlog2, vdivlog10, vmullog2add!, vmullog10add!, vdivlog2add!, vdivlog10add!, vfmaddaddone
1414
using Base.Broadcast: Broadcasted, DefaultArrayStyle
1515
using LinearAlgebra: Adjoint, Transpose

src/constructors.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ true
8888
8989
"""
9090
macro avx(q)
91+
q = macroexpand(__module__, q)
9192
q2 = if q.head === :for
9293
setup_call(LoopSet(q, __module__))
9394
else# assume broadcast
@@ -134,12 +135,14 @@ end
134135
macro avx(arg, q)
135136
@assert q.head === :for
136137
@assert arg.head === :(=)
138+
q = macroexpand(__module__, q)
137139
inline, U, T = check_macro_kwarg(arg)
138140
ls = LoopSet(q, __module__)
139141
esc(setup_call(ls, inline, U, T))
140142
end
141143
macro avx(arg1, arg2, q)
142144
@assert q.head === :for
145+
q = macroexpand(__module__, q)
143146
inline, U, T = check_macro_kwarg(arg1)
144147
inline, U, T = check_macro_kwarg(arg2, inline, U, T)
145148
esc(setup_call(LoopSet(q, __module__), inline, U, T))
@@ -154,15 +157,18 @@ While `@avx` punts to a generated function to enable type-based analysis, `_@avx
154157
works on just the expressions. This requires that it makes a number of default assumptions.
155158
"""
156159
macro _avx(q)
160+
q = macroexpand(__module__, q)
157161
esc(lower(LoopSet(q, __module__)))
158162
end
159163
macro _avx(arg, q)
160164
@assert q.head === :for
165+
q = macroexpand(__module__, q)
161166
inline, U, T = check_macro_kwarg(arg)
162167
esc(lower(LoopSet(q, __module__), U, T))
163168
end
164169

165170

166171
macro avx_debug(q)
172+
q = macroexpand(__module__, q)
167173
esc(LoopVectorization.setup_call_debug(LoopSet(q, __module__)))
168174
end

src/determinestrategy.jl

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,15 @@ function unitstride(ls::LoopSet, op::Operation, s::Symbol)
4343
end
4444

4545
function register_pressure(op::Operation)
46-
if isconstant(op)
46+
if isconstant(op) || isloopvalue(op)
4747
0
4848
else
4949
instruction_cost(instruction(op)).register_pressure
5050
end
5151
end
5252
function cost(ls::LoopSet, op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int = op.elementbytes)
5353
isconstant(op) && return 0.0, 0, 1
54+
isloopvalue(op) && return 0.0, 0, 1
5455
# Wshift == dependson(op, unrolled) ? Wshift : 0
5556
# c = first(cost(instruction(op), Wshift, size_T))::Int
5657
instr = Instruction(:LoopVectorization, instruction(op).instr)
@@ -73,8 +74,8 @@ function cost(ls::LoopSet, op::Operation, unrolled::Symbol, Wshift::Int, size_T:
7374
# else # vmov(a/u)pd
7475
end
7576
elseif instr === :setindex! # broadcast or reductionstore; if store we want to penalize reduction
76-
srt *= 2
77-
sl *= 2
77+
srt *= 3
78+
sl *= 3
7879
end
7980
end
8081
srt, sl, srp
@@ -400,16 +401,30 @@ function stride_penalty(ls::LoopSet, order::Vector{Symbol})
400401
end
401402
stridepenalty * 1e-9
402403
end
404+
function convolution_cost_factor(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol, v::Symbol)
405+
(u1 loopdependencies(op) && u2 loopdependencies(op)) || return 1.0
406+
istranslation = false
407+
inds = getindices(op); li = op.ref.loopedindex
408+
for i eachindex(li)
409+
if !li[i]
410+
opp = findparent(ls, inds[i + (first(inds) === Symbol("##DISCONTIGUOUSSUBARRAY##"))])
411+
if instruction(opp).instr (:+, :-) && u1 loopdependencies(opp) && u2 loopdependencies(opp)
412+
istranslation = true
413+
end
414+
end
415+
end
416+
istranslation ? 0.25 : 1.0
417+
end
403418
# Just tile outer two loops?
404419
# But optimal order within tile must still be determined
405420
# as well as size of the tiles.
406421
function evaluate_cost_tile(
407-
ls::LoopSet, order::Vector{Symbol}, vectorized::Symbol
422+
ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, tiled::Symbol, vectorized::Symbol
408423
)
409424
N = length(order)
410425
@assert N 2 "Cannot tile merely $N loops!"
411-
tiled = order[1]
412-
unrolled = order[2]
426+
# tiled = order[1]
427+
# unrolled = order[2]
413428
ops = operations(ls)
414429
nops = length(ops)
415430
included_vars = fill!(resize!(ls.included_vars, nops), false)
@@ -464,6 +479,10 @@ function evaluate_cost_tile(
464479
isunrolled = unrolledtiled[1,id]
465480
istiled = unrolledtiled[2,id]
466481
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
482+
if isload(op)
483+
factor = convolution_cost_factor(ls, op, unrolled, tiled, vectorized)
484+
rt *= factor#; rp *= factor;
485+
end
467486
rp = opisininnerloop ? rp : 0 # we only care about register pressure within the inner most loop
468487
rt *= iters[id]
469488
if isunrolled && istiled # no cost decrease; cost must be repeated
@@ -556,54 +575,56 @@ function choose_tile(ls::LoopSet)
556575
lo = LoopOrders(ls)
557576
# @show lo.syms ls.loop_order.bestorder
558577
best_order = copyto!(ls.loop_order.bestorder, lo.syms)
559-
best_vec = first(best_order) # filler
578+
best_unrolled = best_tiled = best_vec = first(best_order) # filler
560579
new_order, state = iterate(lo) # right now, new_order === best_order
561580
U, T, lowest_cost = 0, 0, Inf
581+
nloops = length(new_order)
562582
while true
563-
for new_vec @view(new_order[2:end]) # view to skip first
564-
U_temp, T_temp, cost_temp = evaluate_cost_tile(ls, new_order, new_vec)
565-
if cost_temp < lowest_cost
566-
lowest_cost = cost_temp
567-
U, T = U_temp, T_temp
568-
best_vec = new_vec
569-
copyto!(best_order, new_order)
570-
save_tilecost!(ls)
583+
for new_vec new_order # view to skip first
584+
for nt 1:nloops-1
585+
new_tiled = new_order[nt]
586+
for new_unrolled @view(new_order[nt+1:end])
587+
U_temp, T_temp, cost_temp = evaluate_cost_tile(ls, new_order, new_unrolled, new_tiled, new_vec)
588+
if cost_temp < lowest_cost
589+
lowest_cost = cost_temp
590+
U, T = U_temp, T_temp
591+
best_vec = new_vec
592+
best_tiled = new_tiled
593+
best_unrolled = new_unrolled
594+
copyto!(best_order, new_order)
595+
save_tilecost!(ls)
596+
end
597+
end
571598
end
572599
end
573600
iter = iterate(lo, state)
574-
iter === nothing && return best_order, best_vec, U, T, lowest_cost
601+
iter === nothing && return best_order, best_unrolled, best_tiled, best_vec, U, T, lowest_cost
575602
new_order, state = iter
576603
end
577604
end
578605
# Last in order is the inner most loop
579606
function choose_order(ls::LoopSet)
580607
if num_loops(ls) > 1
581-
torder, tvec, tU, tT, tc = choose_tile(ls)
608+
torder, tunroll, ttile, tvec, tU, tT, tc = choose_tile(ls)
582609
else
583610
tc = Inf
584611
end
585612
uorder, uvec, uc = choose_unroll_order(ls, tc)
586613
if num_loops(ls) > 1 && tc uc
587-
return torder, tvec, min(tU, tT), tT
614+
return torder, tunroll, ttile, tvec, min(tU, tT), tT
588615
# return torder, tvec, 4, 4#5, 5
589616
else
590-
return uorder, uvec, determine_unroll_factor(ls, uorder, first(uorder), uvec), -1
617+
return uorder, first(uorder), Symbol("##undefined##"), uvec, determine_unroll_factor(ls, uorder, first(uorder), uvec), -1
591618
end
592619
end
593620

594621
function register_pressure(ls::LoopSet)
595-
# uses unroll of 1 if not tiling
596-
if num_loops(ls) > 1
597-
torder, tvec, tU, tT, tc = choose_tile(ls)
622+
order, unroll, vec, U, T = choose_order(ls)
623+
if T == -1
624+
sum(register_pressure, operations(ls))
598625
else
599-
tc = Inf
600-
end
601-
uorder, uvec, uc = choose_unroll_order(ls, tc)
602-
if num_loops(ls) > 1 && tc uc # tile
603626
rp = @view ls.reg_pressure[:,1]
604627
tU * tT * rp[1] + tU * rp[2] + rp[3] + rp[4]
605-
else
606-
sum(register_pressure, operations(ls))
607-
end
628+
end
608629
end
609630

src/graphs.jl

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,30 @@
3636

3737
# end
3838

39+
struct UnrollSpecification
40+
unrolledloopnum::Int
41+
tiledloopnum::Int
42+
vectorizedloopnum::Int
43+
U::Int
44+
T::Int
45+
end
46+
# UnrollSpecification(ls::LoopSet, unrolled::Loop, vectorized::Symbol, U, T) = UnrollSpecification(ls, unrolled.itersymbol, vectorized, U, T)
47+
function UnrollSpecification(us::UnrollSpecification, U, T)
48+
@unpack unrolledloopnum, tiledloopnum, vectorizedloopnum = us
49+
UnrollSpecification(unrolledloopnum, tiledloopnum, vectorizedloopnum, U, T)
50+
end
51+
function UnrollSpecification(us::UnrollSpecification; U = us.U, T = us.T)
52+
@unpack unrolledloopnum, tiledloopnum, vectorizedloopnum = us
53+
UnrollSpecification(unrolledloopnum, tiledloopnum, vectorizedloopnum, U, T)
54+
end
55+
isunrolled(us::UnrollSpecification, n::Int) = us.unrolledloopnum == n
56+
istiled(us::UnrollSpecification, n::Int) = !isunrolled(us, n) && us.tiledloopnum == n
57+
isvectorized(us::UnrollSpecification, n::Int) = us.vectorizedloopnum == n
58+
function unrollfactor(us::UnrollSpecification, n::Int)
59+
@unpack unrolledloopnum, tiledloopnum, U, T = us
60+
(unrolledloopnum == n) ? U : ((tiledloopnum == n) ? T : 1)
61+
end
62+
3963
struct Loop
4064
itersymbol::Symbol
4165
starthint::Int
@@ -57,7 +81,7 @@ function Loop(itersymbol::Symbol, start::Union{Int,Symbol}, stop::Union{Int,Symb
5781
end
5882
Base.length(loop::Loop) = 1 + loop.stophint - loop.starthint
5983
isstaticloop(loop::Loop) = loop.startexact & loop.stopexact
60-
function startloop(loop::Loop, isvectorized, W, itersymbol = loop.itersymbol)
84+
function startloop(loop::Loop, isvectorized, W, itersymbol)
6185
startexact = loop.startexact
6286
if isvectorized
6387
if startexact
@@ -71,41 +95,48 @@ function startloop(loop::Loop, isvectorized, W, itersymbol = loop.itersymbol)
7195
Expr(:(=), itersymbol, Expr(:call, lv(:unwrap), loop.startsym))
7296
end
7397
end
74-
function vec_looprange(loop::Loop, isunrolled::Bool, W::Symbol, U::Int)
98+
function vec_looprange(loop::Loop, W::Symbol, UF::Int, mangledname::Symbol)
99+
isunrolled = UF > 1
75100
incr = if isunrolled
76-
Expr(:call, lv(:valmuladd), W, U, -2)
101+
Expr(:call, lv(:valmuladd), W, UF, -2)
77102
else
78103
Expr(:call, lv(:valsub), W, 2)
79104
end
80105
if loop.stopexact # split for type stability
81-
Expr(:call, :<, loop.itersymbol, Expr(:call, :-, loop.stophint, incr))
106+
Expr(:call, :<, mangledname, Expr(:call, :-, loop.stophint, incr))
82107
else
83-
Expr(:call, :<, loop.itersymbol, Expr(:call, :-, loop.stopsym, incr))
108+
Expr(:call, :<, mangledname, Expr(:call, :-, loop.stopsym, incr))
84109
end
85110
end
86111
function looprange(loop::Loop, incr::Int, mangledname::Symbol)
87-
incr -= 2
112+
incr = 2 - incr
88113
if iszero(incr)
89114
Expr(:call, :<, mangledname, loop.stopexact ? loop.stophint : loop.stopsym)
90115
else
91-
Expr(:call, :<, mangledname, loop.stopexact ? loop.stophint - incr : Expr(:call, :-, loop.stopsym, incr))
116+
Expr(:call, :<, mangledname, loop.stopexact ? loop.stophint + incr : Expr(:call, :+, loop.stopsym, incr))
92117
end
93118
end
94119
function terminatecondition(
95-
loop::Loop, W::Symbol, U::Int, T::Int, isvectorized::Bool, isunrolled::Bool, istiled::Bool,
96-
mangledname::Symbol = loop.itersymbol, mask::Nothing = nothing
120+
loop::Loop, us::UnrollSpecification, n::Int, W::Symbol, mangledname::Symbol, inclmask::Bool, UF::Int = unrollfactor(us, n)
97121
)
98-
if isvectorized
99-
vec_looprange(loop, isunrolled, W, U) # may not be tiled
122+
if !isvectorized(us, n)
123+
looprange(loop, UF, mangledname)
124+
elseif inclmask
125+
looprange(loop, 1, mangledname)
100126
else
101-
looprange(loop, isunrolled ? U : (istiled ? T : 1), mangledname)
127+
vec_looprange(loop, W, UF, mangledname) # may not be tiled
102128
end
103129
end
104-
function terminatecondition(
105-
loop::Loop, W::Symbol, U::Int, T::Int, isvectorized::Bool, isunrolled::Bool, istiled::Bool,
106-
mangledname::Symbol, mask::Symbol
107-
)
108-
looprange(loop, 1, mangledname)
130+
function incrementloopcounter(us::UnrollSpecification, n::Int, W::Symbol, mangledname::Symbol, UF::Int = unrollfactor(us, n))
131+
if isvectorized(us, n)
132+
if UF == 1
133+
Expr(:(=), mangledname, Expr(:call, lv(:valadd), W, mangledname))
134+
else
135+
Expr(:+=, mangledname, Expr(:call, lv(:valmul), W, UF))
136+
end
137+
else
138+
Expr(:+=, mangledname, UF)
139+
end
109140
end
110141

111142
# load/compute/store × isunrolled × istiled × pre/post loop × Loop number
@@ -556,3 +587,11 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
556587
throw("Don't know how to handle expression:\n$ex")
557588
end
558589
end
590+
591+
function UnrollSpecification(ls::LoopSet, unrolled::Symbol, tiled::Symbol, vectorized::Symbol, U, T)
592+
order = names(ls)
593+
nu = findfirst(isequal(unrolled), order)::Int
594+
nt = T == -1 ? nu : findfirst(isequal(tiled), order)::Int
595+
nv = findfirst(isequal(vectorized), order)::Int
596+
UnrollSpecification(nu, nt, nv, U, T)
597+
end

src/lower_compute.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function lower_compute!(
5858
Uiter = opunrolled ? U - 1 : 0
5959
isreduct = isreduction(op)
6060
if !isnothing(suffix) && isreduct && tiledouterreduction == -1
61-
instrfid = findfirst(isequal(instr.instr), (:vfmadd_fast, :vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
61+
instrfid = findfirst(isequal(instr.instr), (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub))
6262
if instrfid !== nothing
6363
instr = Instruction((:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)[instrfid])
6464
end

0 commit comments

Comments
 (0)