1
1
2
- # TODO : FIXME for general case
3
- # wrong for transposed matrices, and certain views/SubArrays.
4
- unitstride (op:: Operation , s) = first (getindices (op)) === s
2
+ function indexappearences (op:: Operation , s:: Symbol )
3
+ s ∉ loopdependencies (op) && return 0
4
+ appearences = 0
5
+ if isloopvalue (op)
6
+ return s === first (loopdependencies (op)) ? 1 : 0
7
+ elseif isload (op)
8
+ return 100
9
+ end
10
+ newapp = 0
11
+ for opp ∈ parents (op)
12
+ newapp += indexappearences (opp, s)
13
+ end
14
+ factor = instruction (op). instr ∈ (:+ , :vadd , :add_fast , :evadd ) ? 1 : 10
15
+ newapp * factor
16
+ end
17
+ function findparent (ls:: LoopSet , s:: Symbol )# opdict isn't filled when reconstructing
18
+ id = findfirst (op -> name (op) === s, operations (ls))
19
+ id === nothing && throw (" $s not found" )
20
+ operations (ls)[id]
21
+ end
22
+ function unitstride (ls:: LoopSet , op:: Operation , s:: Symbol )
23
+ inds = getindices (op)
24
+ li = op. ref. loopedindex
25
+ # The first index is allowed to be indexed by `s`
26
+ fi = first (inds)
27
+ if fi === Symbol (" ##DISCONTIGUOUSSUBARRAY##" )
28
+ return false
29
+ elseif ! first (li)
30
+ # We must check if this
31
+ parent = findparent (ls, fi)
32
+ indexappearences (parent, s) > 1 && return false
33
+ end
34
+ for i ∈ 2 : length (inds)
35
+ if li[i]
36
+ s === inds[i] && return false
37
+ else
38
+ parent = findparent (ls, inds[i])
39
+ s ∈ loopdependencies (parent) && return false
40
+ end
41
+ end
42
+ true
43
+ end
5
44
6
45
function register_pressure (op:: Operation )
7
46
if isconstant (op)
@@ -10,7 +49,7 @@ function register_pressure(op::Operation)
10
49
instruction_cost (instruction (op)). register_pressure
11
50
end
12
51
end
13
- function cost (op:: Operation , unrolled:: Symbol , Wshift:: Int , size_T:: Int = op. elementbytes)
52
+ function cost (ls :: LoopSet , op:: Operation , unrolled:: Symbol , Wshift:: Int , size_T:: Int = op. elementbytes)
14
53
isconstant (op) && return 0.0 , 0 , 1
15
54
# Wshift == dependson(op, unrolled) ? Wshift : 0
16
55
# c = first(cost(instruction(op), Wshift, size_T))::Int
@@ -27,7 +66,7 @@ function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int = op.ele
27
66
# either vbroadcast/reductionstore, vmov(a/u)pd, or gather/scatter
28
67
# @show instr, unrolled, loopdependencies(op), unitstride(op, unrolled)
29
68
if opisunrolled
30
- if ! unitstride (op, unrolled)# || !isdense(op) # need gather/scatter
69
+ if ! unitstride (ls, op, unrolled)# || !isdense(op) # need gather/scatter
31
70
r = (1 << Wshift)
32
71
srt *= r
33
72
sl *= r
@@ -93,7 +132,7 @@ function evaluate_cost_unroll(
93
132
hasintersection (rd, nested_loop_syms[1 : end - length (rd)]) && return Inf
94
133
included_vars[id] = true
95
134
# @show op first(cost(op, vectorized, Wshift, size_T)), iter
96
- total_cost += iter * first (cost (op, vectorized, Wshift, size_T))
135
+ total_cost += iter * first (cost (ls, op, vectorized, Wshift, size_T))
97
136
total_cost > max_cost && return total_cost # abort if more expensive; we only want to know the cheapest
98
137
end
99
138
end
@@ -102,18 +141,18 @@ end
102
141
103
142
# only covers vectorized ops; everything else considered lifted?
104
143
function depchain_cost! (
105
- skip:: Vector{Bool} , op:: Operation , vectorized:: Symbol , Wshift:: Int , size_T:: Int , rt:: Float64 = 0.0 , sl:: Int = 0
144
+ ls :: LoopSet , skip:: Vector{Bool} , op:: Operation , vectorized:: Symbol , Wshift:: Int , size_T:: Int , rt:: Float64 = 0.0 , sl:: Int = 0
106
145
)
107
146
skip[identifier (op)] = true
108
147
# depth first search
109
148
for opp ∈ parents (op)
110
149
skip[identifier (opp)] && continue
111
- rt, sl = depchain_cost! (skip, opp, vectorized, Wshift, size_T, rt, sl)
150
+ rt, sl = depchain_cost! (ls, skip, opp, vectorized, Wshift, size_T, rt, sl)
112
151
end
113
152
# Basically assuming memory and compute don't conflict, but everything else does
114
153
# Ie, ignoring the fact that integer and floating point operations likely don't either
115
154
if iscompute (op)
116
- rtᵢ, slᵢ = cost (op, vectorized, Wshift, size_T)
155
+ rtᵢ, slᵢ = cost (ls, op, vectorized, Wshift, size_T)
117
156
rt += rtᵢ; sl += slᵢ
118
157
end
119
158
rt, sl
@@ -139,9 +178,9 @@ function unroll_no_reductions(ls, order, vectorized, Wshift, size_T)
139
178
for op ∈ operations (ls)
140
179
dependson (op, innermost) || continue
141
180
if iscompute (op)
142
- compute_rt += first (cost (op, vectorized, Wshift, size_T))
181
+ compute_rt += first (cost (ls, op, vectorized, Wshift, size_T))
143
182
elseif isload (op)
144
- load_rt += first (cost (op, vectorized, Wshift, size_T))
183
+ load_rt += first (cost (ls, op, vectorized, Wshift, size_T))
145
184
end
146
185
end
147
186
# heuristic guess
@@ -181,13 +220,13 @@ function determine_unroll_factor(
181
220
for op ∈ operations (ls)
182
221
dependson (op, unrolled) || continue
183
222
if isreduction (op)
184
- rt, sl = depchain_cost! (visited_nodes, op, vectorized, Wshift, size_T)
223
+ rt, sl = depchain_cost! (ls, visited_nodes, op, vectorized, Wshift, size_T)
185
224
latency = max (sl, latency)
186
225
compute_recip_throughput += rt
187
226
elseif isload (op)
188
- load_recip_throughput += first (cost (op, vectorized, Wshift, size_T))
227
+ load_recip_throughput += first (cost (ls, op, vectorized, Wshift, size_T))
189
228
elseif isstore (op)
190
- store_recip_throughput += first (cost (op, vectorized, Wshift, size_T))
229
+ store_recip_throughput += first (cost (ls, op, vectorized, Wshift, size_T))
191
230
end
192
231
end
193
232
recip_throughput = max (
@@ -424,7 +463,7 @@ function evaluate_cost_tile(
424
463
opisininnerloop = descendentsininnerloop[id]
425
464
isunrolled = unrolledtiled[1 ,id]
426
465
istiled = unrolledtiled[2 ,id]
427
- rt, lat, rp = cost (op, vectorized, Wshift, size_T)
466
+ rt, lat, rp = cost (ls, op, vectorized, Wshift, size_T)
428
467
rp = opisininnerloop ? rp : 0 # we only care about register pressure within the inner most loop
429
468
rt *= iters[id]
430
469
if isunrolled && istiled # no cost decrease; cost must be repeated
0 commit comments