Skip to content

Commit 4d77be0

Browse files
committed
A few changes to determinestrategy, and update benchmark results.
1 parent 6389458 commit 4d77be0

28 files changed

+175
-86
lines changed

Manifest.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4242

4343
[[SIMDPirates]]
4444
deps = ["VectorizationBase"]
45-
git-tree-sha1 = "9fc6737cd40087e7d486f8e81fb5be8ad18f970b"
45+
git-tree-sha1 = "1a9cbe1be1f5d43ac49eeb38ca64dd78e50a0cc6"
4646
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
47-
version = "0.7.2"
47+
version = "0.7.3"
4848

4949
[[SLEEFPirates]]
5050
deps = ["Libdl", "SIMDPirates", "VectorizationBase"]
@@ -69,6 +69,6 @@ version = "0.1.0"
6969

7070
[[VectorizationBase]]
7171
deps = ["CpuId", "LinearAlgebra"]
72-
git-tree-sha1 = "76e8817f7732d9a127191f5bcd5fe3a5eed0fb3e"
72+
git-tree-sha1 = "83f073a514b5d654cc9c72ae283a33388a0d0386"
7373
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
74-
version = "0.9.2"
74+
version = "0.9.3"

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1313

1414
[compat]
1515
OffsetArrays = "1"
16-
SIMDPirates = "0.7.1"
16+
SIMDPirates = "0.7.3"
1717
SLEEFPirates = "0.4"
1818
UnPack = "0"
19-
VectorizationBase = "0.9.2"
19+
VectorizationBase = "0.9.3"
2020
julia = "1.1"
2121

2222
[extras]

docs/src/assets/bench_AmulB_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_AmulBt_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_Amulvb_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_AplusAt_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_AtmulB_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_AtmulBt_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_Atmulvb_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_aplusBc_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_dot3_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_dot_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_exp_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_filter2d_3x3_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_filter2d_dynamic_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_filter2d_unrolled_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_logdettriangle_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_random_access_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_selfdot_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/src/assets/bench_sse_v1.svg

Lines changed: 1 addition & 1 deletion
Loading

src/determinestrategy.jl

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -352,16 +352,25 @@ function solve_tilesize(
352352
reg_pressure::AbstractVector{Float64},
353353
W::Int, vectorized::Symbol
354354
)
355-
maxTbase = maxUbase = 4#8
355+
maxTbase = maxUbase = 6#8
356356
maxT = maxTbase#8
357357
maxU = maxUbase#8
358358
tiledloop = getloop(ls, tiled)
359359
unrolledloop = getloop(ls, unrolled)
360360
if isstaticloop(tiledloop)
361+
if length(tiledloop) 4
362+
T = length(tiledloop)
363+
U = max(1, solve_tilesize_constT(cost_vec, reg_pressure, T))
364+
return U, T, tile_cost(cost_vec, U, T, length(unrolledloop), T)
365+
end
361366
maxT = min(4maxT, length(tiledloop))
362367
end
363368
if isstaticloop(unrolledloop)
364369
UL = length(unrolledloop)
370+
if unrolled !== vectorized && UL 4
371+
T = max(1, solve_tilesize_constU(cost_vec, reg_pressure, UL))
372+
return UL, T, tile_cost(cost_vec, UL, T, UL, length(tiledloop))
373+
end
365374
UL = unrolled === vectorized ? cld(UL,W) : UL
366375
maxU = min(4maxU, UL)
367376
end
@@ -383,11 +392,22 @@ function set_upstream_family!(adal::Vector{T}, op::Operation, val::T) where {T}
383392
set_upstream_family!(adal, opp, val)
384393
end
385394
end
386-
395+
function stride_penalty_opdependent(ls::LoopSet, op::Operation, order::Vector{Symbol}, contigsym::Symbol)
396+
num_loops = length(order)
397+
firstloopdeps = loopdependencies(findparent(ls, contigsym))
398+
iter = 1
399+
for i 0:num_loops - 1
400+
loopsym = order[num_loops - i]
401+
loopsym firstloopdeps && return iter
402+
iter *= length(getloop(ls, loopsym))
403+
end
404+
iter
405+
end
387406
function stride_penalty(ls::LoopSet, op::Operation, order::Vector{Symbol})
388407
num_loops = length(order)
389-
contigsym = first(loopdependencies(op))
408+
contigsym = first(loopdependencies(op.ref))
390409
contigsym == Symbol("##DISCONTIGUOUSSUBARRAY##") && return 0
410+
first(op.ref.loopedindex) || return stride_penalty_opdependent(ls, op, order, contigsym)
391411
iter = 1
392412
for i 0:num_loops - 1
393413
loopsym = order[num_loops - i]
@@ -405,19 +425,40 @@ function stride_penalty(ls::LoopSet, order::Vector{Symbol})
405425
end
406426
stridepenalty# * 1e-9
407427
end
408-
function convolution_cost_factor(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol)
409-
(u1 loopdependencies(op) && u2 loopdependencies(op)) || return 1.0, 1.0
428+
function isoptranslation(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol, vectorized::Symbol)
429+
(vectorized == u1 || vectorized == u2) && return false, false
430+
(u1 loopdependencies(op) && u2 loopdependencies(op)) || return false, false
410431
istranslation = false
411432
inds = getindices(op); li = op.ref.loopedindex
433+
translationplus = false
412434
for i eachindex(li)
413435
if !li[i]
414436
opp = findparent(ls, inds[i + (first(inds) === Symbol("##DISCONTIGUOUSSUBARRAY##"))])
415437
if instruction(opp).instr (:+, :-) && u1 loopdependencies(opp) && u2 loopdependencies(opp)
416438
istranslation = true
439+
translationplus = instruction(opp).instr === :+
440+
end
441+
end
442+
end
443+
istranslation, translationplus
444+
end
445+
function convolution_cost_factor(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol, v::Symbol)
446+
if first(isoptranslation(ls, op, u1, u2, v))
447+
for loop ls.loops
448+
# If another loop is short, assume that LLVM will unroll it, in which case
449+
# we want to be a little more conservative in terms of register pressure.
450+
#FIXME: heuristic hack to get some desired behavior.
451+
if isstaticloop(loop) && length(loop) 4
452+
itersym = loop.itersymbol
453+
if itersym !== u1 && itersym !== u2
454+
return (0.25, 1.0)
455+
end
417456
end
418457
end
458+
(0.25, 0.5)
459+
else
460+
(1.0, 1.0)
419461
end
420-
istranslation ? (0.25, 1.0) : (1.0, 1.0)
421462
end
422463
# Just tile outer two loops?
423464
# But optimal order within tile must still be determined
@@ -484,7 +525,7 @@ function evaluate_cost_tile(
484525
istiled = unrolledtiled[2,id]
485526
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
486527
if isload(op)
487-
factor1, factor2 = convolution_cost_factor(ls, op, unrolled, tiled)
528+
factor1, factor2 = convolution_cost_factor(ls, op, unrolled, tiled, vectorized)
488529
rt *= factor1; rp *= factor2;
489530
end
490531
# @show op rt, lat, rp

src/lower_load.jl

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
1-
function pushvectorload!(q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, W::Symbol, mask, vecnotunrolled::Bool)
2-
@unpack u, unrolled = td
3-
ptr = refname(op)
4-
name, mo = name_memoffset(var, op, td, W, vecnotunrolled)
5-
instrcall = Expr(:call, lv(:vload), ptr, mo)
6-
if mask !== nothing && (vecnotunrolled || u == U - 1)
7-
push!(instrcall.args, mask)
8-
end
9-
push!(q.args, Expr(:(=), name, instrcall))
10-
end
111
function lower_load_scalar!(
122
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
13-
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
3+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing, umin::Int = 0
144
)
155
loopdeps = loopdependencies(op)
166
@assert vectorized loopdeps
177
var = variable_name(op, suffix)
188
ptr = refname(op)
199
isunrolled = unrolled loopdeps
2010
U = isunrolled ? U : 1
21-
for u 0:U-1
11+
for u umin:U-1
2212
varname = varassignname(var, u, isunrolled)
2313
td = UnrollArgs(u, unrolled, tiled, suffix)
2414
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:vload), ptr, mem_offset_u(op, td))))
2515
end
2616
nothing
2717
end
18+
function pushvectorload!(q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, W::Symbol, mask, vecnotunrolled::Bool)
19+
@unpack u, unrolled = td
20+
ptr = refname(op)
21+
name, mo = name_memoffset(var, op, td, W, vecnotunrolled)
22+
instrcall = Expr(:call, lv(:vload), ptr, mo)
23+
if mask !== nothing && (vecnotunrolled || u == U - 1)
24+
push!(instrcall.args, mask)
25+
end
26+
push!(q.args, Expr(:(=), name, instrcall))
27+
end
2828
function lower_load_vectorized!(
2929
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
30-
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
30+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing, umin::Int = 0
3131
)
3232
loopdeps = loopdependencies(op)
3333
@assert vectorized loopdeps
3434
if unrolled loopdeps
35-
umin = 0
35+
umin = umin
3636
U = U
3737
else
3838
umin = -1
@@ -51,12 +51,28 @@ end
5151
# TODO: this code should be rewritten to be more "orthogonal", so that we're just combining separate pieces.
5252
# Using sentinel values (eg, T = -1 for non tiling) in part to avoid recompilation.
5353
function lower_load!(
54-
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
54+
q::Expr, op::Operation, vectorized::Symbol, ls::LoopSet, unrolled::Symbol, tiled::Symbol, U::Int,
5555
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
5656
)
57+
if !isnothing(suffix) && suffix > 0
58+
istr, ispl = isoptranslation(ls, op, unrolled, tiled, vectorized)
59+
if istr && ispl
60+
varnew = variable_name(op, suffix)
61+
varold = variable_name(op, suffix - 1)
62+
for u 0:U-2
63+
push!(q.args, Expr(:(=), Symbol(varnew, u), Symbol(varold, u + 1)))
64+
end
65+
umin = U - 1
66+
else
67+
umin = 0
68+
end
69+
else
70+
umin = 0
71+
end
72+
W = ls.W
5773
if vectorized loopdependencies(op)
58-
lower_load_vectorized!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
74+
lower_load_vectorized!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, umin)
5975
else
60-
lower_load_scalar!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
76+
lower_load_scalar!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, umin)
6177
end
6278
end

src/lowering.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function lower!(
2121
lower_zero!(q, op, vectorized, ls, unrolled, U, suffix, zerotyp)
2222
end
2323
elseif isload(op)
24-
lower_load!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
24+
lower_load!(q, op, vectorized, ls, unrolled, tiled, U, suffix, mask)
2525
elseif iscompute(op)
2626
lower_compute!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
2727
elseif isstore(op)
@@ -47,7 +47,7 @@ function lower!(
4747
lower_zero!(q, op, vectorized, ls, unrolled, U, suffix, zerotyp)
4848
end
4949
elseif isload(op)
50-
lower_load!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
50+
lower_load!(q, op, vectorized, ls, unrolled, tiled, U, suffix, mask)
5151
elseif iscompute(op)
5252
lower_compute!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
5353
end

src/reconstruct_loopset.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
function Loop(ls::LoopSet, l::Int, sym::Symbol, ::Type{UnitRange{Int}})
2-
start = gensym(String(sym)*"_loopstart"); stop = gensym(String(sym)*"_loopstop")
3-
pushpreamble!(ls, Expr(:(=), start, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), Expr(:ref, :lb, l), QuoteNode(:start)))))
4-
pushpreamble!(ls, Expr(:(=), stop, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), Expr(:ref, :lb, l), QuoteNode(:stop)))))
1+
function Loop(ls::LoopSet, ex::Expr, sym::Symbol, ::Type{<:AbstractUnitRange})
2+
ssym = String(sym)
3+
start = gensym(ssym*"_loopstart"); stop = gensym(ssym*"_loopstop"); loopsym = gensym(ssym * "_loop")
4+
pushpreamble!(ls, Expr(:(=), loopsym, ex))
5+
pushpreamble!(ls, Expr(:(=), start, Expr(:call, :first, loopsym)))
6+
pushpreamble!(ls, Expr(:(=), stop, Expr(:call, :last, loopsym)))
57
Loop(sym, 0, 1024, start, stop, false, false)::Loop
68
end
7-
function Loop(ls::LoopSet, l::Int, sym::Symbol, ::Type{StaticUpperUnitRange{U}}) where {U}
9+
function Loop(ls::LoopSet, ex::Expr, sym::Symbol, ::Type{StaticUpperUnitRange{U}}) where {U}
810
start = gensym(String(sym)*"_loopstart")
9-
pushpreamble!(ls, Expr(:(=), start, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), Expr(:ref, :lb, l), QuoteNode(:L)))))
11+
pushpreamble!(ls, Expr(:(=), start, Expr(:(.), ex, QuoteNode(:L))))
1012
Loop(sym, U - 1024, U, start, Symbol(""), false, true)::Loop
1113
end
12-
function Loop(ls::LoopSet, l::Int, sym::Symbol, ::Type{StaticLowerUnitRange{L}}) where {L}
14+
function Loop(ls::LoopSet, ex::Expr, sym::String, ::Type{StaticLowerUnitRange{L}}) where {L}
1315
stop = gensym(String(sym)*"_loopstop")
14-
pushpreamble!(ls, Expr(:(=), stop, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), Expr(:ref, :lb, l), QuoteNode(:U)))))
16+
pushpreamble!(ls, Expr(:(=), stop, Expr(:(.), ex, QuoteNode(:U))))
1517
Loop(sym, L, L + 1024, Symbol(""), stop, true, false)::Loop
1618
end
1719
# Is there any likely way to generate such a range?
@@ -21,17 +23,12 @@ end
2123
# pushpreamble!(ls, Expr(:(=), stop, Expr(:call, :(+), start, N - 1)))
2224
# Loop(gensym(:n), 0, N, start, stop, false, false)::Loop
2325
# end
24-
function Loop(ls, l, sym::Symbol, ::Type{StaticUnitRange{L,U}}) where {L,U}
26+
function Loop(::LoopSet, ::Expr, sym::Symbol, ::Type{StaticUnitRange{L,U}}) where {L,U}
2527
Loop(sym, L, U, Symbol(""), Symbol(""), true, true)::Loop
2628
end
2729

28-
function Loop(ls::LoopSet, l::Int, k::Int, sym::Symbol, ::Type{<:CartesianIndices{N}}) where N
29-
str = String(sym)*'#'*string(k)*'#'
30-
start = gensym(str*"_loopstart"); stop = gensym(str*"_loopstop")
31-
axisexpr = Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:ref, Expr(:., Expr(:ref, :lb, l), QuoteNode(:indices)), k))
32-
pushpreamble!(ls, Expr(:(=), start, Expr(:call, :first, axisexpr)))
33-
pushpreamble!(ls, Expr(:(=), stop, Expr(:call, :last, axisexpr)))
34-
Loop(Symbol(str), 0, 1024, start, stop, false, false)::Loop
30+
function extract_loop(l)
31+
Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:ref, :lb, l))
3532
end
3633

3734
function add_loops!(ls::LoopSet, LPSYM, LB)
@@ -41,14 +38,16 @@ function add_loops!(ls::LoopSet, LPSYM, LB)
4138
if l<:CartesianIndices
4239
add_loops!(ls, i, sym, l)
4340
else
44-
add_loop!(ls, Loop(ls, i, sym, l)::Loop)
41+
add_loop!(ls, Loop(ls, extract_loop(i), sym, l)::Loop)
4542
push!(ls.loopsymbol_offsets, ls.loopsymbol_offsets[end]+1)
4643
end
4744
end
4845
end
49-
function add_loops!(ls, i, sym, l::Type{<:CartesianIndices{N}}) where N
46+
function add_loops!(ls::LoopSet, i::Int, sym::Symbol, l::Type{CartesianIndices{N,T}}) where {N,T}
47+
ssym = String(sym)
5048
for k = N:-1:1
51-
add_loop!(ls, Loop(ls, i, k, sym, l)::Loop)
49+
axisexpr = Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:ref, Expr(:., Expr(:ref, :lb, i), QuoteNode(:indices)), k))
50+
add_loop!(ls, Loop(ls, axisexpr, Symbol(ssym*'#'*string(k)*'#'), T.parameters[k])::Loop)
5251
end
5352
push!(ls.loopsymbol_offsets, ls.loopsymbol_offsets[end]+N)
5453
end

test/gemm.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testset "GEMM" begin
22
# using LoopVectorization, LinearAlgebra, Test; T = Float64
3-
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 4)
3+
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (5, 5)
44
AmulBtq1 = :(for m 1:size(A,1), n 1:size(B,2)
55
C[m,n] = zeroB
66
for k 1:size(A,2)
@@ -291,7 +291,11 @@
291291
C[m,n] = Cₘₙ
292292
end)
293293
lsr2amb = LoopVectorization.LoopSet(r2ambq);
294-
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :k, :n, :m, Unum & -2, Tnum)
294+
if LoopVectorization.VectorizationBase.REGISTER_COUNT == 32
295+
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :k, :n, :m, 3, 6)
296+
else
297+
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :k, :n, :m, 2, 4)
298+
end
295299
function rank2AmulBavx!(C, Aₘ, Aₖ, B)
296300
@avx for m 1:size(C,1), n 1:size(C,2)
297301
Cₘₙ = zero(eltype(C))

test/gemv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using LoopVectorization
22
using Test
33

44
@testset "GEMV" begin
5-
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 4)
5+
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 6)
66
gemvq = :(for i eachindex(y)
77
yᵢ = 0.0
88
for j eachindex(x)

test/miscellaneous.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
@testset "Miscellaneous" begin
3-
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 4)
3+
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 6)
44
dot3q = :(for m 1:M, n 1:N
55
s += x[m] * A[m,n] * y[n]
66
end);

0 commit comments

Comments
 (0)