Skip to content

Commit 39b069b

Browse files
committed
Big fixes; tests pass locally on latest VectorizationBase and SIMDPirates.
1 parent 3b5c7ce commit 39b069b

File tree

7 files changed

+162
-41
lines changed

7 files changed

+162
-41
lines changed

src/LoopVectorization.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export LowDimArray, stridedpointer, vectorizable,
1717
vmap, vmap!
1818

1919

20+
include("map.jl")
2021
include("costs.jl")
2122
include("operations.jl")
2223
include("graphs.jl")
@@ -29,11 +30,9 @@ include("add_ifelse.jl")
2930
include("broadcast.jl")
3031
include("determinestrategy.jl")
3132
include("lowering.jl")
32-
include("constructors.jl")
33-
include("map.jl")
34-
# include("_avx.jl")
3533
include("condense_loopset.jl")
3634
include("reconstruct_loopset.jl")
35+
include("constructors.jl")
3736

3837
export @_avx, _avx, @_avx_, avx_!
3938

src/broadcast.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ end
3838

3939
@inline *ˡ(a::A, b::B) where {A,B} = Product{A,B}(a, b)
4040
@inline Base.Broadcast.broadcasted(::typeof(*ˡ), a::A, b::B) where {A, B} = Product{A,B}(a, b)
41-
const = *ˡ
4241
# TODO: Need to make this handle A or B being (1 or 2)-D broadcast objects.
4342
function add_broadcast!(
4443
ls::LoopSet, mC::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},

src/condense_loopset.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ end
161161

162162

163163
# Try to condense in type stable manner
164-
function generate_call(ls::LoopSet)
164+
function generate_call(ls::LoopSet, IUT)
165165
operation_descriptions = Expr(:curly, :Tuple)
166166
varnames = Symbol[]
167167
for op operations(ls)
@@ -176,16 +176,16 @@ function generate_call(ls::LoopSet)
176176
argmeta = argmeta_and_consts_description(ls, arraysymbolinds)
177177
loop_bounds = loop_boundaries(ls)
178178

179-
q = Expr(:call, lv(:_avx_!), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds)
179+
q = Expr(:call, lv(:_avx_!), Expr(:call, Expr(:curly, :Val, IUT)), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds)
180180

181181
foreach(ref -> push!(q.args, vptr(ref)), ls.refs_aliasing_syms)
182182
foreach(is -> push!(q.args, last(is)), ls.preamble_symsym)
183183
append!(q.args, arraysymbolinds)
184184
q
185185
end
186186

187-
function setup_call(ls::LoopSet)
188-
call = generate_call(ls)
187+
function setup_call(ls::LoopSet, inline = one(Int8), U = zero(Int8), T = zero(Int8))
188+
call = generate_call(ls, (inline,U,T))
189189
hasouterreductions = length(ls.outer_reductions) > 0
190190
if hasouterreductions
191191
retv = loopset_return_value(ls, Val(false))
@@ -208,7 +208,4 @@ function setup_call(ls::LoopSet)
208208
ls.preamble
209209
end
210210

211-
macro _avx(q)
212-
esc(setup_call(LoopSet(q)))
213-
end
214211

src/constructors.jl

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,69 @@ true
9292
"""
9393
macro avx(q)
9494
q2 = if q.head === :for
95-
lower(LoopSet(q))
95+
setup_call(LoopSet(q))
9696
else# assume broadcast
9797
substitute_broadcast(q)
9898
end
9999
esc(q2)
100100
end
101+
102+
function check_inline(arg)
103+
a1 = (arg.args[1])::Symbol
104+
a1 === :inline || return nothing
105+
(arg.args[2])::Bool
106+
end
107+
function check_tile(arg)
108+
a1 = (arg.args[1])::Symbol
109+
a1 === :tile || return nothing
110+
U = convert(Int8, tup.args[1])
111+
T = convert(Int8, tup.args[2])
112+
U, T
113+
end
114+
function check_unroll(arg)
115+
a1 = (arg.args[1])::Symbol
116+
a1 === :unroll || return nothing
117+
convert(Int8, arg.args[2])
118+
end
119+
function check_macro_kwarg(arg, inline::Int8 = one(Int8), U::Int8 = zero(Int8), T::Int8 = zero(Int8))
120+
@assert arg.head === :(=)
121+
i = check_inline(arg)
122+
if i !== nothing
123+
inline = i ? Int8(2) : Int8(-1)
124+
else
125+
u = check_unroll(arg)
126+
if u !== nothing
127+
U = u
128+
T = Int8(-1)
129+
else
130+
U, T = check_tile(arg)
131+
end
132+
end
133+
inline, U, T
134+
end
101135
macro avx(arg, q)
102136
@assert q.head === :for
103137
@assert arg.head === :(=)
104-
local U::Int, T::Int
105-
if arg.args[1] === :unroll
106-
U = arg.args[2]
107-
T = -1
108-
elseif arg.args[1] === :tile
109-
tup = arg.args[2]
110-
@assert tup.head === :tuple
111-
U = tup.args[1]
112-
T = tup.args[2]
113-
end
138+
inline, U, T = check_macro_kwarg(arg)
139+
esc(setup_call(LoopSet(q), inline, U, T))
140+
end
141+
macro avx(arg1, arg2, q)
142+
@assert q.head === :for
143+
inline, U, T = check_macro_kwarg(arg1)
144+
inline, U, T = check_macro_kwarg(arg2, inline, U, T)
145+
esc(setup_call(LoopSet(q), inline, U, T))
146+
end
147+
148+
149+
150+
macro _avx(q)
151+
esc(lower(LoopSet(q)))
152+
end
153+
macro _avx(arg, q)
154+
@assert q.head === :for
155+
inline, U, T = check_macro_kwarg(arg)
114156
esc(lower(LoopSet(q), U, T))
115157
end
116-
158+
159+
117160

src/lowering.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,22 +1001,41 @@ end
10011001

10021002

10031003

1004+
function maybeinline!(q, ls, istiled, prependinlineORorUnroll)
1005+
if prependinlineORorUnroll == 1
1006+
if !istiled | length(ls.outer_reductions) > 1
1007+
pushfirst!(q.args, Expr(:meta, :inline))
1008+
end
1009+
elseif prependinlineORorUnroll == 2
1010+
pushfirst!(q.args, Expr(:meta, :inline))
1011+
elseif prependinlineORorUnroll == -1
1012+
pushfirst!(q.args, Expr(:meta, :noinline))
1013+
end
1014+
q
1015+
end
10041016
# Here, we have to figure out how to convert the loopset into a vectorized expression.
10051017
# This must traverse in a parent -> child pattern
10061018
# but order is also dependent on which loop inds they depend on.
1007-
# Requires sorting
1008-
function lower(ls::LoopSet)
1019+
# Requires sorting
1020+
# values for prependinlineORorUnroll:
1021+
# -1 : force @noinline
1022+
# 0 : nothing
1023+
# 1 : inline if length(ls.outer_reductions) > 1
1024+
# 2 : force inline
1025+
function lower(ls::LoopSet, prependinlineORorUnroll = 0)
10091026
order, vectorized, U, T = choose_order(ls)
10101027
istiled = T != -1
10111028
fillorder!(ls, order, istiled)
1012-
istiled ? lower_tiled(ls, vectorized, U, T) : lower_unrolled(ls, vectorized, U)
1029+
q = istiled ? lower_tiled(ls, vectorized, U, T) : lower_unrolled(ls, vectorized, U)
1030+
maybeinline!(q, ls, istiled, prependinlineORorUnroll)
10131031
end
1014-
function lower(ls::LoopSet, U, T = -1)
1032+
function lower(ls::LoopSet, U, T, prependinlineORorUnroll = 0)
10151033
num_loops(ls) == 1 && @assert T == -1
10161034
order, vectorized, _U, _T = choose_order(ls)
10171035
istiled = T != -1
10181036
fillorder!(ls, order, istiled)
1019-
istiled ? lower_tiled(ls, vectorized, U, T) : lower_unrolled(ls, vectorized, U)
1037+
q = istiled ? lower_tiled(ls, vectorized, Int(U), Int(T)) : lower_unrolled(ls, vectorized, Int(U))
1038+
maybeinline!(q, ls, istiled, prependinlineORorUnroll)
10201039
end
10211040

10221041
Base.convert(::Type{Expr}, ls::LoopSet) = lower(ls)

src/reconstruct_loopset.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function sizeofeltypes(v, num_arrays)::Int
190190
sizeof(T)
191191
end
192192

193-
function avx_body(instr, ops, arf, AM, LB, vargs)
193+
function avx_body(IUT, instr, ops, arf, AM, LB, vargs)
194194
ls = LoopSet()
195195
# elementbytes = mapreduce(elbytes, min, @view(vargs[Base.OneTo(length(arf))]))::Int
196196
num_arrays = length(arf)
@@ -203,20 +203,20 @@ function avx_body(instr, ops, arf, AM, LB, vargs)
203203
add_ops!(ls, instr, ops, mrefs, opsymbols, elementbytes)
204204
add_array_symbols!(ls, arraysymbolinds, num_arrays + length(ls.preamble_symsym))
205205
pushpreamble!(ls, Expr(:(=), ls.T, Expr(:call, :promote_type, [Expr(:call, :eltype, vptr(mref)) for mref mrefs]...)))
206-
q = lower(ls)
207-
push!(q.args, loopset_return_value(ls, Val(true)))
206+
inline, U, T = IUT
207+
q = iszero(U) ? lower(ls, inline) : lower(ls, U, T, inline)
208+
length(ls.outer_reductions) > 0 ? push!(q.args, loopset_return_value(ls, Val(true))) : push!(q.args, nothing)
208209
# @show q
209-
length(ls.outer_reductions) > 1 && pushfirst!(q.args, Expr(:meta, :inline))
210210
q
211211
end
212212

213-
@generated function _avx_!(::Type{OPS}, ::Type{ARF}, ::Type{AM}, lb::LB, vargs...) where {OPS, ARF, AM, LB}
213+
@generated function _avx_!(::Val{IUT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, lb::LB, vargs...) where {IUT, OPS, ARF, AM, LB}
214214
OPSsv = OPS.parameters
215215
nops = length(OPSsv) ÷ 3
216216
instr = Instruction[Instruction(OPSsv[3i+1], OPSsv[3i+2]) for i 0:nops-1]
217217
ops = OperationStruct[ OPSsv[3i] for i 1:nops ]
218218
avx_body(
219-
instr, ops,
219+
IUT, instr, ops,
220220
ArrayRefStruct[ARF.parameters...],
221221
AM.parameters, LB.parameters, vargs
222222
)

test/runtests.jl

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -355,47 +355,76 @@ using LinearAlgebra
355355
AmulB!(C2, A, B)
356356
AmulBavx1!(C, A, B)
357357
@test C C2
358+
fill!(C, 999.99); AmulBavx1!(C, At', B)
359+
@test C C2
358360
fill!(C, 999.99); AmulBavx2!(C, A, B)
359361
@test C C2
362+
fill!(C, 999.99); AmulBavx2!(C, At', B)
363+
@test C C2
360364
fill!(C, 999.99); AmulBavx3!(C, A, B)
361365
@test C C2
366+
fill!(C, 999.99); AmulBavx3!(C, At', B)
367+
@test C C2
362368
fill!(C, 0.0); AmuladdBavx!(C, A, B)
363369
@test C C2
364-
AmuladdBavx!(C, A, B)
370+
AmuladdBavx!(C, At', B)
365371
@test C 2C2
366372
AmuladdBavx!(C, A, B, -1)
367373
@test C C2
374+
AmuladdBavx!(C, At', B, -2)
375+
@test C -C2
368376
fill!(C, 9999.999); AtmulBavx!(C, At, B)
369377
@test C C2
378+
fill!(C, 9999.999); AtmulBavx!(C, A', B)
379+
@test C C2
370380
fill!(C, 9999.999); mulCAtB_2x2blockavx!(C, At, B);
371381
@test C C2
382+
fill!(C, 9999.999); mulCAtB_2x2blockavx!(C, A', B);
383+
@test C C2
372384
end
373385
@time @testset "_avx $T gemm" begin
374-
fill!(C, 999.99); AmulB_avx1!(C, A, B)
386+
AmulB_avx1!(C, A, B)
387+
@test C C2
388+
fill!(C, 999.99); AmulB_avx1!(C, At', B)
375389
@test C C2
376390
fill!(C, 999.99); AmulB_avx2!(C, A, B)
377391
@test C C2
392+
fill!(C, 999.99); AmulB_avx2!(C, At', B)
393+
@test C C2
378394
fill!(C, 999.99); AmulB_avx3!(C, A, B)
379395
@test C C2
396+
fill!(C, 999.99); AmulB_avx3!(C, At', B)
397+
@test C C2
380398
fill!(C, 0.0); AmuladdB_avx!(C, A, B)
381399
@test C C2
382-
AmuladdB_avx!(C, A, B)
400+
AmuladdB_avx!(C, At', B)
383401
@test C 2C2
384402
AmuladdB_avx!(C, A, B, -1)
385403
@test C C2
404+
AmuladdB_avx!(C, At', B, -2)
405+
@test C -C2
386406
fill!(C, 9999.999); AtmulB_avx!(C, At, B)
387407
@test C C2
408+
fill!(C, 9999.999); AtmulB_avx!(C, A', B)
409+
@test C C2
388410
fill!(C, 9999.999); mulCAtB_2x2block_avx!(C, At, B);
389411
@test C C2
412+
fill!(C, 9999.999); mulCAtB_2x2block_avx!(C, A', B);
413+
@test C C2
390414
end
391415

392416
@time @testset "$T rank2mul" begin
393417
Aₘ= rand(R, M, 2); Aₖ = rand(R, 2, K);
418+
Aₖ′ = copy(Aₖ')
394419
rank2AmulB!(C2, Aₘ, Aₖ, B)
395420
rank2AmulBavx!(C, Aₘ, Aₖ, B)
396421
@test C C2
397422
fill!(C, 9999.999); rank2AmulB_avx!(C, Aₘ, Aₖ, B)
398423
@test C C2
424+
fill!(C, 9999.999); rank2AmulBavx!(C, Aₘ, Aₖ′', B)
425+
@test C C2
426+
fill!(C, 9999.999); rank2AmulB_avx!(C, Aₘ, Aₖ′', B)
427+
@test C C2
399428
end
400429

401430
end
@@ -456,20 +485,51 @@ using LinearAlgebra
456485
end
457486
s
458487
end
459-
function dot_unroll2(x::Vector{T}, y::Vector{T}) where {T<:AbstractFloat}
488+
function dot_unroll2avx(x::Vector{T}, y::Vector{T}) where {T<:AbstractFloat}
460489
z = zero(T)
461490
@avx unroll=2 for i 1:length(x)
462491
z += x[i]*y[i]
463492
end
464493
return z
465494
end
466-
function dot_unroll3(x::Vector{T}, y::Vector{T}) where {T<:AbstractFloat}
495+
@macroexpand @avx unroll=2 for i 1:length(x)
496+
z += x[i]*y[i]
497+
end
498+
function dot_unroll3avx(x::Vector{T}, y::Vector{T}) where {T<:AbstractFloat}
467499
z = zero(T)
468500
@avx unroll=3 for i 1:length(x)
469501
z += x[i]*y[i]
470502
end
471503
return z
472504
end
505+
function dot_unroll2avx_noinline(x::Vector{T}, y::Vector{T}) where {T<:AbstractFloat}
506+
z = zero(T)
507+
@avx inline=true unroll=2 for i 1:length(x)
508+
z += x[i]*y[i]
509+
end
510+
return z
511+
end
512+
function dot_unroll3avx_inline(x::Vector{T}, y::Vector{T}) where {T<:AbstractFloat}
513+
z = zero(T)
514+
@avx unroll=3 inline=false for i 1:length(x)
515+
z += x[i]*y[i]
516+
end
517+
return z
518+
end
519+
function dot_unroll2_avx(x::Vector{T}, y::Vector{T}) where {T<:AbstractFloat}
520+
z = zero(T)
521+
@_avx unroll=2 for i 1:length(x)
522+
z += x[i]*y[i]
523+
end
524+
return z
525+
end
526+
function dot_unroll3_avx(x::Vector{T}, y::Vector{T}) where {T<:AbstractFloat}
527+
z = zero(T)
528+
@_avx unroll=3 for i 1:length(x)
529+
z += x[i]*y[i]
530+
end
531+
return z
532+
end
473533
function complex_dot_soa(
474534
xre::AbstractVector{T}, xim::AbstractVector{T},
475535
yre::AbstractVector{T}, yim::AbstractVector{T}
@@ -495,8 +555,12 @@ using LinearAlgebra
495555
s = mydot(a,b)
496556
@test mydotavx(a,b) s
497557
@test mydot_avx(a,b) s
498-
@test dot_unroll2(a,b) s
499-
@test dot_unroll3(a,b) s
558+
@test dot_unroll2avx(a,b) s
559+
@test dot_unroll3avx(a,b) s
560+
@test dot_unroll2_avx(a,b) s
561+
@test dot_unroll3_avx(a,b) s
562+
@test dot_unroll2avx_noinline(a,b) s
563+
@test dot_unroll3avx_inline(a,b) s
500564
s = myselfdot(a)
501565
@test myselfdotavx(a) s
502566
@test myselfdot_avx(a) s

0 commit comments

Comments
 (0)