Skip to content

Commit 7b1fc49

Browse files
committed
Updated reduction to track through multiple operations.
1 parent 671bbdd commit 7b1fc49

File tree

7 files changed

+225
-205
lines changed

7 files changed

+225
-205
lines changed

src/add_compute.jl

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,49 +61,66 @@ function add_reduction!(
6161
end
6262
# pushparent!(parents, deps, reduceddeps, parent)
6363
end
64+
function search_tree(opv::Vector{Operation}, var::Symbol) # relies on cycles being forbidden
65+
for opp opv
66+
name(opp) === var && return true
67+
search_tree(parents(opp), var) && return true
68+
end
69+
false
70+
end
71+
function update_reduction_status!(parentvec::Vector{Operation}, deps::Vector{Symbol}, parent::Symbol)
72+
for opp parentvec
73+
if name(opp) === parent
74+
mergesetv!(reducedchildren(opp), deps)
75+
break
76+
elseif search_tree(parents(opp), parent)
77+
mergesetv!(reducedchildren(opp), deps)
78+
update_reduction_status!(parents(opp), deps, parent)
79+
break
80+
end
81+
end
82+
end
6483
function add_reduction_update_parent!(
6584
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet,
66-
var::Symbol, instr::Symbol, elementbytes::Int = 8
85+
var::Symbol, instr::Symbol, directdependency::Bool, elementbytes::Int = 8
6786
)
6887
parent = getop(ls, var, elementbytes)
69-
isloopconstant = parent.instruction === LOOPCONSTANT
88+
isouterreduction = parent.instruction === LOOPCONSTANT
7089
Instr = Instruction(instr)
90+
instrclass = reduction_instruction_class(Instr) # key allows for faster lookups
7191
# if parent is not an outer reduction...
72-
if !isloopconstant
92+
if !isouterreduction
7393
# and parent is not a reduction_zero
74-
reduct_zero = REDUCTION_ZERO[Instr]
75-
reductcombine::Symbol = @static VERSION < v"1.3" ? last(REDUCTION_SCALAR_COMBINE[Instr].args).value : REDUCTION_SCALAR_COMBINE[Instr].name
94+
reduct_zero = reduction_zero(instrclass)
95+
reductcombine = reduction_scalar_combine(instrclass)
7696
reductsym = gensym(:reduction)
7797
reductinit = add_constant!(ls, Expr(:call, reduct_zero, ls.T), loopdependencies(parent), reductsym, reduct_zero, elementbytes)
7898
if isconstant(parent) && reduct_zero === parent.instruction.mod #we can use parent op as initialization.
79-
reductcombine = REDUCTION_COMBINETO[reductcombine]
80-
# else # we cannot use parent op as initialization.
99+
reductcombine = reduction_combine_to(instrclass)
81100
end
82101
else
83102
reductinit = parent
84103
reductsym = var
85104
reductcombine = Symbol("")
86105
end
87-
# mergesetv!(reduceddeps, deps)
88-
# if length(reduceddependencies(reductinit)) == 0
89-
# setdiffv!(reduceddeps, deps, loopdependencies(reductinit))
90-
# else
91106
setdiffv!(reduceddeps, deps, loopdependencies(reductinit))
92-
# end
93-
# mergesetv!(reduceddependencies(reductinit), reduceddeps)
94-
pushparent!(parents, deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
107+
combineddeps = copy(deps); mergesetv!(combineddeps, reduceddeps)
108+
directdependency && pushparent!(parents, deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
109+
update_reduction_status!(parents, combineddeps, name(reductinit))
110+
# this is the op added by add_compute
95111
op = Operation(length(operations(ls)), reductsym, elementbytes, instr, compute, deps, reduceddeps, parents)
96112
parent.instruction === LOOPCONSTANT && push!(ls.outer_reductions, identifier(op))
97113
opout = pushop!(ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
98-
isloopconstant && return opout
99-
# create child
114+
isouterreduction && return opout
115+
# create child op, which is the reduction combination
100116
childdeps = Symbol[]; childrdeps = Symbol[]; childparents = Operation[]
101117
pushparent!(childparents, childdeps, childrdeps, op) # reduce op
102118
pushparent!(childparents, childdeps, childrdeps, parent) # to
103119
child = Operation(
104120
length(operations(ls)), name(parent), elementbytes, reductcombine, compute, childdeps, childrdeps, childparents
105121
)
106122
pushop!(ls, child, name(parent))
123+
opout
107124
end
108125
function add_compute!(
109126
ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8,
@@ -136,8 +153,8 @@ function add_compute!(
136153
add_parent!(parents, deps, reduceddeps, ls, arg, elementbytes)
137154
end
138155
end
139-
if reduction # arg[reduction] is the reduction
140-
add_reduction_update_parent!(parents, deps, reduceddeps, ls, var, instr, elementbytes)
156+
if reduction || search_tree(parents, var)
157+
add_reduction_update_parent!(parents, deps, reduceddeps, ls, var, instr, reduction, elementbytes)
141158
else
142159
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
143160
pushop!(ls, op, var)

src/condense_loopset.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct OperationStruct
4444
instruction::Instruction
4545
loopdeps::UInt64
4646
reduceddeps::UInt64
47+
childdeps::UInt64
4748
parents::UInt64
4849
node_type::OperationType
4950
array::UInt8
@@ -60,7 +61,8 @@ filled_4byte_chunks(u::UInt64) = 16 - (leading_zeros(u) >>> 2)
6061
filled_8byte_chunks(u::UInt64) = 8 - (leading_zeros(u) >>> 3)
6162

6263
num_loop_deps(os::OperationStruct) = filled_4byte_chunks(os.loopdeps)
63-
num_reduced_deps(os::OperationStruct) = filled_4byte_chunks(os.reduced_deps)
64+
num_reduced_deps(os::OperationStruct) = filled_4byte_chunks(os.reduceddeps)
65+
num_child_deps(os::OperationStruct) = filled_4byte_chunks(os.childdeps)
6466
num_parents(os::OperationStruct) = filled_4byte_chunks(os.parents)
6567

6668
function shifted_loopset(ls::LoopSet, loopsyms::Vector{Symbol})
@@ -73,6 +75,7 @@ function shifted_loopset(ls::LoopSet, loopsyms::Vector{Symbol})
7375
end
7476
loopdeps_uint(ls::LoopSet, op::Operation) = shifted_loopset(ls, loopdependencies(op))
7577
reduceddeps_uint(ls::LoopSet, op::Operation) = shifted_loopset(ls, reduceddependencies(op))
78+
childdeps_uint(ls::LoopSet, op::Operation) = shifted_loopset(ls, reducedchildren(op))
7679
function parents_uint(ls::LoopSet, op::Operation)
7780
p = zero(UInt64)
7881
for parent parents(op)
@@ -85,10 +88,11 @@ function OperationStruct(ls::LoopSet, op::Operation)
8588
instr = instruction(op)
8689
ld = loopdeps_uint(ls, op)
8790
rd = reduceddeps_uint(ls, op)
91+
cd = childdeps_uint(ls, op)
8892
p = parents_uint(ls, op)
8993
array = accesses_memory(op) ? findmatchingarray(ls, vptr(op.ref)) : 0x00
9094
OperationStruct(
91-
instr, ld, rd, p, op.node_type, array
95+
instr, ld, rd, cd, p, op.node_type, array
9296
)
9397
end
9498
## turn a LoopSet into a type object which can be used to reconstruct the LoopSet.

src/costs.jl

Lines changed: 50 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -156,108 +156,58 @@ const COST = Dict{Instruction,InstructionCost}(
156156
# COST[Symbol("typeof(", lower(k), ")")] = v
157157
# end
158158

159-
const CORRESPONDING_REDUCTION = Dict{Instruction,Instruction}(
160-
Instruction(:+) => Instruction(:vsum),
161-
Instruction(:-) => Instruction(:vsum),
162-
Instruction(:*) => Instruction(:vprod),
163-
Instruction(:vadd) => Instruction(:vsum),
164-
Instruction(:vsub) => Instruction(:vsum),
165-
Instruction(:vmul) => Instruction(:vprod),
166-
Instruction(:evadd) => Instruction(:vsum),
167-
Instruction(:evsub) => Instruction(:vsum),
168-
Instruction(:evmul) => Instruction(:vprod),
169-
Instruction(:&) => Instruction(:vall),
170-
Instruction(:|) => Instruction(:vany),
171-
Instruction(:muladd) => Instruction(:vsum),
172-
Instruction(:fma) => Instruction(:vsum),
173-
Instruction(:vmuladd) => Instruction(:vsum),
174-
Instruction(:vfma) => Instruction(:vsum),
175-
Instruction(:vfmadd) => Instruction(:vsum),
176-
Instruction(:vfmsub) => Instruction(:vsum),
177-
Instruction(:vfnmadd) => Instruction(:vsum),
178-
Instruction(:vfnmsub) => Instruction(:vsum),
179-
Instruction(:vfmadd_fast) => Instruction(:vsum),
180-
Instruction(:vfmsub_fast) => Instruction(:vsum),
181-
Instruction(:vfnmadd_fast) => Instruction(:vsum),
182-
Instruction(:vfnmsub_fast) => Instruction(:vsum)
183-
)
184-
const REDUCTION_TRANSLATION = Dict{Instruction,Instruction}(
185-
Instruction(:+) => Instruction(:evadd),
186-
Instruction(:vadd) => Instruction(:evadd),
187-
Instruction(:*) => Instruction(:evmul),
188-
Instruction(:vmul) => Instruction(:evmul),
189-
Instruction(:-) => Instruction(:evadd),
190-
Instruction(:vsub) => Instruction(:evadd),
191-
Instruction(:/) => Instruction(:evmul),
192-
Instruction(:vfdiv) => Instruction(:evmul),
193-
Instruction(:muladd) => Instruction(:evadd),
194-
Instruction(:fma) => Instruction(:evadd),
195-
Instruction(:vmuladd) => Instruction(:evadd),
196-
Instruction(:vfma) => Instruction(:evadd),
197-
Instruction(:vfmadd) => Instruction(:evadd),
198-
Instruction(:vfmsub) => Instruction(:evadd),
199-
Instruction(:vfnmadd) => Instruction(:evadd),
200-
Instruction(:vfnmsub) => Instruction(:evadd),
201-
Instruction(:vfmadd_fast) => Instruction(:evadd),
202-
Instruction(:vfmsub_fast) => Instruction(:evadd),
203-
Instruction(:vfnmadd_fast) => Instruction(:evadd),
204-
Instruction(:vfnmsub_fast) => Instruction(:evadd)
205-
)
206-
const REDUCTION_ZERO = Dict{Instruction,Symbol}(
207-
Instruction(:+) => :zero,
208-
Instruction(:vadd) => :zero,
209-
Instruction(:evadd) => :zero,
210-
Instruction(:*) => :one,
211-
Instruction(:vmul) => :one,
212-
Instruction(:evmul) => :one,
213-
Instruction(:-) => :zero,
214-
Instruction(:vsub) => :zero,
215-
Instruction(:evsub) => :zero,
216-
Instruction(:/) => :one,
217-
Instruction(:vfdiv) => :one,
218-
Instruction(:evfdiv) => :one,
219-
Instruction(:muladd) => :zero,
220-
Instruction(:fma) => :zero,
221-
Instruction(:vmuladd) => :zero,
222-
Instruction(:vfma) => :zero,
223-
Instruction(:vfmadd) => :zero,
224-
Instruction(:vfmsub) => :zero,
225-
Instruction(:vfnmadd) => :zero,
226-
Instruction(:vfnmsub) => :zero,
227-
Instruction(:vfmadd_fast) => :zero,
228-
Instruction(:vfmsub_fast) => :zero,
229-
Instruction(:vfnmadd_fast) => :zero,
230-
Instruction(:vfnmsub_fast) => :zero
231-
)
159+
const ADDITIVE_IN_REDUCTIONS = 1.0
160+
const MULTIPLICATIVE_IN_REDUCTIONS = 2.0
161+
const ANY = 3.0
162+
const ALL = 4.0
232163

233-
const LVGETPROP = @static VERSION < v"1.3" ? Expr : GlobalRef
234-
# Fast functions, because common pattern is
235-
const REDUCTION_SCALAR_COMBINE = Dict{Instruction,LVGETPROP}(
236-
Instruction(:+) => lv(:reduced_add),
237-
Instruction(:vadd) => lv(:reduced_add),
238-
Instruction(:*) => lv(:reduced_prod),
239-
Instruction(:vmul) => lv(:reduced_prod),
240-
Instruction(:-) => lv(:reduced_add),
241-
Instruction(:vsub) => lv(:reduced_add),
242-
Instruction(:/) => lv(:reduced_prod),
243-
Instruction(:vfdiv) => lv(:reduced_prod),
244-
Instruction(:muladd) => lv(:reduced_add),
245-
Instruction(:fma) => lv(:reduced_add),
246-
Instruction(:vmuladd) => lv(:reduced_add),
247-
Instruction(:vfma) => lv(:reduced_add),
248-
Instruction(:vfmadd) => lv(:reduced_add),
249-
Instruction(:vfmsub) => lv(:reduced_add),
250-
Instruction(:vfnmadd) => lv(:reduced_add),
251-
Instruction(:vfnmsub) => lv(:reduced_add),
252-
Instruction(:vfmadd_fast) => lv(:reduced_add),
253-
Instruction(:vfmsub_fast) => lv(:reduced_add),
254-
Instruction(:vfnmadd_fast) => lv(:reduced_add),
255-
Instruction(:vfnmsub_fast) => lv(:reduced_add)
256-
)
257-
const REDUCTION_COMBINETO = Dict{Symbol,Symbol}(
258-
:reduced_add => :reduce_to_add,
259-
:reduced_prod => :reduce_to_prod
164+
const REDUCTION_CLASS = Dict{Instruction,Float64}(
165+
Instruction(:+) => ADDITIVE_IN_REDUCTIONS,
166+
Instruction(:-) => ADDITIVE_IN_REDUCTIONS,
167+
Instruction(:*) => MULTIPLICATIVE_IN_REDUCTIONS,
168+
Instruction(:vadd) => ADDITIVE_IN_REDUCTIONS,
169+
Instruction(:vsub) => ADDITIVE_IN_REDUCTIONS,
170+
Instruction(:vmul) => MULTIPLICATIVE_IN_REDUCTIONS,
171+
Instruction(:evadd) => ADDITIVE_IN_REDUCTIONS,
172+
Instruction(:evsub) => ADDITIVE_IN_REDUCTIONS,
173+
Instruction(:evmul) => MULTIPLICATIVE_IN_REDUCTIONS,
174+
Instruction(:&) => ALL,
175+
Instruction(:|) => ANY,
176+
Instruction(:muladd) => ADDITIVE_IN_REDUCTIONS,
177+
Instruction(:fma) => ADDITIVE_IN_REDUCTIONS,
178+
Instruction(:vmuladd) => ADDITIVE_IN_REDUCTIONS,
179+
Instruction(:vfma) => ADDITIVE_IN_REDUCTIONS,
180+
Instruction(:vfmadd) => ADDITIVE_IN_REDUCTIONS,
181+
Instruction(:vfmsub) => ADDITIVE_IN_REDUCTIONS,
182+
Instruction(:vfnmadd) => ADDITIVE_IN_REDUCTIONS,
183+
Instruction(:vfnmsub) => ADDITIVE_IN_REDUCTIONS,
184+
Instruction(:vfmadd_fast) => ADDITIVE_IN_REDUCTIONS,
185+
Instruction(:vfmsub_fast) => ADDITIVE_IN_REDUCTIONS,
186+
Instruction(:vfnmadd_fast) => ADDITIVE_IN_REDUCTIONS,
187+
Instruction(:vfnmsub_fast) => ADDITIVE_IN_REDUCTIONS
260188
)
189+
reduction_instruction_class(instr::Symbol) = get(REDUCTION_CLASS, Instruction(instr), NaN)
190+
reduction_instruction_class(instr::Instruction) = get(REDUCTION_CLASS, instr, NaN)
191+
function reduction_to_single_vector(x::Float64)
192+
x == 1.0 ? :evadd : x == 2.0 ? :evmul : x == 3.0 ? :vand : x == 4.0 ? :vor : throw("Reduction not found.")
193+
end
194+
reduction_to_single_vector(x) = reduction_to_single_vector(reduction_instruction_class(x))
195+
function reduction_to_scalar(x::Float64)
196+
x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 3.0 ? :vany : x == 4.0 ? :vall : throw("Reduction not found.")
197+
end
198+
reduction_to_scalar(x) = reduction_to_scalar(reduction_instruction_class(x))
199+
function reduction_scalar_combine(x::Float64)
200+
x == 1.0 ? :reduced_add : x == 2.0 ? :reduced_prod : x == 3.0 ? :reduced_any : x == 4.0 ? :reduced_all : throw("Reduction not found.")
201+
end
202+
reduction_scalar_combine(x) = reduction_scalar_combine(reduction_instruction_class(x))
203+
function reduction_combine_to(x::Float64)
204+
x == 1.0 ? :reduce_to_add : x == 2.0 ? :reduce_to_prod : x == 3.0 ? :reduce_to_any : x == 4.0 ? :reduce_to_all : throw("Reduction not found.")
205+
end
206+
reduction_combine_to(x) = reduction_combine_to(reduction_instruction_class(x))
207+
function reduction_zero(x::Float64)
208+
x == 1.0 ? :zero : x == 2.0 ? :one : x == 3.0 ? :false : x == 4.0 ? :true : throw("Reduction not found.")
209+
end
210+
reduction_zero(x) = reduction_zero(reduction_instruction_class(x))
261211

262212
const FUNCTIONSYMBOLS = Dict{Type{<:Function},Instruction}(
263213
typeof(+) => :(+),

src/lowering.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,15 +208,14 @@ function reduce_range!(q::Expr, ls::LoopSet, Ulow::Int, Uhigh::Int)
208208
op = ls.operations[or]
209209
var = mangledvar(op)
210210
temp = gensym(var)
211-
instr = op.instruction
212-
instr = get(REDUCTION_TRANSLATION, instr, instr)
211+
instr = Instruction(reduction_to_single_vector(op.instruction))
213212
reduce_range!(q, var, instr, Ulow, Uhigh)
214213
end
215214
end
216215

217216
function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, U::Int)
218217
U == 1 && return nothing
219-
instr = get(REDUCTION_TRANSLATION, instr, instr)
218+
instr = Instruction(reduction_to_single_vector(instr))
220219
Uh2 = U
221220
iter = 0
222221
while true # combine vectors
@@ -226,8 +225,6 @@ function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, U::Int)
226225
Uh2 = Uh
227226
iter += 1; iter > 4 && throw("Oops! This seems to be excessive unrolling.")
228227
end
229-
# reduce last vector
230-
# push!(q.args, Expr(:(=), assignto, Expr(:call, reductfunc, Symbol(toreduct,:_0))))
231228
nothing
232229
end
233230

@@ -403,7 +400,9 @@ function lower_compute!(
403400
# if op is an inner reduction, one of its parents will be the initialization of op
404401
# They will share the same `variable` field. The initialization may not have
405402
# unrolled in its loop dependencies, but (if opunrolled) op itself is, so we return true
406-
parentsunrolled[p] = var === opp.variable ? true : (unrolled loopdependencies(opp))
403+
# pu = (var === opp.variable || search_tree(parents(opp), var)) ? opunrolled : (unrolled ∈ loopdependencies(opp))
404+
pu = unrolled loopdependencies(opp) || unrolled reducedchildren(opp)
405+
parentsunrolled[p] = pu
407406
end
408407
else # maybe skip allocating this?
409408
parentsunrolled = fill(false, nparents)
@@ -662,8 +661,7 @@ end
662661
function initialize_outer_reductions!(
663662
q::Expr, op::Operation, Umin::Int, Umax::Int, W::Symbol, typeT::Symbol, vectorized::Symbol, suffix::Union{Symbol,Nothing} = nothing
664663
)
665-
# T = op.elementbytes == 8 ? :Float64 : :Float32
666-
z = Expr(:call, REDUCTION_ZERO[op.instruction], typeT)
664+
z = Expr(:call, reduction_zero(op.instruction), typeT)
667665
if vectorized reduceddependencies(op)
668666
z = Expr(:call, lv(:vbroadcast), W, z)
669667
end
@@ -705,7 +703,7 @@ function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
705703
mvar = mangledvar(op)
706704
instr = instruction(op)
707705
reduce_expr!(q, mvar, instr, U)
708-
length(ls.opdict) == 0 || push!(q.args, Expr(:(=), var, Expr(:call, REDUCTION_SCALAR_COMBINE[instr], var, Symbol(mvar, 0))))
706+
length(ls.opdict) == 0 || push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), var, Symbol(mvar, 0))))
709707
end
710708
end
711709
function gc_preserve(ls::LoopSet, q::Expr)

src/operations.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ struct Operation
9494
parents::Vector{Operation}
9595
ref::ArrayReferenceMeta
9696
mangledvariable::Symbol
97+
reduced_children::Vector{Symbol}
9798
function Operation(
9899
identifier::Int,
99100
variable,
@@ -103,15 +104,16 @@ struct Operation
103104
dependencies = Symbol[],
104105
reduced_deps = Symbol[],
105106
parents = Operation[],
106-
ref::ArrayReferenceMeta = NOTAREFERENCE
107+
ref::ArrayReferenceMeta = NOTAREFERENCE,
108+
reduced_children = Symbol[]
107109
)
108110
new(
109111
identifier, variable, elementbytes, instruction, node_type,
110112
convert(Vector{Symbol},dependencies),
111113
convert(Vector{Symbol},reduced_deps),
112114
convert(Vector{Operation},parents),
113-
ref,
114-
Symbol("##", variable, :_)
115+
ref, Symbol("##", variable, :_),
116+
reduced_children
115117
)
116118
end
117119
end
@@ -168,6 +170,7 @@ parents(op::Operation) = op.parents
168170
# children(op::Operation) = op.children
169171
loopdependencies(op::Operation) = op.dependencies
170172
reduceddependencies(op::Operation) = op.reduced_deps
173+
reducedchildren(op::Operation) = op.reduced_children
171174
identifier(op::Operation) = op.identifier + 1
172175
vptr(x::Symbol) = Symbol("##vptr##_", x)
173176
vptr(x::ArrayReference) = vptr(x.array)

0 commit comments

Comments
 (0)