Skip to content

Commit 51a75c9

Browse files
committed
Use Mask type in place of unsigned integers to represent bitmasks and update to deps that added more support for them; fixes #60.
1 parent f93c523 commit 51a75c9

File tree

9 files changed

+145
-23
lines changed

9 files changed

+145
-23
lines changed

Manifest.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4949

5050
[[SIMDPirates]]
5151
deps = ["VectorizationBase"]
52-
git-tree-sha1 = "4b1e0b1442fb4af5e6b93b9c7fdeacf287d2653b"
52+
git-tree-sha1 = "839625f8699855a7d5ca96be25bc24d71c5c00ff"
5353
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
54-
version = "0.5.0"
54+
version = "0.6.0"
5555

5656
[[SLEEFPirates]]
5757
deps = ["Libdl", "SIMDPirates", "VectorizationBase"]
58-
git-tree-sha1 = "769fd039d0835e8e628d61e2f0c80822ba668497"
58+
git-tree-sha1 = "62368836fef70b461ac005ed0112315222eab5b5"
5959
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
60-
version = "0.3.9"
60+
version = "0.4.0"
6161

6262
[[Serialization]]
6363
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -71,6 +71,6 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7171

7272
[[VectorizationBase]]
7373
deps = ["CpuId", "LinearAlgebra"]
74-
git-tree-sha1 = "9f8caaa5d033f88e188f62a3dba0dab5f429447a"
74+
git-tree-sha1 = "9410db46eeb38d9fb108fae9758713cfafc4cb91"
7575
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
76-
version = "0.5.0"
76+
version = "0.6.1"

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.6.15"
4+
version = "0.6.16"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -12,9 +12,9 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1212

1313
[compat]
1414
Parameters = "0"
15-
SIMDPirates = "~0.5"
16-
SLEEFPirates = "~0.3.9"
17-
VectorizationBase = "~0.5"
15+
SIMDPirates = "~0.6"
16+
SLEEFPirates = "~0.4"
17+
VectorizationBase = "~0.6.1"
1818
julia = "1.1"
1919

2020
[extras]

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector
77
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange,
88
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct
99
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod,
10-
sizeequivalentfloat, sizeequivalentint, vadd!, vsub!, vfmadd!, vfnmadd!,
10+
sizeequivalentfloat, sizeequivalentint, vadd!, vsub!, vmul!, vfdiv!, vfmadd!, vfnmadd!, vfmsub!, vfnmsub!,
1111
vmullog2, vmullog10, vdivlog2, vdivlog10, vmullog2add!, vmullog10add!, vdivlog2add!, vdivlog10add!, vfmaddaddone
1212
using Base.Broadcast: Broadcasted, DefaultArrayStyle
1313
using LinearAlgebra: Adjoint, Transpose

src/add_compute.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,15 @@ function add_reduction_update_parent!(
137137
reductcombine = Symbol("")
138138
end
139139
combineddeps = copy(deps); mergesetv!(combineddeps, reduceddeps)
140-
directdependency && pushparent!(vparents, deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
140+
# directdependency && pushparent!(vparents, deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
141+
if directdependency
142+
if instr (:-, :vsub!, :vsub, :/, :vfdiv!, :vfidiv!)
143+
pushfirst!(vparents, reductinit)
144+
update_deps!(deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
145+
else
146+
push!(vparents, reductinit)
147+
end
148+
end
141149
update_reduction_status!(vparents, reduceddeps, name(reductinit))
142150
# this is the op added by add_compute
143151
op = Operation(length(operations(ls)), reductsym, elementbytes, instr, compute, deps, reduceddeps, vparents)

src/add_ifelse.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,25 @@ function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, positio
88
# for now, just simple 1-liners
99
@assert length(RHS.args) == 3 "if statements without an else cannot be assigned to a variable."
1010
condition = first(RHS.args)
11-
condop = add_compute!(ls, gensym(:mask), condition, elementbytes, position, mpref)
11+
condop = if mpref === nothing
12+
add_operation!(ls, gensym(:mask), condition, elementbytes, position)
13+
else
14+
add_operation!(ls, gensym(:mask), condition, mpref, elementbytes, position)
15+
end
1216
iftrue = RHS.args[2]
13-
(iftrue isa Expr && iftrue.head !== :call) && throw("Only calls or constant expressions are currently supported in if/else blocks.")
14-
trueop = add_operation!(ls, Symbol(:iftrue), iftrue, elementbytes, position)
17+
trueop = if iftrue isa Expr
18+
(iftrue isa Expr && iftrue.head !== :call) && throw("Only calls or constant expressions are currently supported in if/else blocks.")
19+
add_operation!(ls, Symbol(:iftrue), iftrue, elementbytes, position)
20+
else
21+
getop(ls, iftrue, elementbytes)
22+
end
1523
iffalse = RHS.args[3]
16-
(iffalse isa Expr && iffalse.head !== :call) && throw("Only calls or constant expressions are currently supported in if/else blocks.")
17-
falseop = add_operation!(ls, Symbol(:iffalse), iffalse, elementbytes, position)
18-
24+
falseop = if iffalse isa Expr
25+
(iffalse isa Expr && iffalse.head !== :call) && throw("Only calls or constant expressions are currently supported in if/else blocks.")
26+
add_operation!(ls, Symbol(:iffalse), iffalse, elementbytes, position)
27+
else
28+
getop(ls, iffalse, elementbytes)
29+
end
1930
add_compute!(ls, LHS, :vifelse, [condop, trueop, falseop], elementbytes)
2031
end
2132

@@ -38,7 +49,7 @@ function add_andblock!(ls::LoopSet, condop::Operation, LHS, RHS, elementbytes::I
3849
add_andblock!(ls, condop, LHS, rhsop, elementbytes, position)
3950
end
4051
function add_andblock!(ls::LoopSet, condexpr::Expr, condeval::Expr, elementbytes::Int, position::Int)
41-
condop = add_compute!(ls, gensym(:mask), condexpr, elementbytes, position)
52+
condop = add_operation!(ls, gensym(:mask), condexpr, elementbytes, position)
4253
if condeval.head === :call
4354
@assert first(condeval.args) === :setindex!
4455
array, raw_indices = ref_from_setindex(condeval)
@@ -79,7 +90,7 @@ function add_orblock!(ls::LoopSet, condop::Operation, LHS, RHS, elementbytes::In
7990
add_orblock!(ls, condop, LHS, rhsop, elementbytes, position)
8091
end
8192
function add_orblock!(ls::LoopSet, condexpr::Expr, condeval::Expr, elementbytes::Int, position::Int)
82-
condop = add_compute!(ls, gensym(:mask), condexpr, elementbytes, position)
93+
condop = add_operation!(ls, gensym(:mask), condexpr, elementbytes, position)
8394
if condeval.head === :call
8495
@assert first(condeval.args) === :setindex!
8596
array, raw_indices = ref_from_setindex(condeval)

src/costs.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ const COST = Dict{Instruction,InstructionCost}(
110110
Instruction(:vsub) => InstructionCost(4,0.5),
111111
Instruction(:vadd!) => InstructionCost(4,0.5),
112112
Instruction(:vsub!) => InstructionCost(4,0.5),
113+
Instruction(:vmul!) => InstructionCost(4,0.5),
113114
Instruction(:vmul) => InstructionCost(4,0.5),
114115
Instruction(:vfdiv) => InstructionCost(13,4.0,-2.0),
116+
Instruction(:vfdiv!) => InstructionCost(13,4.0,-2.0),
115117
Instruction(:evadd) => InstructionCost(4,0.5),
116118
Instruction(:evsub) => InstructionCost(4,0.5),
117119
Instruction(:evmul) => InstructionCost(4,0.5),
@@ -152,6 +154,8 @@ const COST = Dict{Instruction,InstructionCost}(
152154
Instruction(:vfnmsub) => InstructionCost(4,0.5), # - and -* will fuse into this, so much of the time they're not twice as expensive
153155
Instruction(:vfmadd!) => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
154156
Instruction(:vfnmadd!) => InstructionCost(4,0.5), # + and -* will fuse into this, so much of the time they're not twice as expensive
157+
Instruction(:vfmsub!) => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
158+
Instruction(:vfnmsub!) => InstructionCost(4,0.5), # + and -* will fuse into this, so much of the time they're not twice as expensive
155159
Instruction(:vfmadd_fast) => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
156160
Instruction(:vfmsub_fast) => InstructionCost(4,0.5), # - and * will fuse into this, so much of the time they're not twice as expensive
157161
Instruction(:vfnmadd_fast) => InstructionCost(4,0.5), # + and -* will fuse into this, so much of the time they're not twice as expensive
@@ -212,7 +216,10 @@ const REDUCTION_CLASS = Dict{Symbol,Float64}(
212216
:* => MULTIPLICATIVE_IN_REDUCTIONS,
213217
:vadd => ADDITIVE_IN_REDUCTIONS,
214218
:vsub => ADDITIVE_IN_REDUCTIONS,
219+
:vadd! => ADDITIVE_IN_REDUCTIONS,
220+
:vsub! => ADDITIVE_IN_REDUCTIONS,
215221
:vmul => MULTIPLICATIVE_IN_REDUCTIONS,
222+
:vmul! => MULTIPLICATIVE_IN_REDUCTIONS,
216223
:evadd => ADDITIVE_IN_REDUCTIONS,
217224
:evsub => ADDITIVE_IN_REDUCTIONS,
218225
:evmul => MULTIPLICATIVE_IN_REDUCTIONS,
@@ -228,6 +235,8 @@ const REDUCTION_CLASS = Dict{Symbol,Float64}(
228235
:vfnmsub => ADDITIVE_IN_REDUCTIONS,
229236
:vfmadd! => ADDITIVE_IN_REDUCTIONS,
230237
:vfnmadd! => ADDITIVE_IN_REDUCTIONS,
238+
:vfmsub! => ADDITIVE_IN_REDUCTIONS,
239+
:vfnmsub! => ADDITIVE_IN_REDUCTIONS,
231240
:vfmadd_fast => ADDITIVE_IN_REDUCTIONS,
232241
:vfmsub_fast => ADDITIVE_IN_REDUCTIONS,
233242
:vfnmadd_fast => ADDITIVE_IN_REDUCTIONS,
@@ -283,8 +292,11 @@ const FUNCTIONSYMBOLS = Dict{Type{<:Function},Instruction}(
283292
typeof(Base.FastMath.sub_fast) => :(-),
284293
typeof(*) => :(*),
285294
typeof(SIMDPirates.vmul) => :(*),
295+
typeof(SIMDPirates.vmul!) => :(*),
286296
typeof(Base.FastMath.mul_fast) => :(*),
287297
typeof(/) => :(/),
298+
typeof(SIMDPirates.vfdiv) => :(/),
299+
typeof(SIMDPirates.vfdiv!) => :(/),
288300
typeof(SIMDPirates.vdiv) => :(/),
289301
typeof(Base.FastMath.div_fast) => :(/),
290302
typeof(==) => :(==),
@@ -306,6 +318,8 @@ const FUNCTIONSYMBOLS = Dict{Type{<:Function},Instruction}(
306318
typeof(SIMDPirates.vfnmsub) => :vfnmsub,
307319
typeof(SIMDPirates.vfmadd!) => :vfmadd!,
308320
typeof(SIMDPirates.vfnmadd!) => :vfnmadd!,
321+
typeof(SIMDPirates.vfmsub!) => :vfmsub!,
322+
typeof(SIMDPirates.vfnmsub!) => :vfnmsub!,
309323
typeof(SIMDPirates.vfmadd_fast) => :vfmadd_fast,
310324
typeof(SIMDPirates.vfmsub_fast) => :vfmsub_fast,
311325
typeof(SIMDPirates.vfnmadd_fast) => :vfnmadd_fast,

src/graphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ Base.length(ls::LoopSet, s::Symbol) = length(getloop(ls, s))
290290
isstaticloop(ls::LoopSet, s::Symbol) = isstaticloop(getloop(ls,s))
291291
looprangehint(ls::LoopSet, s::Symbol) = length(getloop(ls, s))
292292
looprangesym(ls::LoopSet, s::Symbol) = getloop(ls, s).rangesym
293+
getop(ls::LoopSet, var::Number, elementbytes) = add_constant!(ls, var, elementbytes)
293294
function getop(ls::LoopSet, var::Symbol, elementbytes::Int)
294295
get!(ls.opdict, var) do
295296
add_constant!(ls, var, elementbytes)

src/precompile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function _precompile_()
5858
precompile(Tuple{typeof(Base.Broadcast.broadcasted),Function,Array{Int64,3},LowDimArray{(false, true, true),Int64,3,Array{Int64,3}}})
5959
precompile(Tuple{typeof(Base.Broadcast.broadcasted),Function,Array{Int64,3},LowDimArray{(true, false, true),Int64,3,Array{Int64,3}}})
6060
precompile(Tuple{typeof(Base.Broadcast.broadcasted),Function,Array{Int64,3},LowDimArray{(true, true, false),Int64,3,Array{Int64,3}}})
61-
precompile(Tuple{typeof(Base.Broadcast.broadcasted),typeof(*ˡ),Array{Float64,2},Array{Float64,1}})
61+
precompile(Tuple{typeof(Base.Broadcast.broadcasted),typeof(*ˡ),Array{Int32,2},Array{Int32,1}})
6262
precompile(Tuple{typeof(Base.Broadcast.broadcasted),typeof(*ˡ),Array{Int64,2},Array{Int64,1}})
6363
precompile(Tuple{typeof(LoopVectorization._avx_loopset),Core.SimpleVector,Core.SimpleVector,Core.SimpleVector,Core.SimpleVector,NTuple{4,DataType}})
6464
precompile(Tuple{typeof(LoopVectorization._avx_loopset),Core.SimpleVector,Core.SimpleVector,Core.SimpleVector,Core.SimpleVector,NTuple{5,DataType}})

test/ifelsemasks.jl

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,68 @@
11
@testset "ifelse (masks)" begin
2+
3+
function promote_bool_store!(z, x, y)
4+
for i eachindex(x)
5+
z[i] = (x[i]*x[i] + y[i]*y[i]) < 1
6+
end
7+
z
8+
end
9+
function promote_bool_storeavx!(z, x, y)
10+
@avx for i eachindex(x)
11+
z[i] = (x[i]*x[i] + y[i]*y[i]) < 1
12+
end
13+
z
14+
end
15+
function promote_bool_store_avx!(z, x, y)
16+
@_avx for i eachindex(x)
17+
z[i] = (x[i]*x[i] + y[i]*y[i]) < 1
18+
end
19+
z
20+
end
21+
function promote_bool_storeavx2!(z, x, y)
22+
@avx for i eachindex(x)
23+
z[i] = (x[i]*x[i] + y[i]*y[i]) < 1 ? 1 : 0
24+
end
25+
z
26+
end
27+
function promote_bool_store_avx2!(z, x, y)
28+
@_avx for i eachindex(x)
29+
z[i] = (x[i]*x[i] + y[i]*y[i]) < 1 ? 1 : 0
30+
end
31+
z
32+
end
33+
34+
function Bernoulli_logit(y::BitVector, α::AbstractVector{T}) where {T}
35+
t = zero(promote_type(Float32,T))
36+
@inbounds for i eachindex(α)
37+
invOmP = 1 + exp(α[i])
38+
nlogOmP = log(invOmP)
39+
nlogP = nlogOmP - α[i]
40+
t -= y[i] ? nlogP : nlogOmP
41+
end
42+
t
43+
end
44+
function Bernoulli_logitavx(y::BitVector, α::AbstractVector{T}) where {T}
45+
t = zero(promote_type(Float32,T))
46+
@avx for i eachindex(α)
47+
invOmP = 1 + exp(α[i])
48+
nlogOmP = log(invOmP)
49+
nlogP = nlogOmP - α[i]
50+
t -= y[i] ? nlogP : nlogOmP
51+
end
52+
t
53+
end
54+
function Bernoulli_logit_avx(y::BitVector, α::AbstractVector{T}) where {T}
55+
t = zero(promote_type(Float32,T))
56+
@_avx for i eachindex(α)
57+
invOmP = 1 + exp(α[i])
58+
nlogOmP = log(invOmP)
59+
nlogP = nlogOmP - α[i]
60+
t -= y[i] ? nlogP : nlogOmP
61+
end
62+
t
63+
end
64+
65+
266
function addormul!(c, a, b)
367
for i eachindex(c,a,b)
468
c[i] = a[i] > b[i] ? a[i] + b[i] : a[i] * b[i]
@@ -227,8 +291,19 @@
227291
a = rand(T, N); b = rand(T, N);
228292
end;
229293
c1 = similar(a); c2 = similar(a);
230-
addormul!(c1, a, b)
231-
addormul_avx!(c2, a, b)
294+
295+
promote_bool_store!(c1, a, b)
296+
promote_bool_storeavx!(c2, a, b)
297+
@test c1 == c2
298+
fill!(c2, -999999999); promote_bool_store_avx!(c2, a, b)
299+
@test c1 == c2
300+
fill!(c2, -999999999); promote_bool_storeavx2!(c2, a, b)
301+
@test c1 == c2
302+
fill!(c2, -999999999); promote_bool_store_avx2!(c2, a, b)
303+
@test c1 == c2
304+
305+
fill!(c2, -999999999); addormul!(c1, a, b)
306+
fill!(c2, -999999999); addormul_avx!(c2, a, b)
232307
@test c1 c2
233308
fill!(c2, -999999999); addormulavx!(c2, a, b)
234309
@test c1 c2
@@ -296,4 +371,17 @@
296371
@test C1 C2
297372
@test C1 C3
298373
end
374+
375+
376+
a = rand(-10:10, 43);
377+
bit = a .> 0.5;
378+
t = Bernoulli_logit(bit, a);
379+
@test t Bernoulli_logitavx(bit, a)
380+
@test t Bernoulli_logit_avx(bit, a)
381+
a = rand(43)
382+
bit = a .> 0.5;
383+
t = Bernoulli_logit(bit, a);
384+
@test t Bernoulli_logitavx(bit, a)
385+
@test t Bernoulli_logit_avx(bit, a)
386+
299387
end

0 commit comments

Comments
 (0)