Skip to content

Commit 2ae8a78

Browse files
authored
Merge pull request #65 from timholy/teh/loops
Fix `for i in iter`
2 parents 977528d + 73e2ed5 commit 2ae8a78

11 files changed

+216
-62
lines changed

Manifest.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
3131
deps = ["Base64"]
3232
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
3333

34+
[[OffsetArrays]]
35+
git-tree-sha1 = "707e34562700b81e8aa13548eb6b23b18112e49b"
36+
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
37+
version = "1.0.2"
38+
3439
[[OrderedCollections]]
3540
deps = ["Random", "Serialization", "Test"]
3641
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
@@ -49,9 +54,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4954

5055
[[SIMDPirates]]
5156
deps = ["VectorizationBase"]
52-
git-tree-sha1 = "34dff4f4715f871e71b38f31397d96e62621f14d"
57+
git-tree-sha1 = "f91198b7ef74b04028f98e0eed7c556b93538a2e"
5358
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
54-
version = "0.6.5"
59+
version = "0.6.6"
5560

5661
[[SLEEFPirates]]
5762
deps = ["Libdl", "SIMDPirates", "VectorizationBase"]
@@ -71,6 +76,6 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7176

7277
[[VectorizationBase]]
7378
deps = ["CpuId", "LinearAlgebra"]
74-
git-tree-sha1 = "006d7b7f276db8d728f8bfd70ebf2efd132f9548"
79+
git-tree-sha1 = "8abb5697fb64cadccd1bba444c955942d3181e5c"
7580
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
76-
version = "0.7.0"
81+
version = "0.7.1"

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.6.20"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
89
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
910
SIMDPirates = "21efa798-c60a-11e8-04d3-e1a92915a26a"
1011
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"

src/LoopVectorization.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@ using VectorizationBase, SIMDPirates, SLEEFPirates, Parameters
44
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector_load_expr,
55
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valadd, valsub, _MM,
66
maybestaticlength, maybestaticsize, staticm1, subsetview, vzero, stridedpointer_for_broadcast,
7-
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange,
7+
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange, unwrap, maybestaticrange,
88
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct
99
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod,
1010
sizeequivalentfloat, sizeequivalentint, vadd!, vsub!, vmul!, vfdiv!, vfmadd!, vfnmadd!, vfmsub!, vfnmsub!,
1111
vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, #prefetch,
1212
vmullog2, vmullog10, vdivlog2, vdivlog10, vmullog2add!, vmullog10add!, vdivlog2add!, vdivlog10add!, vfmaddaddone
1313
using Base.Broadcast: Broadcasted, DefaultArrayStyle
1414
using LinearAlgebra: Adjoint, Transpose
15+
using Base.Meta: isexpr
1516

1617
const SUPPORTED_TYPES = Union{Float16,Float32,Float64,Integer}
1718

@@ -21,6 +22,7 @@ export LowDimArray, stridedpointer, vectorizable,
2122
vfilter, vfilter!
2223

2324

25+
include("vectorizationbase_extensions.jl")
2426
include("map.jl")
2527
include("filter.jl")
2628
include("costs.jl")

src/add_loads.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,3 @@ function add_loopvalue!(ls::LoopSet, arg::Symbol, elementbytes::Int)
7070
loopsymop
7171
end
7272

73-
74-
struct LoopValue end
75-
@inline VectorizationBase.stridedpointer(::LoopValue) = LoopValue()
76-
@inline VectorizationBase.vload(::LoopValue, i::Tuple{_MM{W}}) where {W} = _MM{W}(@inbounds(i[1].i) + 1)
77-
# @inline VectorizationBase.vload(::LoopValue, i::Tuple{_MM{W}}, ::Unsigned) where {W} = _MM{W}(@inbounds(i[1].i) + 1)
78-
@inline VectorizationBase.vload(::LoopValue, i::Tuple{_MM{W}}, ::Mask) where {W} = _MM{W}(@inbounds(i[1].i) + 1)
79-
@inline VectorizationBase.vload(::LoopValue, i::Integer) = i + one(i)
80-
@inline VectorizationBase.vload(::LoopValue, i::Tuple{I}) where {I<:Integer} = @inbounds(i[1]) + one(I)
81-
@inline Base.eltype(::LoopValue) = Int8
82-

src/graphs.jl

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
# For passing options like array types and mask
3535
# struct LoopSetOptions
36-
36+
3737
# end
3838

3939
struct Loop
@@ -70,7 +70,7 @@ function startloop(loop::Loop, isvectorized, W, itersymbol = loop.itersymbol)
7070
elseif startexact
7171
Expr(:(=), itersymbol, loop.starthint)
7272
else
73-
Expr(:(=), itersymbol, loop.startsym)
73+
Expr(:(=), itersymbol, Expr(:call, lv(:unwrap), loop.startsym))
7474
end
7575
end
7676
function vec_looprange(loop::Loop, isunrolled::Bool, W::Symbol, U::Int)
@@ -84,7 +84,7 @@ function vec_looprange(loop::Loop, isunrolled::Bool, W::Symbol, U::Int)
8484
else
8585
Expr(:call, :<, loop.itersymbol, Expr(:call, :-, loop.stopsym, incr))
8686
end
87-
end
87+
end
8888
function looprange(loop::Loop, incr::Int, mangledname::Symbol)
8989
incr -= 1#one(Int32)
9090
if iszero(incr)
@@ -369,47 +369,59 @@ This function creates a loop, while switching from 1 to 0 based indices
369369
"""
370370
function register_single_loop!(ls::LoopSet, looprange::Expr)
371371
itersym = (looprange.args[1])::Symbol
372-
r = (looprange.args[2])::Expr
373-
@assert r.head === :call
374-
f = first(r.args)
375-
loop::Loop = if f === :(:)
376-
lower = r.args[2]
377-
upper = r.args[3]
378-
lii::Bool = lower isa Integer
379-
liiv::Int = lii ? (convert(Int, lower)-1) : 0
380-
uii::Bool = upper isa Integer
381-
if lii & uii # both are integers
382-
Loop(itersym, liiv, convert(Int, upper))
383-
elseif lii # only lower bound is an integer
384-
if upper isa Symbol
385-
Loop(itersym, liiv, upper)
386-
elseif upper isa Expr
387-
Loop(itersym, liiv, add_loop_bound!(ls, itersym, upper, true))
388-
else
389-
Loop(itersym, liiv, add_loop_bound!(ls, itersym, upper, true))
372+
r = looprange.args[2]
373+
if isexpr(r, :call)
374+
f = first(r.args)
375+
loop::Loop = if f === :(:)
376+
lower = r.args[2]
377+
upper = r.args[3]
378+
lii::Bool = lower isa Integer
379+
liiv::Int = lii ? (convert(Int, lower)-1) : 0
380+
uii::Bool = upper isa Integer
381+
if lii & uii # both are integers
382+
Loop(itersym, liiv, convert(Int, upper))
383+
elseif lii # only lower bound is an integer
384+
if upper isa Symbol
385+
Loop(itersym, liiv, upper)
386+
elseif upper isa Expr
387+
Loop(itersym, liiv, add_loop_bound!(ls, itersym, upper, true))
388+
else
389+
Loop(itersym, liiv, add_loop_bound!(ls, itersym, upper, true))
390+
end
391+
elseif uii # only upper bound is an integer
392+
uiiv = convert(Int, upper)
393+
Loop(itersym, add_loop_bound!(ls, itersym, lower, false), uiiv)
394+
else # neither are integers
395+
L = add_loop_bound!(ls, itersym, lower, false)
396+
U = add_loop_bound!(ls, itersym, upper, true)
397+
Loop(itersym, L, U)
390398
end
391-
elseif uii # only upper bound is an integer
392-
uiiv = convert(Int, upper)
393-
Loop(itersym, add_loop_bound!(ls, itersym, lower, false), uiiv)
394-
else # neither are integers
395-
L = add_loop_bound!(ls, itersym, lower, false)
396-
U = add_loop_bound!(ls, itersym, upper, true)
399+
elseif f === :eachindex
400+
N = gensym(Symbol(:loopeachindex, itersym))
401+
pushpreamble!(ls, Expr(:(=), N, Expr(:call, lv(:maybestaticrange), r)))
402+
L = add_loop_bound!(ls, itersym, Expr(:call, :first, N), false)
403+
U = add_loop_bound!(ls, itersym, Expr(:call, :last, N), true)
397404
Loop(itersym, L, U)
398-
end
399-
elseif f === :eachindex
400-
N = gensym(Symbol(:loop, itersym))
401-
pushpreamble!(ls, Expr(:(=), N, Expr(:call, lv(:maybestaticlength), r.args[2])))
402-
Loop(itersym, 0, N)
403-
elseif f === :OneTo || f == Expr(:(.), :Base, QuoteNode(:OneTo))
404-
otN = r.args[2]
405-
if otN isa Integer
406-
Loop(itersym, 0, otN)
405+
elseif f === :OneTo || f == Expr(:(.), :Base, QuoteNode(:OneTo))
406+
otN = r.args[2]
407+
if otN isa Integer
408+
Loop(itersym, 0, otN)
409+
else
410+
otN isa Expr && maybestatic!(otN)
411+
N = gensym(Symbol(:loop, itersym))
412+
pushpreamble!(ls, Expr(:(=), N, otN))
413+
Loop(itersym, 0, N)
414+
end
407415
else
408-
otN isa Expr && maybestatic!(otN)
409-
N = gensym(Symbol(:loop, itersym))
410-
pushpreamble!(ls, Expr(:(=), N, otN))
411-
Loop(itersym, 0, N)
416+
throw("Unrecognized loop range type: $r.")
412417
end
418+
elseif isa(r, Symbol)
419+
# Treat similar to `eachindex`
420+
N = gensym(Symbol(:loop, itersym))
421+
pushpreamble!(ls, Expr(:(=), N, Expr(:call, lv(:maybestaticrange), r)))
422+
L = add_loop_bound!(ls, itersym, Expr(:call, :first, N), false)
423+
U = add_loop_bound!(ls, itersym, Expr(:call, :last, N), true)
424+
loop = Loop(itersym, L, U)
413425
else
414426
throw("Unrecognized loop range type: $r.")
415427
end
@@ -546,7 +558,3 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
546558
throw("Don't know how to handle expression:\n$ex")
547559
end
548560
end
549-
550-
551-
552-

src/reconstruct_loopset.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ function Loop(ls::LoopSet, l::Int, ::Type{StaticLowerUnitRange{L}}) where {L}
1414
pushpreamble!(ls, Expr(:(=), stop, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), Expr(:ref, :lb, l), QuoteNode(:U)))))
1515
Loop(gensym(:n), L, L + 1024, Symbol(""), stop, true, false)::Loop
1616
end
17+
# Is there any likely way to generate such a range?
18+
# function Loop(ls::LoopSet, l::Int, ::Type{StaticLengthUnitRange{N}}) where {N}
19+
# start = gensym(:loopstart); stop = gensym(:loopstop)
20+
# pushpreamble!(ls, Expr(:(=), start, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:(.), Expr(:ref, :lb, l), QuoteNode(:L)))))
21+
# pushpreamble!(ls, Expr(:(=), stop, Expr(:call, :(+), start, N - 1)))
22+
# Loop(gensym(:n), 0, N, start, stop, false, false)::Loop
23+
# end
1724
function Loop(ls, l, ::Type{StaticUnitRange{L,U}}) where {L,U}
1825
Loop(gensym(:n), L, U, Symbol(""), Symbol(""), true, true)::Loop
1926
end
@@ -63,14 +70,18 @@ extract_varg(i) = Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__
6370
pushvarg!(ls::LoopSet, ar::ArrayReferenceMeta, i) = pushpreamble!(ls, Expr(:(=), vptr(ar), extract_varg(i)))
6471
function pushvarg′!(ls::LoopSet, ar::ArrayReferenceMeta, i)
6572
reverse!(ar.loopedindex); reverse!(getindices(ar)) # reverse the listed indices here, and transpose it to make it column major
66-
pushpreamble!(ls, Expr(:(=), vptr(ar), Expr(:call, lv(:Transpose), extract_varg(i))))
73+
pushpreamble!(ls, Expr(:(=), vptr(ar), Expr(:call, lv(:transpose), extract_varg(i))))
6774
end
6875
function add_mref!(ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{PackedStridedPointer{T, N}}) where {T, N}
6976
pushvarg!(ls, ar, i)
7077
end
7178
function add_mref!(ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{RowMajorStridedPointer{T, N}}) where {T, N}
7279
pushvarg′!(ls, ar, i)
7380
end
81+
function add_mref!(ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{OffsetStridedPointer{T,N,P}}) where {T,N,P}
82+
add_mref!(ls, ar, i, P)
83+
end
84+
7485
function add_mref!(
7586
ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{S}
7687
) where {T, X <: Tuple, S <: VectorizationBase.AbstractStaticStridedPointer{T,X}}

src/vectorizationbase_extensions.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
struct LoopValue end
3+
@inline VectorizationBase.stridedpointer(::LoopValue) = LoopValue()
4+
@inline VectorizationBase.vload(::LoopValue, i::Tuple{_MM{W}}) where {W} = _MM{W}(@inbounds(i[1].i) + 1)
5+
# @inline VectorizationBase.vload(::LoopValue, i::Tuple{_MM{W}}, ::Unsigned) where {W} = _MM{W}(@inbounds(i[1].i) + 1)
6+
@inline VectorizationBase.vload(::LoopValue, i::Tuple{_MM{W}}, ::Mask) where {W} = _MM{W}(@inbounds(i[1].i) + 1)
7+
@inline VectorizationBase.vload(::LoopValue, i::Integer) = i + one(i)
8+
@inline VectorizationBase.vload(::LoopValue, i::Tuple{I}) where {I<:Integer} = @inbounds(i[1]) + one(I)
9+
@inline Base.eltype(::LoopValue) = Int8
10+
11+
import OffsetArrays
12+
13+
# If ndim(::OffsetArray) == 1, we can convert to a regular strided pointer and offset.
14+
@inline VectorizationBase.stridedpointer(a::OffsetArrays.OffsetArray{<:Any,1}) = gesp(stridedpointer(parent(a)), (-@inbounds(a.offsets[1]),))
15+
16+
struct OffsetStridedPointer{T, N, P <: VectorizationBase.AbstractStridedPointer{T}} <: VectorizationBase.AbstractStridedPointer{T}
17+
ptr::P
18+
offsets::NTuple{N,Int}
19+
end
20+
# if ndim(A::OffsetArray) ≥ 2, then eachindex(A) isa Base.OneTo, index starting at 1.
21+
# but multiple indexing is calculated using offsets, so we need a special type to express this.
22+
@inline function VectorizationBase.stridedpointer(A::OffsetArrays.OffsetArray)
23+
OffsetStridedPointer(stridedpointer(parent(A)), A.offsets)
24+
end
25+
# Tuple of length == 1, use ind directly.
26+
# @inline VectorizationBase.offset(ptr::OffsetStridedPointer, ind::Tuple{I}) where {I} = VectorizationBase.offset(ptr.ptr, ind)
27+
# Tuple of length > 1, subtract offsets.
28+
# @inline VectorizationBase.offset(ptr::OffsetStridedPointer{<:Any,N}, ind::Tuple) where {N} = VectorizationBase.offset(ptr.ptr, ntuple(n -> ind[n] + ptr.offsets[n], Val{N}()))
29+
@inline VectorizationBase.offset(ptr::OffsetStridedPointer, ind::Tuple{I}) where {I} = ind
30+
# Tuple of length > 1, subtract offsets.
31+
@inline VectorizationBase.offset(ptr::OffsetStridedPointer{<:Any,N}, ind::Tuple) where {N} = ntuple(n -> ind[n] - ptr.offsets[n], Val{N}())
32+
@inline Base.similar(p::OffsetStridedPointer, ptr::Ptr) = OffsetStridedPointer(similar(p.ptr, ptr), p.offsets)
33+
34+
# If an OffsetArray is getting indexed by a (loop-)constant value, then this particular vptr object cannot also be eachindexed, so we can safely return a stridedpointer
35+
@inline function VectorizationBase.subsetview(ptr::OffsetStridedPointer{<:Any,N}, ::Val{I}, i) where {I,N}
36+
subsetview(gesp(ptr.ptr, ntuple(n -> 0 - @inbounds(ptr.offsets[n]), Val{N}())), Val{I}(), i)
37+
end
38+

test/dot.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using LoopVectorization, OffsetArrays
2+
using Test
3+
14
@testset "dot" begin
25
dotq = :(for i eachindex(a,b)
36
s += a[i]*b[i]
@@ -46,6 +49,14 @@
4649
end
4750
s
4851
end
52+
function myselfdotavx_range(a)
53+
s = zero(eltype(a))
54+
rng = axes(a, 1)
55+
@avx for i rng
56+
s += a[i]*a[i]
57+
end
58+
s
59+
end
4960
function myselfdot_avx(a)
5061
s = zero(eltype(a))
5162
@_avx for i eachindex(a)
@@ -167,7 +178,7 @@
167178
end
168179
4acc/length(x)
169180
end
170-
181+
171182
# @macroexpand @_avx for i = 1:length(a_re) - 1
172183
# c_re[i] = b_re[i] * a_re[i + 1] - b_im[i] * a_im[i + 1]
173184
# c_im[i] = b_re[i] * a_im[i + 1] + b_im[i] * a_re[i + 1]
@@ -179,9 +190,12 @@
179190
N = 127
180191
R = T <: Integer ? (T(-100):T(100)) : T
181192
a = rand(T, N); b = rand(R, N);
193+
ao = OffsetArray(a, -60:66); bo = OffsetArray(b, -60:66);
182194
s = mydot(a, b)
183195
@test mydotavx(a,b) s
184196
@test mydot_avx(a,b) s
197+
@test mydotavx(ao,bo) s
198+
@test mydot_avx(ao,bo) s
185199
@test dot_unroll2avx(a,b) s
186200
@test dot_unroll3avx(a,b) s
187201
@test dot_unroll2_avx(a,b) s
@@ -190,6 +204,7 @@
190204
@test dot_unroll3avx_inline(a,b) s
191205
s = myselfdot(a)
192206
@test myselfdotavx(a) s
207+
@test myselfdotavx_range(a) s
193208
@test myselfdot_avx(a) s
194209
@test myselfdotavx(a) s
195210

@@ -205,7 +220,7 @@
205220
b_re = rand(R, N); b_im = rand(R, N);
206221
ac = Complex.(a_re, a_im);
207222
bc = Complex.(b_re, b_im);
208-
223+
209224
@test mydot(ac, bc) complex_dot_soa(a_re, a_im, b_re, b_im)
210225

211226
c_re1 = similar(a_re); c_im1 = similar(a_im);

test/gemv.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using LoopVectorization
2+
using Test
3+
14
@testset "GEMV" begin
25
gemvq = :(for i eachindex(y)
36
yᵢ = 0.0
@@ -27,6 +30,16 @@
2730
y[i] = yᵢ
2831
end
2932
end
33+
function mygemvavx_range!(y, A, x)
34+
rng1, rng2 = axes(A)
35+
@avx for i rng1
36+
yᵢ = zero(eltype(y))
37+
for j rng2
38+
yᵢ += A[i,j] * x[j]
39+
end
40+
y[i] = yᵢ
41+
end
42+
end
3043
q = :(for i eachindex(y)
3144
yᵢ = zero(eltype(y))
3245
for j eachindex(x)
@@ -150,6 +163,9 @@
150163
@test y1 y2
151164
fill!(y2, -999.9); mygemv_avx!(y2, A, x)
152165
@test y1 y2
166+
fill!(y2, -999.9)
167+
mygemvavx_range!(y2, A, x)
168+
@test y1 y2
153169

154170
B = rand(R, N, N);
155171
G1 = Matrix{TC}(undef, N, 1);

0 commit comments

Comments
 (0)