Skip to content

Commit ce3147e

Browse files
committed
Reworked handling of loop constants, reducing to two primary strategies (one for loop constants assinged outside of the loops, another for assignments nested within the loops.
1 parent b65eb33 commit ce3147e

12 files changed

+172
-124
lines changed

benchmark/looptests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function jgemm!(C, A, B)
1111
end
1212
end
1313
@inline function gemmavx!(C, A, B)
14-
@avx inline=true for i 1:size(A,1), j 1:size(B,2)
14+
@avx for i 1:size(A,1), j 1:size(B,2)
1515
Cᵢⱼ = zero(eltype(C))
1616
for k 1:size(A,2)
1717
Cᵢⱼ += A[i,k] * B[k,j]

src/LoopVectorization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ include("constructors.jl")
3636

3737
export @_avx, _avx, @_avx_, avx_!
3838

39-
include("precompile.jl")
40-
_precompile_()
39+
# include("precompile.jl")
40+
# _precompile_()
4141

4242
end # module

src/add_compute.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,15 @@ function add_reduction_update_parent!(
9494
reduct_zero = reduction_zero(instrclass)
9595
reductcombine = reduction_scalar_combine(instrclass)
9696
reductsym = gensym(:reduction)
97-
reductinit = add_constant!(ls, Expr(:call, reduct_zero, ls.T), loopdependencies(parent), reductsym, reduct_zero, elementbytes)
97+
reductinit = add_constant!(ls, gensym(:reductzero), loopdependencies(parent), reductsym, elementbytes, :numericconstant)
98+
if reduct_zero === :zero
99+
push!(ls.preamble_zeros, identifier(reductinit))
100+
elseif reduct_zero === :one
101+
push!(ls.preamble_ones, identifier(reductinit))
102+
else
103+
pushpreamble!(ls, Expr(:(=), name(reductinit), reductzero))
104+
pushpreamble!(ls, op, name, reductinit)
105+
end
98106
if isconstant(parent) && reduct_zero === parent.instruction.mod #we can use parent op as initialization.
99107
reductcombine = reduction_combine_to(instrclass)
100108
end

src/add_constants.jl

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,45 @@
11
function add_constant!(ls::LoopSet, var::Symbol, elementbytes::Int = 8)
22
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
3-
pushpreamble!(ls, op, mangledvar(op))
3+
pushpreamble!(ls, op, var)
44
pushop!(ls, op, var)
55
end
66
function add_constant!(ls::LoopSet, var, elementbytes::Int = 8)
7-
sym = gensym(:temp)
8-
op = Operation(length(operations(ls)), sym, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
9-
temp = gensym(:intermediateconst)
10-
pushpreamble!(ls, Expr(:(=), temp, var))
11-
pushpreamble!(ls, op, temp)
12-
pushop!(ls, op, sym)
7+
sym = gensym(:loopconstant)
8+
pushpreamble!(ls, Expr(:(=), sym, var))
9+
add_constant!(ls, sym, elementbytes)
1310
end
1411
function add_constant!(ls::LoopSet, var::Symbol, mpref::ArrayReferenceMetaPosition, elementbytes::Int)
1512
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS, mpref.mref)
1613
add_vptr!(ls, op)
1714
temp = gensym(:intermediateconstref)
1815
pushpreamble!(ls, Expr(:(=), temp, Expr(:call, lv(:load), mpref.mref.ptr, mem_offset(op, UnrollArgs(zero(Int32), Symbol(""), Symbol(""), nothing)))))
1916
pushpreamble!(ls, op, temp)
20-
pushop!(ls, op, var)
17+
pushop!(ls, op, temp)
2118
end
2219
# This version has loop dependencies. var gets assigned to sym when lowering.
23-
function add_constant!(ls::LoopSet, var::Symbol, deps::Vector{Symbol}, sym::Symbol = gensym(:constant), f::Symbol = Symbol(""), elementbytes::Int = 8)
24-
# length(deps) == 0 && push!(ls.preamble.args, Expr(:(=), sym, var))
25-
pushop!(ls, Operation(length(operations(ls)), sym, elementbytes, Instruction(f,var), constant, deps, NODEPENDENCY, NOPARENTS), sym)
20+
function add_constant!(
21+
ls::LoopSet, value::Symbol, deps::Vector{Symbol}, assignedsym::Symbol = gensym(:constant), elementbytes::Int = 8, f::Symbol = Symbol("")
22+
)
23+
op = Operation(length(operations(ls)), assignedsym, elementbytes, Instruction(f, value), constant, deps, NODEPENDENCY, NOPARENTS)
24+
pushop!(ls, op, assignedsym)
25+
end
26+
function add_constant!(
27+
ls::LoopSet, value, deps::Vector{Symbol}, assignedsym::Symbol = gensym(:constant), elementbytes::Int = 8, f::Symbol = Symbol("")
28+
)
29+
intermediary = gensym(:intermediate) # hack, passing meta info here
30+
pushpreamble!(ls, Expr(:(=), intermediary, value))
31+
add_constant!(ls, intermediary, deps, assignedsym, f, elementbytes)
2632
end
27-
2833
function add_constant!(
29-
ls::LoopSet, var, deps::Vector{Symbol}, sym::Symbol = gensym(:constant), f::Symbol = Symbol(""), elementbytes::Int = 8
34+
ls::LoopSet, value::Number, deps::Vector{Symbol}, assignedsym::Symbol = gensym(:constant), elementbytes::Int = 8
3035
)
31-
sym2 = gensym(:temp) # hack, passing meta info here
32-
op = Operation(length(operations(ls)), sym, elementbytes, Instruction(f, sym2), constant, deps, NODEPENDENCY, NOPARENTS)
33-
# @show f, sym, name(op), mangledvar(op)
34-
# temp = gensym(:temp2)
35-
# pushpreamble!(ls, Expr(:(=), temp, var))
36-
pushpreamble!(ls, op, var)#temp)
37-
pushop!(ls, op, sym)
36+
op = add_constant!(ls, gensym(Symbol(value)), deps, assignedsym, elementbytes, :numericconstant)
37+
if iszero(value)
38+
push!(ls.preamble_zeros, identifier(op))
39+
elseif isone(value)
40+
push!(ls.preamble_ones, identifier(op))
41+
else
42+
pushpreamble!(ls, op, value)
43+
end
44+
op
3845
end

src/broadcast.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function add_broadcast!(
7272
loadB = add_broadcast!(ls, gensym(:B), mB, bloopsyms, B, elementbytes)
7373
# set Cₘₙ = 0
7474
# setC = add_constant!(ls, zero(promote_type(recursive_eltype(A), recursive_eltype(B))), cloopsyms, mC, elementbytes)
75-
setC = add_constant!(ls, gensym(:zero), cloopsyms, mC, :zero, elementbytes)
75+
setC = add_constant!(ls, gensym(:zero), cloopsyms, mC, elementbytes, :numericconstant)
7676
push!(ls.preamble_zeros, identifier(setC))
7777
# compute Cₘₙ += Aₘₖ * Bₖₙ
7878
reductop = Operation(
@@ -136,11 +136,9 @@ function add_broadcast!(
136136
add_simple_load!(ls, destname, ArrayReference(bcname, @view(loopsyms[1:N])), elementbytes)
137137
end
138138
function add_broadcast!(
139-
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{T}, elementbytes::Int = 8
140-
) where {T<:Union{Integer,Float32,Float64}}
141-
op = add_constant!(ls, destname, elementbytes) # or replace elementbytes with sizeof(T) ? u
142-
pushpreamble!(ls, Expr(:(=), mangledvar(op), bcname))
143-
op
139+
ls::LoopSet, ::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{T}, elementbytes::Int = 8
140+
) where {T<:Number}
141+
add_constant!(ls, bcname, elementbytes) # or replace elementbytes with sizeof(T) ? u
144142
end
145143
function add_broadcast!(
146144
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
@@ -168,7 +166,7 @@ function add_broadcast!(
168166
reduceddeps = Symbol[]
169167
for (i,arg) enumerate(args)
170168
argname = gensym(:arg)
171-
pushpreamble!(ls, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,@__FILE__), Expr(:(=), argname, Expr(:ref, bcargs, i))))
169+
pushpreamble!(ls, Expr(:(=), argname, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,@__FILE__), Expr(:ref, bcargs, i))))
172170
# dynamic dispatch
173171
parent = add_broadcast!(ls, gensym(:temp), argname, loopsyms, arg, elementbytes)::Operation
174172
pushparent!(parents, deps, reduceddeps, parent)

src/condense_loopset.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function loop_boundaries(ls::LoopSet)
116116
elseif stopexact
117117
Expr(:call, Expr(:curly, lv(:StaticUpperUnitRange), loop.stophint), loop.startsym)
118118
else
119-
Expr(:call, Expr(:call, :(:), loop.startsym, loop.stopsym))
119+
Expr(:call, :(:), loop.startsym, loop.stopsym)
120120
end
121121
push!(lbd.args, lexpr)
122122
end
@@ -159,6 +159,14 @@ function loopset_return_value(ls::LoopSet, ::Val{extract}) where {extract}
159159
end
160160
end
161161

162+
function add_reassigned_syms!(q::Expr, ls::LoopSet)
163+
for op operations(ls)
164+
if isconstant(op)
165+
instr = instruction(op)
166+
(instr == LOOPCONSTANT || instr.mod === :numericconstant) || push!(q.args, instr.instr)
167+
end
168+
end
169+
end
162170

163171
# Try to condense in type stable manner
164172
function generate_call(ls::LoopSet, IUT)
@@ -177,14 +185,14 @@ function generate_call(ls::LoopSet, IUT)
177185
loop_bounds = loop_boundaries(ls)
178186

179187
q = Expr(:call, lv(:_avx_!), Expr(:call, Expr(:curly, :Val, IUT)), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds)
180-
181188
foreach(ref -> push!(q.args, vptr(ref)), ls.refs_aliasing_syms)
182189
foreach(is -> push!(q.args, last(is)), ls.preamble_symsym)
183190
append!(q.args, arraysymbolinds)
191+
add_reassigned_syms!(q, ls)
184192
q
185193
end
186194

187-
function setup_call(ls::LoopSet, inline = one(Int8), U = zero(Int8), T = zero(Int8))
195+
function setup_call(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8))
188196
call = generate_call(ls, (inline,U,T))
189197
hasouterreductions = length(ls.outer_reductions) > 0
190198
if hasouterreductions

src/costs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ function Base.isless(instr1::Instruction, instr2::Instruction)
2323
isless(instr1.mod, instr2.mod)
2424
end
2525
end
26+
Base.isequal(ins1::Instruction, ins2::Instruction) = (ins1.instr === ins2.instr) && (ins1.mod === ins2.mod)
2627

2728
const LOOPCONSTANT = Instruction(gensym())
2829

src/determinestrategy.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,12 @@ function evaluate_cost_tile(
405405
end
406406
for (id, op) enumerate(ops)
407407
iters[id] == -99.9 && continue
408-
descendentsininnerloop[id] || continue
408+
opisininnerloop = descendentsininnerloop[id]
409409
isunrolled = unrolledtiled[1,id]
410410
istiled = unrolledtiled[2,id]
411411
rt, lat, rp = cost(op, vectorized, Wshift, size_T)
412-
# @show instruction(op), rt, lat, rp, iter
412+
rp = opisininnerloop ? rp : 0 # we only care about register pressure within the inner most loop
413413
rt *= iters[id]
414-
# @show isunrolled, istiled
415414
if isunrolled && istiled # no cost decrease; cost must be repeated
416415
cost_vec[1] += rt
417416
reg_pressure[1] += rp

src/graphs.jl

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -261,18 +261,12 @@ looprangehint(ls::LoopSet, s::Symbol) = length(getloop(ls, s))
261261
looprangesym(ls::LoopSet, s::Symbol) = getloop(ls, s).rangesym
262262
function getop(ls::LoopSet, var::Symbol, elementbytes::Int = 8)
263263
get!(ls.opdict, var) do
264-
# might add constant
265-
op = add_constant!(ls, var, elementbytes)
266-
pushpreamble!(ls, op, var)
267-
op
264+
add_constant!(ls, var, elementbytes)
268265
end
269266
end
270267
function getop(ls::LoopSet, var::Symbol, deps, elementbytes::Int = 8)
271268
get!(ls.opdict, var) do
272-
# might add constant
273-
op = add_constant!(ls, var, deps, gensym(:constant), Symbol(""), elementbytes)
274-
pushpreamble!(ls, op, var)
275-
op
269+
add_constant!(ls, var, deps, gensym(:constant), elementbytes)
276270
end
277271
end
278272
getop(ls::LoopSet, i::Int) = ls.operations[i + 1]
@@ -370,7 +364,7 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
370364
else # neither are integers
371365
L = add_loop_bound!(ls, itersym, lower, false)
372366
U = add_loop_bound!(ls, itersym, upper, true)
373-
Loop(itersym, L, N)
367+
Loop(itersym, L, U)
374368
end
375369
elseif f === :eachindex
376370
N = gensym(Symbol(:loop, itersym))
@@ -434,8 +428,7 @@ function add_operation!(
434428
add_load_getindex!(ls, LHS, RHS, elementbytes)
435429
elseif f === :zero || f === :one
436430
c = gensym(f)
437-
# pushpreamble!(ls, Expr(:(=), c, RHS))
438-
op = add_constant!(ls, c, copy(ls.loopsymbols), LHS, f, elementbytes)
431+
op = add_constant!(ls, c, copy(ls.loopsymbols), LHS, elementbytes, :numericconstant)
439432
push!(f === :zero ? ls.preamble_zeros : ls.preamble_ones, identifier(op))
440433
op
441434
else
@@ -459,8 +452,7 @@ function add_operation!(
459452
add_load!(ls, LHS_sym, LHS_ref, elementbytes)
460453
elseif f === :zero || f === :one
461454
c = gensym(f)
462-
# pushpreamble!(ls, Expr(:(=), c, RHS))
463-
op = add_constant!(ls, c, copy(ls.loopsymbols), LHS_sym, f, elementbytes)
455+
op = add_constant!(ls, c, copy(ls.loopsymbols), LHS_sym, elementbytes, :numericconstant)
464456
push!(f === :zero ? ls.preamble_zeros : ls.preamble_ones, identifier(op))
465457
op
466458
else
@@ -488,33 +480,7 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
488480
if RHS isa Expr
489481
add_operation!(ls, LHS, RHS, elementbytes)
490482
else
491-
deps = copy(ls.loopsymbols)
492-
if RHS isa Number
493-
fisone = false
494-
fiszero = false
495-
instr = if iszero(RHS)
496-
fiszero = true
497-
:zero
498-
elseif isone(RHS)
499-
fisone = true
500-
:one
501-
else
502-
:numericalconstant
503-
end
504-
op = add_constant!(ls, gensym(instr), deps, LHS, instr, elementbytes)
505-
if fiszero
506-
push!(ls.preamble_zeros, identifier(op))
507-
elseif fisone
508-
push!(ls.preamble_ones, identifier(op))
509-
else
510-
pushpreamble!(ls, op, RHS)
511-
end
512-
op
513-
elseif RHS isa Symbol
514-
add_constant!(ls, RHS, deps, LHS, :constsym, elementbytes)
515-
else
516-
add_constant!(ls, RHS, deps, LHS, :constmisc, elementbytes)
517-
end
483+
add_constant!(ls, RHS, copy(ls.loopsymbols), LHS, elementbytes)
518484
end
519485
elseif LHS isa Expr
520486
@assert LHS.head === :ref

src/lowering.jl

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -877,36 +877,37 @@ end
877877
@inline sizeequivalentint(::Type{Float16}, x::Int64) = Int16(x)
878878
@inline sizeequivalentint(::Type{Float16}, x::Int32) = Int16(x)
879879

880+
function setop!(ls, op, val)
881+
if instruction(op) === LOOPCONSTANT# && mangledvar(op) !== val
882+
pushpreamble!(ls, Expr(:(=), mangledvar(op), val))
883+
else
884+
pushpreamble!(ls, Expr(:(=), instruction(op).instr, val))
885+
end
886+
nothing
887+
end
888+
function setconstantop!(ls, op, val)
889+
if instruction(op) === LOOPCONSTANT# && mangledvar(op) !== val
890+
pushpreamble!(ls, Expr(:(=), mangledvar(op), val))
891+
end
892+
nothing
893+
end
894+
880895
function lower_licm_constants!(ls::LoopSet)
881896
ops = operations(ls)
882-
for (id,sym) ls.preamble_symsym
883-
op = ops[id]
884-
mv = mangledvar(op)
885-
# mv === sym || pushpreamble!(ls, Expr(:(=), instruction(op).instr, sym))
886-
if mv !== sym
887-
if length(ls.includedarrays) == 0 && instruction(op).mod === Symbol("")
888-
pushpreamble!(ls, Expr(:(=), instruction(op).instr, sym))
889-
else
890-
pushpreamble!(ls, Expr(:(=), mv, sym))
891-
end
892-
# pushpreamble!(ls, instruction(op) === LOOPCONSTANT ? Expr(:(=), mv, sym) : Expr(:(=), instruction(op).instr, sym))
893-
end
897+
for (id, sym) ls.preamble_symsym
898+
setconstantop!(ls, ops[id], sym)
894899
end
895900
for (id,intval) ls.preamble_symint
896-
op = ops[id]
897-
pushpreamble!(ls, Expr(:(=), mangledvar(op), Expr(:call, lv(:sizeequivalentint), ls.T, intval)))
901+
setop!(ls, ops[id], Expr(:call, lv(:sizeequivalentint), ls.T, intval))
898902
end
899903
for (id,floatval) ls.preamble_symfloat
900-
op = ops[id]
901-
pushpreamble!(ls, Expr(:(=), mangledvar(op), Expr(:call, lv(:sizeequivalentfloat), ls.T, floatval)))
904+
setop!(ls, ops[id], Expr(:call, lv(:sizeequivalentfloat), ls.T, intval))
902905
end
903906
for id ls.preamble_zeros
904-
op = ops[id]
905-
pushpreamble!(ls, Expr(:(=), instruction(op).instr, Expr(:call, :zero, ls.T)))
907+
setop!(ls, ops[id], Expr(:call, :zero, ls.T))
906908
end
907909
for id ls.preamble_ones
908-
op = ops[id]
909-
pushpreamble!(ls, Expr(:(=), instruction(op).instr, Expr(:call, :one, ls.T)))
910+
setop!(ls, ops[id], Expr(:call, :one, ls.T))
910911
end
911912
end
912913
function setup_preamble!(ls::LoopSet, W::Symbol, typeT::Symbol, vectorized::Symbol, unrolled::Symbol, tiled::Symbol, U::Int)

0 commit comments

Comments
 (0)