Skip to content

Commit 93afab1

Browse files
committed
Tests on lines 775 and 900 are currently broken.
1 parent 7b1fc49 commit 93afab1

File tree

7 files changed

+178
-78
lines changed

7 files changed

+178
-78
lines changed

src/LoopVectorization.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ module LoopVectorization
33
using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools, Parameters
44
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector_load_expr,
55
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valadd, valsub, _MM,
6-
maybestaticlength, maybestaticsize, Static, staticm1, subsetview
6+
maybestaticlength, maybestaticsize, staticm1, subsetview,
7+
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange,
8+
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct
79
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod
810
using Base.Broadcast: Broadcasted, DefaultArrayStyle
911
using LinearAlgebra: Adjoint, Transpose
@@ -29,9 +31,11 @@ include("determinestrategy.jl")
2931
include("lowering.jl")
3032
include("constructors.jl")
3133
include("map.jl")
32-
include("_avx.jl")
34+
# include("_avx.jl")
35+
include("condense_loopset.jl")
36+
include("reconstruct_loopset.jl")
3337

34-
export @_avx, _avx
38+
export @_avx, _avx, @_avx_, avx_!
3539

3640
# include("precompile.jl")
3741
# _precompile_()

src/condense_loopset.jl

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ function ArrayRefStruct(ls::LoopSet, mref::ArrayReferenceMeta, arraysymbolinds::
2424
for (n,ind) enumerate(@view(indv[start:end]))
2525
index_types <<= 8
2626
indices <<= 8
27-
if mref.loopindex[n]
27+
if mref.loopedindex[n]
2828
index_types |= LoopIndex
29+
indices |= getloopid(ls, ind)
2930
else
30-
parent = getop(opdict, ind, nothing)
31+
parent = get(ls.opdict, ind, nothing)
3132
if parent === nothing
3233
index_types |= SymbolicIndex
3334
indices |= findindoradd!(arraysymbolinds, ind)
@@ -41,14 +42,19 @@ function ArrayRefStruct(ls::LoopSet, mref::ArrayReferenceMeta, arraysymbolinds::
4142
end
4243

4344
struct OperationStruct
44-
instruction::Instruction
45+
# instruction::Instruction
4546
loopdeps::UInt64
4647
reduceddeps::UInt64
4748
childdeps::UInt64
4849
parents::UInt64
4950
node_type::OperationType
5051
array::UInt8
52+
symid::UInt8
5153
end
54+
isload(os::OperationStruct) = os.node_type == memload
55+
isstore(os::OperationStruct) = os.node_type == memstore
56+
iscompute(os::OperationStruct) = os.node_type == compute
57+
isconstant(os::OperationStruct) = os.node_type == constant
5258
function findmatchingarray(ls::LoopSet, array::Symbol)
5359
id = 0x01
5460
for as ls.refs_aliasing_syms
@@ -80,19 +86,19 @@ function parents_uint(ls::LoopSet, op::Operation)
8086
p = zero(UInt64)
8187
for parent parents(op)
8288
p <<= 8
83-
p |= identifier(op)
89+
p |= identifier(parent)
8490
end
8591
p
8692
end
87-
function OperationStruct(ls::LoopSet, op::Operation)
93+
function OperationStruct!(varnames::Vector{Symbol}, ls::LoopSet, op::Operation)
8894
instr = instruction(op)
8995
ld = loopdeps_uint(ls, op)
9096
rd = reduceddeps_uint(ls, op)
9197
cd = childdeps_uint(ls, op)
9298
p = parents_uint(ls, op)
9399
array = accesses_memory(op) ? findmatchingarray(ls, vptr(op.ref)) : 0x00
94100
OperationStruct(
95-
instr, ld, rd, cd, p, op.node_type, array
101+
ld, rd, cd, p, op.node_type, array, findindoradd!(varnames, name(op))
96102
)
97103
end
98104
## turn a LoopSet into a type object which can be used to reconstruct the LoopSet.
@@ -112,12 +118,12 @@ function loop_boundaries(ls::LoopSet)
112118
else
113119
Expr(:call, Expr(:call, :(:), loop.startsym, loop.stopsym))
114120
end
115-
push!(lbd, lexpr)
121+
push!(lbd.args, lexpr)
116122
end
117123
lbd
118124
end
119125

120-
function argmeta_and_costs_description(ls::LoopSet, arraysymbolinds)
126+
function argmeta_and_consts_description(ls::LoopSet, arraysymbolinds)
121127
Expr(
122128
:curly, :Tuple,
123129
length(arraysymbolinds),
@@ -130,14 +136,22 @@ function argmeta_and_costs_description(ls::LoopSet, arraysymbolinds)
130136
)
131137
end
132138

133-
function loopset_return_value(ls::LoopSet)
139+
function loopset_return_value(ls::LoopSet, ::Val{extract}) where {extract}
134140
if length(ls.outer_reductions) == 1
135-
Expr(:call, :extract_data, Symbol(mangledvar(operations(ls)[ls.outer_reductions[1]]), 0))
141+
if extract
142+
Expr(:call, :extract_data, Symbol(mangledvar(operations(ls)[ls.outer_reductions[1]]), 0))
143+
else
144+
Symbol(mangledvar(operations(ls)[ls.outer_reductions[1]]), 0)
145+
end
136146
elseif length(ls.outer_reductions) > 1
137147
ret = Expr(:tuple)
138148
ops = operations(ls)
139149
for or ls.outer_reductions
140-
push!(ret.args, Expr(:call, :extract_data, Symbol(mangledvar(ops[or]), 0)))
150+
if extract
151+
push!(ret.args, Expr(:call, :extract_data, Symbol(mangledvar(ops[or]), 0)))
152+
else
153+
push!(ret.args, Symbol(mangledvar(ops[or]), 0))
154+
end
141155
end
142156
ret
143157
else
@@ -149,14 +163,20 @@ end
149163
# Try to condense in type stable manner
150164
function generate_call(ls::LoopSet)
151165
operation_descriptions = Expr(:curly, :Tuple)
152-
foreach(op -> push!(operation_descriptions.args, OperationStruct(ls, op)), operations(ls))
166+
varnames = Symbol[]
167+
for op operations(ls)
168+
instr = instruction(op)
169+
push!(operation_descriptions.args, QuoteNode(instr.mod))
170+
push!(operation_descriptions.args, QuoteNode(instr.instr))
171+
push!(operation_descriptions.args, OperationStruct!(varnames, ls, op))
172+
end
153173
arraysymbolinds = Symbol[]
154174
arrayref_descriptions = Expr(:curly, :Tuple)
155175
foreach(ref -> push!(arrayref_descriptions.args, ArrayRefStruct(ls, ref, arraysymbolinds)), ls.refs_aliasing_syms)
156176
argmeta = argmeta_and_consts_description(ls, arraysymbolinds)
157177
loop_bounds = loop_boundaries(ls)
158178

159-
q = Expr(:call, :_avx!, operation_descriptions, arrayref_descriptions, argmeta, loop_bounds)
179+
q = Expr(:call, lv(:_avx_!), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds)
160180

161181
foreach(ref -> push!(q.args, vptr(ref)), ls.refs_aliasing_syms)
162182
foreach(is -> push!(q.args, last(is)), ls.preamble_symsym)
@@ -166,16 +186,29 @@ end
166186

167187
function setup_call(ls::LoopSet)
168188
call = generate_call(ls)
169-
retv = loopset_return_value(ls)
170-
q = Expr(:block,gc_preserve(ls, Expr(:(=), retv, call)))
189+
hasouterreductions = length(ls.outer_reductions) > 0
190+
if hasouterreductions
191+
retv = loopset_return_value(ls, Val(false))
192+
call = Expr(:(=), retv, call)
193+
end
194+
q = Expr(:block,gc_preserve(ls, call))
195+
outer_reducts = Expr(:local)
171196
for or ls.outer_reductions
172197
op = ls.operations[or]
173198
var = name(op)
174199
mvar = mangledvar(op)
175200
instr = instruction(op)
176-
push!(q.args, Expr(:(=), var, Expr(:call, REDUCTION_SCALAR_COMBINE[instr], var, Symbol(mvar, 0))))
201+
out = Symbol(mvar, 0)
202+
push!(outer_reducts.args, out)
203+
# push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), Expr(:call, lv(:SVec), out), var)))
204+
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), out, var)))
177205
end
178-
206+
hasouterreductions && pushpreamble!(ls, outer_reducts)
207+
append!(ls.preamble.args, q.args)
208+
ls.preamble
179209
end
180210

211+
macro _avx(q)
212+
esc(setup_call(LoopSet(q)))
213+
end
181214

src/graphs.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,12 @@ end
195195
# @assert id !== nothing
196196
# ls.refs_aliasing_syms[id]
197197
# end
198-
pushpreamble!(ls::LoopSet, op::Operation, v::Symbol) = push!(ls.preamble_symsym, (identifier(op),v))
198+
function pushpreamble!(ls::LoopSet, op::Operation, v::Symbol)
199+
if v !== mangledvar(op)
200+
push!(ls.preamble_symsym, (identifier(op),v))
201+
end
202+
nothing
203+
end
199204
pushpreamble!(ls::LoopSet, op::Operation, v::Integer) = push!(ls.preamble_symint, (identifier(op),convert(Int,v)))
200205
pushpreamble!(ls::LoopSet, op::Operation, v::Real) = push!(ls.preamble_symfloat, (identifier(op),convert(Float64,v)))
201206
pushpreamble!(ls::LoopSet, ex::Expr) = push!(ls.preamble.args, ex)
@@ -405,7 +410,7 @@ function add_loop!(ls::LoopSet, q::Expr, elementbytes::Int = 8)
405410
push!(ls, q, elementbytes)
406411
end
407412
end
408-
function add_loop!(ls::LoopSet, loop::Loop, itersym::Symbol = loop.itersym)
413+
function add_loop!(ls::LoopSet, loop::Loop, itersym::Symbol = loop.itersymbol)
409414
push!(ls.loopsymbols, itersym)
410415
push!(ls.loops, loop)
411416
nothing

src/lowering.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ function symbolind(ind::Symbol, op::Operation, td::UnrollArgs)
3232
else
3333
mangledvar(parent)
3434
end
35-
if unrolled loopdependencies(parent)
36-
pvar = Symbol(pvar, u)
37-
end
35+
pvar = unrolled loopdependencies(parent) ? Symbol(pvar, u) : pvar
3836
Expr(:call, :-, pvar, one(Int32))
3937
end
4038
function mem_offset(op::Operation, td::UnrollArgs)
@@ -473,7 +471,9 @@ function lower_constant!(
473471
instruction = op.instruction
474472
mvar = variable_name(op, suffix)
475473
constsym = instruction.instr
474+
# constsym = mangledvar(op)
476475
if vectorized loopdependencies(op) || vectorized reduceddependencies(op)
476+
# call = Expr(:call, lv(:vbroadcast), W, mangledvar(op))
477477
call = Expr(:call, lv(:vbroadcast), W, constsym)
478478
for u 0:U-1
479479
push!(q.args, Expr(:(=), Symbol(mvar, u), call))
@@ -882,7 +882,15 @@ function lower_licm_constants!(ls::LoopSet)
882882
for (id,sym) ls.preamble_symsym
883883
op = ops[id]
884884
mv = mangledvar(op)
885-
mv === sym || pushpreamble!(ls, Expr(:(=), mv, sym))
885+
# mv === sym || pushpreamble!(ls, Expr(:(=), instruction(op).instr, sym))
886+
if mv !== sym
887+
if length(ls.includedarrays) == 0 && instruction(op).mod === Symbol("")
888+
pushpreamble!(ls, Expr(:(=), instruction(op).instr, sym))
889+
else
890+
pushpreamble!(ls, Expr(:(=), mv, sym))
891+
end
892+
# pushpreamble!(ls, instruction(op) === LOOPCONSTANT ? Expr(:(=), mv, sym) : Expr(:(=), instruction(op).instr, sym))
893+
end
886894
end
887895
for (id,intval) ls.preamble_symint
888896
op = ops[id]
@@ -903,7 +911,7 @@ function lower_licm_constants!(ls::LoopSet)
903911
end
904912
function setup_preamble!(ls::LoopSet, W::Symbol, typeT::Symbol, vectorized::Symbol, unrolled::Symbol, tiled::Symbol, U::Int)
905913
# println("Setup preamble")
906-
push!(ls.preamble.args, Expr(:(=), typeT, determine_eltype(ls)))
914+
length(ls.includedarrays) == 0 || push!(ls.preamble.args, Expr(:(=), typeT, determine_eltype(ls)))
907915
push!(ls.preamble.args, Expr(:(=), W, determine_width(ls, typeT, unrolled)))
908916
lower_licm_constants!(ls)
909917
pushpreamble!(ls, definemask(getloop(ls, vectorized), W, U > 1 && unrolled === vectorized))

src/memory_ops_common.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ function add_vptr!(ls::LoopSet, array::Symbol, vptrarray::Symbol = vptr(array))
77
end
88
nothing
99
end
10-
function subset_vptr!(ls::LoopSet, vptr::Symbol, indnum::Int, ind::Integer)
10+
function subset_vptr!(ls::LoopSet, vptr::Symbol, indnum::Int, ind::Union{Symbol,Int})
1111
subsetvptr = Symbol(vptr, "_subset_$(indnum)_with_$(ind)##")
12-
pushpreamble!(ls, Expr(:(=), subsetvptr, Expr(:call, lv(:subsetview), vptr, Expr(:call, Expr(:curly, :Val, indnum)), ind)))
12+
inde = ind isa Symbol ? Expr(:call, :-, ind, 1) : ind - 1
13+
pushpreamble!(ls, Expr(:(=), subsetvptr, Expr(:call, lv(:subsetview), vptr, Expr(:call, Expr(:curly, :Val, indnum)), inde)))
1314
subsetvptr
1415
end
15-
16+
const DISCONTIGUOUS = Symbol("##DISCONTIGUOUSSUBARRAY##")
1617
function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementbytes::Int = 8)
1718
vptrarray = vptr(array)
1819
add_vptr!(ls, array, vptrarray) # now, subset
@@ -23,24 +24,26 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
2324
loopdependencies = Symbol[]
2425
reduceddeps = Symbol[]
2526
loopset = ls.loopsymbols
27+
ninds = 1
2628
for ind rawindices
2729
if ind isa Integer # subset
28-
vptrarray = subset_vptr!(ls, vptrarray, length(indices) + 1, ind - 1)
29-
length(indices) == 0 && push!(indices, Symbol("##DISCONTIGUOUSSUBARRAY##"))
30+
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind)
31+
length(indices) == 0 && push!(indices, DISCONTIGUOUS)
32+
elseif ind isa Expr || (ind isa Symbol && ind keys(ls.opdict))
33+
parent = add_operation!(ls, gensym(:indexpr), ind, elementbytes)
34+
pushparent!(parents, loopdependencies, reduceddeps, parent)
35+
# var = get(ls.opdict, ind, nothing)
36+
push!(indices, name(parent)); ninds += 1
37+
push!(loopedindex, false)
3038
elseif ind isa Symbol
31-
push!(indices, ind)
3239
if ind loopset
40+
push!(indices, ind); ninds += 1
3341
push!(loopedindex, true)
3442
push!(loopdependencies, ind)
3543
else
36-
push!(loopedindex, false)
44+
vptrarray = subset_vptr!(ls, vptrarray, ninds, ind)
45+
length(indices) == 0 && push!(indices, DISCONTIGUOUS)
3746
end
38-
elseif ind isa Expr
39-
parent = add_operation!(ls, gensym(:indexpr), ind, elementbytes)
40-
pushparent!(parents, loopdependencies, reduceddeps, parent)
41-
# var = get(ls.opdict, ind, nothing)
42-
push!(indices, name(parent))#mangledvar(parent)
43-
push!(loopedindex, false)
4447
else
4548
throw("Unrecognized loop index: $ind.")
4649
end

0 commit comments

Comments
 (0)