Skip to content

Commit 681f828

Browse files
authored
Merge pull request #75 from timholy/teh/names
Preserve names of arrays and iteration variables across generated boundary
2 parents 515169c + 7b43d46 commit 681f828

File tree

2 files changed

+45
-41
lines changed

2 files changed

+45
-41
lines changed

src/condense_loopset.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
Base.:|(u::Unsigned, it::IndexType) = u | UInt8(it)
55
Base.:(==)(u::Unsigned, it::IndexType) = (u % UInt8) == UInt8(it)
66

7-
struct ArrayRefStruct
7+
struct ArrayRefStruct{array,ptr}
88
index_types::UInt64
99
indices::UInt64
1010
offsets::UInt64
1111
end
12+
array(ar::ArrayRefStruct{a,p}) where {a,p} = a
13+
ptr(ar::ArrayRefStruct{a,p}) where {a,p} = p
1214

1315
function findindoradd!(v::Vector{T}, s::T) where {T}
1416
ind = findfirst(sᵢ -> sᵢ == s, v)
@@ -43,7 +45,7 @@ function ArrayRefStruct(ls::LoopSet, mref::ArrayReferenceMeta, arraysymbolinds::
4345
end
4446
end
4547
end
46-
ArrayRefStruct( index_types, indices, offsets )
48+
ArrayRefStruct{mref.ref.array,mref.ptr}( index_types, indices, offsets )
4749
end
4850

4951
struct OperationStruct <: AbstractLoopOperation
@@ -189,14 +191,15 @@ end
189191
@inline array_wrapper(A::SubArray) = A.indices
190192

191193

192-
194+
# If you change the number of arguments here, make commensurate changes
195+
# to the `insert!` locations in `setup_call_noinline`.
193196
@generated function __avx__!(
194-
::Val{UT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, lb::LB,
197+
::Val{UT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM}, lb::LB,
195198
::Val{AR}, ::Val{D}, ::Val{IND}, subsetvals, arraydescript, vargs::Vararg{<:Any,N}
196-
) where {UT, OPS, ARF, AM, LB, N, AR, D, IND}
199+
) where {UT, OPS, ARF, AM, LPSYM, LB, N, AR, D, IND}
197200
num_vptrs = length(ARF.parameters)::Int
198201
vptrs = [gensym(:vptr) for _ 1:num_vptrs]
199-
call = Expr(:call, lv(:_avx_!), Val{UT}(), OPS, ARF, AM, :lb)
202+
call = Expr(:call, lv(:_avx_!), Val{UT}(), OPS, ARF, AM, LPSYM, :lb)
200203
for n 1:num_vptrs
201204
push!(call.args, vptrs[n])
202205
end
@@ -241,21 +244,22 @@ function generate_call(ls::LoopSet, IUT, debug::Bool = false)
241244
foreach(ref -> push!(arrayref_descriptions.args, ArrayRefStruct(ls, ref, arraysymbolinds)), ls.refs_aliasing_syms)
242245
argmeta = argmeta_and_consts_description(ls, arraysymbolinds)
243246
loop_bounds = loop_boundaries(ls)
247+
loop_syms = Expr(:curly, :Tuple, map(QuoteNode, ls.loopsymbols)...)
244248
inline, U, T = IUT
245249
if inline | debug
246250
func = debug ? lv(:_avx_loopset_debug) : lv(:_avx_!)
247251
lbarg = debug ? Expr(:call, :typeof, loop_bounds) : loop_bounds
248252
q = Expr(
249253
:call, func, Expr(:call, Expr(:curly, :Val, (U,T))),
250-
operation_descriptions, arrayref_descriptions, argmeta, lbarg
254+
operation_descriptions, arrayref_descriptions, argmeta, loop_syms, lbarg
251255
)
252256
debug && deleteat!(q.args, 2)
253257
foreach(ref -> push!(q.args, vptr(ref)), ls.refs_aliasing_syms)
254258
else
255259
arraydescript = Expr(:tuple)
256260
q = Expr(
257261
:call, lv(:__avx__!), Expr(:call, Expr(:curly, :Val, (U,T))),
258-
operation_descriptions, arrayref_descriptions, argmeta, loop_bounds, arraydescript
262+
operation_descriptions, arrayref_descriptions, argmeta, loop_syms, loop_bounds, arraydescript
259263
)
260264
for array ls.includedactualarrays
261265
push!(q.args, Expr(:call, lv(:unwrap_array), array))
@@ -315,10 +319,10 @@ function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
315319
end
316320
push!(q.args, ex)
317321
end
318-
insert!(call.args, 7, Expr(:call, Expr(:curly, :Val, vptrarrays)))
319-
insert!(call.args, 8, Expr(:call, Expr(:curly, :Val, vptrsubsetdims)))
320-
insert!(call.args, 9, Expr(:call, Expr(:curly, :Val, vptrindices)))
321-
insert!(call.args, 10, vptrsubsetvals)
322+
insert!(call.args, 8, Expr(:call, Expr(:curly, :Val, vptrarrays)))
323+
insert!(call.args, 9, Expr(:call, Expr(:curly, :Val, vptrsubsetdims)))
324+
insert!(call.args, 10, Expr(:call, Expr(:curly, :Val, vptrindices)))
325+
insert!(call.args, 11, vptrsubsetvals)
322326
if hasouterreductions
323327
outer_reducts = Expr(:local)
324328
for or ls.outer_reductions

src/reconstruct_loopset.jl

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
function Loop(ls::LoopSet, l::Int, ::Type{UnitRange{Int}})
2-
start = gensym(:loopstart); stop = gensym(:loopstop)
1+
function Loop(ls::LoopSet, l::Int, sym::Symbol, ::Type{UnitRange{Int}})
2+
start = gensym(String(sym)*"_loopstart"); stop = gensym(String(sym)*"_loopstop")
33
pushpreamble!(ls, Expr(:(=), start, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), Expr(:ref, :lb, l), QuoteNode(:start)))))
44
pushpreamble!(ls, Expr(:(=), stop, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), Expr(:ref, :lb, l), QuoteNode(:stop)))))
5-
Loop(gensym(:n), 0, 1024, start, stop, false, false)::Loop
5+
Loop(sym, 0, 1024, start, stop, false, false)::Loop
66
end
7-
function Loop(ls::LoopSet, l::Int, ::Type{StaticUpperUnitRange{U}}) where {U}
8-
start = gensym(:loopstart)
7+
function Loop(ls::LoopSet, l::Int, sym::Symbol, ::Type{StaticUpperUnitRange{U}}) where {U}
8+
start = gensym(String(sym)*"_loopstart")
99
pushpreamble!(ls, Expr(:(=), start, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), Expr(:ref, :lb, l), QuoteNode(:L)))))
10-
Loop(gensym(:n), U - 1024, U, start, Symbol(""), false, true)::Loop
10+
Loop(sym, U - 1024, U, start, Symbol(""), false, true)::Loop
1111
end
12-
function Loop(ls::LoopSet, l::Int, ::Type{StaticLowerUnitRange{L}}) where {L}
13-
stop = gensym(:loopstop)
12+
function Loop(ls::LoopSet, l::Int, sym::Symbol, ::Type{StaticLowerUnitRange{L}}) where {L}
13+
stop = gensym(String(sym)*"_loopstop")
1414
pushpreamble!(ls, Expr(:(=), stop, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), Expr(:ref, :lb, l), QuoteNode(:U)))))
15-
Loop(gensym(:n), L, L + 1024, Symbol(""), stop, true, false)::Loop
15+
Loop(sym, L, L + 1024, Symbol(""), stop, true, false)::Loop
1616
end
1717
# Is there any likely way to generate such a range?
1818
# function Loop(ls::LoopSet, l::Int, ::Type{StaticLengthUnitRange{N}}) where {N}
@@ -21,19 +21,19 @@ end
2121
# pushpreamble!(ls, Expr(:(=), stop, Expr(:call, :(+), start, N - 1)))
2222
# Loop(gensym(:n), 0, N, start, stop, false, false)::Loop
2323
# end
24-
function Loop(ls, l, ::Type{StaticUnitRange{L,U}}) where {L,U}
25-
Loop(gensym(:n), L, U, Symbol(""), Symbol(""), true, true)::Loop
24+
function Loop(ls, l, sym::Symbol, ::Type{StaticUnitRange{L,U}}) where {L,U}
25+
Loop(sym, L, U, Symbol(""), Symbol(""), true, true)::Loop
2626
end
2727

28-
function add_loops!(ls::LoopSet, LB)
29-
loopsyms = [gensym(:n) for _ eachindex(LB)]
30-
for (i,l) enumerate(LB)
31-
add_loop!(ls, Loop(ls, i, l)::Loop)
28+
function add_loops!(ls::LoopSet, LPSYM, LB)
29+
n = max(length(LPSYM), length(LB))
30+
for i = 1:n
31+
add_loop!(ls, Loop(ls, i, LPSYM[i], LB[i])::Loop)
3232
end
3333
end
34+
3435
function ArrayReferenceMeta(
35-
ls::LoopSet, ar::ArrayRefStruct, arraysymbolinds::Vector{Symbol}, opsymbols::Vector{Symbol},
36-
array::Symbol, vp::Symbol
36+
ls::LoopSet, @nospecialize(ar::ArrayRefStruct), arraysymbolinds::Vector{Symbol}, opsymbols::Vector{Symbol}
3737
)
3838
index_types = ar.index_types
3939
indices = ar.indices
@@ -61,8 +61,8 @@ function ArrayReferenceMeta(
6161
ni -= 1
6262
end
6363
ArrayReferenceMeta(
64-
ArrayReference(array, index_vec, offset_vec),
65-
loopedindex, vp
64+
ArrayReference(array(ar), index_vec, offset_vec),
65+
loopedindex, ptr(ar)
6666
)
6767
end
6868

@@ -105,7 +105,7 @@ end
105105
function create_mrefs!(ls::LoopSet, arf::Vector{ArrayRefStruct}, as::Vector{Symbol}, os::Vector{Symbol}, vargs)
106106
mrefs = Vector{ArrayReferenceMeta}(undef, length(arf))
107107
for i eachindex(arf)
108-
ar = ArrayReferenceMeta(ls, arf[i], as, os, Symbol(""), gensym())
108+
ar = ArrayReferenceMeta(ls, arf[i], as, os)
109109
add_mref!(ls, ar, i, vargs[i])
110110
mrefs[i] = ar
111111
end
@@ -222,11 +222,11 @@ function sizeofeltypes(v, num_arrays)::Int
222222
sizeof(T)
223223
end
224224

225-
function avx_loopset(instr, ops, arf, AM, LB, vargs)
225+
function avx_loopset(instr, ops, arf, AM, LPSYM, LB, vargs)
226226
ls = LoopSet(:LoopVectorization)
227227
num_arrays = length(arf)
228228
elementbytes = sizeofeltypes(vargs, num_arrays)
229-
add_loops!(ls, LB)
229+
add_loops!(ls, LPSYM, LB)
230230
resize!(ls.loop_order, length(LB))
231231
arraysymbolinds = process_metadata!(ls, AM, length(arf))
232232
opsymbols = [gensym(:op) for _ eachindex(ops)]
@@ -245,22 +245,22 @@ function avx_body(ls, UT)
245245
q
246246
end
247247

248-
function _avx_loopset_debug(::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LB}, vargs...) where {UT, OPS, ARF, AM, LB}
249-
@show OPS ARF AM LB vargs
250-
_avx_loopset(OPS.parameters, ARF.parameters, AM.parameters, LB.parameters, typeof.(vargs))
248+
function _avx_loopset_debug(::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM}, ::Type{LB}, vargs...) where {UT, OPS, ARF, AM, LPSYM, LB}
249+
@show OPS ARF AM LPSYM LB vargs
250+
_avx_loopset(OPS.parameters, ARF.parameters, AM.parameters, LPSYM.parameters, LB.parameters, typeof.(vargs))
251251
end
252-
function _avx_loopset(OPSsv, ARFsv, AMsv, LBsv, vargs) where {UT, OPS, ARF, AM, LB}
252+
function _avx_loopset(OPSsv, ARFsv, AMsv, LPSYMsv, LBsv, vargs)
253253
nops = length(OPSsv) ÷ 3
254254
instr = Instruction[Instruction(OPSsv[3i+1], OPSsv[3i+2]) for i 0:nops-1]
255255
ops = OperationStruct[ OPSsv[3i] for i 1:nops ]
256256
avx_loopset(
257257
instr, ops,
258258
ArrayRefStruct[ARFsv...],
259-
AMsv, LBsv, vargs
259+
AMsv, LPSYMsv, LBsv, vargs
260260
)
261261
end
262-
@generated function _avx_!(::Val{UT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, lb::LB, vargs...) where {UT, OPS, ARF, AM, LB}
263-
ls = _avx_loopset(OPS.parameters, ARF.parameters, AM.parameters, LB.parameters, vargs)
262+
@generated function _avx_!(::Val{UT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM}, lb::LB, vargs...) where {UT, OPS, ARF, AM, LPSYM, LB}
263+
ls = _avx_loopset(OPS.parameters, ARF.parameters, AM.parameters, LPSYM.parameters, LB.parameters, vargs)
264264
avx_body(ls, UT)
265265
end
266266

0 commit comments

Comments
 (0)