@@ -56,7 +56,14 @@ function lower_compute!(
56
56
# making BitArrays inefficient.
57
57
# parentsyms = [opp.variable for opp ∈ parents(op)]
58
58
Uiter = opunrolled ? U - 1 : 0
59
- maskreduct = mask != = nothing && isreduction (op) && vectorized ∈ reduceddependencies (op) # any(opp -> opp.variable === var, parents_op)
59
+ isreduct = isreduction (op)
60
+ if ! isnothing (suffix) && isreduct
61
+ instrfid = findfirst (isequal (instr. instr), (:vfmadd_fast , :vfnmadd_fast , :vfmsub_fast , :vfnmsub_fast ))
62
+ if instrfid != = nothing
63
+ instr = Instruction ((:vfmadd231 , :vfnmadd231 , :vfmsub231 , :vfnmsub231 )[instrfid])
64
+ end
65
+ end
66
+ maskreduct = mask != = nothing && isreduct && vectorized ∈ reduceddependencies (op) # any(opp -> opp.variable === var, parents_op)
60
67
# if a parent is not unrolled, the compiler should handle broadcasting CSE.
61
68
# because unrolled/tiled parents result in an unrolled/tiled dependendency,
62
69
# we handle both the tiled and untiled case here.
@@ -90,10 +97,15 @@ function lower_compute!(
90
97
push! (instrcall. args, parent)
91
98
end
92
99
if maskreduct && (u == Uiter || unrolled != = vectorized) # only mask last
93
- push! (q. args, Expr (:(= ), varsym, Expr (:call , lv (:vifelse ), mask, instrcall, varsym)))
94
- else
95
- push! (q. args, Expr (:(= ), varsym, instrcall))
100
+ if last (instrcall. args) == varsym
101
+ pushfirst! (instrcall. args, lv (:vifelse ))
102
+ insert! (instrcall. args, 3 , mask)
103
+ else
104
+ push! (q. args, Expr (:(= ), varsym, Expr (:call , lv (:vifelse ), mask, instrcall, varsym)))
105
+ continue
106
+ end
96
107
end
108
+ push! (q. args, Expr (:(= ), varsym, instrcall))
97
109
end
98
110
end
99
111
0 commit comments