Skip to content

Commit e047ea4

Browse files
committed
Add vmpant(t)(!) support for non-aligned destination arrays.
1 parent 7f9123f commit e047ea4

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

src/map.jl

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,32 @@ end
2525
vmap_quote(N, T)
2626
end
2727

28-
function vmapnt!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
28+
function alignstores!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
29+
N = length(y)
2930
ptry = pointer(y)
30-
@assert reinterpret(UInt, ptry) & (VectorizationBase.REGISTER_SIZE - 1) == 0
31-
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
3231
ptrargs = pointer.(args)
32+
W = VectorizationBase.pick_vector_width(T)
3333
V = VectorizationBase.pick_vector_width_val(T)
34-
N = length(y)
34+
@assert iszero(reinterpret(UInt, ptry) & (sizeof(T) - 1)) "The destination vector (`dest`) must be aligned at least to `sizeof(eltype(dest))`."
35+
alignment = reinterpret(UInt, ptry) & (VectorizationBase.REGISTER_SIZE - 1)
36+
if alignment > 0
37+
i = reinterpret(Int, W - (alignment >>> VectorizationBase.intlog2(sizeof(T))))
38+
m = mask(T, i)
39+
if N < i
40+
m &= mask(T, N & (W - 1))
41+
end
42+
vstore!(ptry, extract_data(f(vload.(V, ptrargs, m)...)), m)
43+
gep(ptry, i), gep.(ptrargs, i), N - i
44+
else
45+
ptry, ptrargs, N
46+
end
47+
end
48+
49+
function vmapnt!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
50+
ptry, ptrargs, N = alignstores!(f, y, args...)
3551
i = 0
52+
W = VectorizationBase.pick_vector_width(T)
53+
V = VectorizationBase.pick_vector_width_val(T)
3654
while i < N - ((W << 2) - 1)
3755
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i); i += W
3856
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i); i += W
@@ -49,12 +67,9 @@ function vmapnt!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A
4967
y
5068
end
5169
function vmapntt!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
52-
ptry = pointer(y)
53-
@assert reinterpret(UInt, ptry) & (VectorizationBase.REGISTER_SIZE - 1) == 0
70+
ptry, ptrargs, N = alignstores!(f, y, args...)
5471
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
55-
ptrargs = pointer.(args)
5672
V = VectorizationBase.pick_vector_width_val(T)
57-
N = length(y)
5873
Wsh = Wshift + 2
5974
Niter = N >>> Wsh
6075
Base.Threads.@threads for j 0:Niter-1

test/map.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@testset "map" begin
22
@inline foo(x, y) = exp(x) - sin(y)
33
N = 3781
4+
45
for T (Float32,Float64)
56
@show T, @__LINE__
67
a = rand(T, N); b = rand(T, N);
@@ -11,8 +12,10 @@
1112
@test c1 c2
1213
c2 = vmapntt(foo, a, b);
1314
@test c1 c2
14-
@test_throws AssertionError @views vmapnt!(c2[2:end], a[2:end], b[2:end])
15-
@test_throws AssertionError @views vmapntt!(c2[2:end], a[2:end], b[2:end])
15+
fill!(c2, NaN); @views vmapnt!(foo, c2[2:end], a[2:end], b[2:end]);
16+
@test @views c1[2:end] c2[2:end]
17+
fill!(c2, NaN); @views vmapntt!(foo, c2[2:end], a[2:end], b[2:end]);
18+
@test @views c1[2:end] c2[2:end]
1619

1720
c = rand(T,100); x = rand(T,10^4); y1 = similar(x); y2 = similar(x);
1821
map!(xᵢ -> clenshaw(xᵢ, c), y1, x)

0 commit comments

Comments
 (0)