Skip to content

Commit ffa6f0a

Browse files
authored
Call xxxI instead of xxx to support strided layout. (2nd take) (#58)
* add `inc` support. add strided test Update setup.jl * add newly added IVM funs * add `sincos` * test tune. * update readme and example * bump * use dense api if possible. (seems faster.)
1 parent 2852709 commit ffa6f0a

File tree

6 files changed

+170
-40
lines changed

6 files changed

+170
-40
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "IntelVectorMath"
22
uuid = "c8ce9da6-5d36-5c03-b118-5a70151be7bc"
3-
version = "0.4.2"
3+
version = "0.4.3"
44

55
[deps]
66
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ julia> b = similar(a);
5555

5656
julia> @btime IVM.sin!(b, a); # in-place version
5757
20.008 μs (0 allocations: 0 bytes)
58+
59+
julia> @views IVM.sin(a[1:2:end]) == b[1:2:end] # all IVM functions support 1d strided input
60+
true
5861
```
5962

6063
### Accuracy
@@ -247,6 +250,8 @@ Next steps for this package
247250
IntelVectorMath.jl uses [CpuId.jl](https://github.com/m-j-w/CpuId.jl) to detect if your processor supports the newer `avx2` instructions, and if not defaults to `libmkl_vml_avx`. If your system does not have AVX this package will currently not work for you.
248251
If the CPU feature detection does not work for you, please open an issue. -->
249252

250-
As a quick help to convert benchmark timings into operations-per-cycle, IntelVectorMath.jl
253+
1. As a quick help to convert benchmark timings into operations-per-cycle, IntelVectorMath.jl
251254
provides `vml_get_cpu_frequency()` which will return the *actual* current frequency of the
252255
CPU in GHz.
256+
257+
2. Now all IVM functions accept inputs that could be reshaped to an 1d [strided array](https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays).

src/IntelVectorMath.jl

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,39 @@ for t in (Float32, Float64)
3535
def_unary_op(t, t, :cbrt, :cbrt!, :Cbrt)
3636
def_unary_op(t, t, :expm1, :expm1!, :Expm1)
3737
def_unary_op(t, t, :log1p, :log1p!, :Log1p)
38+
def_unary_op(t, t, :log2, :log2!, :Log2)
3839
def_unary_op(t, t, :abs, :abs!, :Abs)
3940
def_unary_op(t, t, :abs2, :abs2!, :Sqr)
4041
def_unary_op(t, t, :ceil, :ceil!, :Ceil)
4142
def_unary_op(t, t, :floor, :floor!, :Floor)
4243
def_unary_op(t, t, :round, :round!, :Round)
4344
def_unary_op(t, t, :trunc, :trunc!, :Trunc)
4445

46+
# Enabled only for Real. MKL guarantees higher accuracy, but at a
47+
# substantial performance cost.
48+
def_unary_op(t, t, :atan, :atan!, :Atan)
49+
def_unary_op(t, t, :cos, :cos!, :Cos)
50+
def_unary_op(t, t, :sin, :sin!, :Sin)
51+
def_unary_op(t, t, :tan, :tan!, :Tan)
52+
def_unary_op(t, t, :atanh, :atanh!, :Atanh)
53+
def_unary_op(t, t, :cosh, :cosh!, :Cosh)
54+
def_unary_op(t, t, :sinh, :sinh!, :Sinh)
55+
def_unary_op(t, t, :tanh, :tanh!, :Tanh)
56+
def_unary_op(t, t, :log10, :log10!, :Log10)
57+
58+
# Unary, real-only
59+
def_unary_op(t, t, :cospi, :cospi!, :Cospi)
60+
def_unary_op(t, t, :sinpi, :sinpi!, :Sinpi)
61+
def_unary_op(t, t, :tanpi, :tanpi!, :Tanpi)
62+
def_unary_op(t, t, :acospi, :acospi!, :Acospi)
63+
def_unary_op(t, t, :asinpi, :asinpi!, :Asinpi)
64+
def_unary_op(t, t, :atanpi, :atanpi!, :Atanpi)
65+
def_unary_op(t, t, :cosd, :cosd!, :Cosd)
66+
def_unary_op(t, t, :sind, :sind!, :Sind)
67+
def_unary_op(t, t, :tand, :tand!, :Tand)
68+
69+
def_one2two_op(t, t, :sincos, :sincos!, :SinCos)
70+
4571
# now in SpecialFunctions (make smart, maybe?)
4672
def_unary_op(t, t, :erf, :erf!, :Erf)
4773
def_unary_op(t, t, :erfc, :erfc!, :Erfc)
@@ -55,18 +81,6 @@ for t in (Float32, Float64)
5581
def_unary_op(t, t, :pow2o3, :pow2o3!, :Pow2o3)
5682
def_unary_op(t, t, :pow3o2, :pow3o2!, :Pow3o2)
5783

58-
# Enabled only for Real. MKL guarantees higher accuracy, but at a
59-
# substantial performance cost.
60-
def_unary_op(t, t, :atan, :atan!, :Atan)
61-
def_unary_op(t, t, :cos, :cos!, :Cos)
62-
def_unary_op(t, t, :sin, :sin!, :Sin)
63-
def_unary_op(t, t, :tan, :tan!, :Tan)
64-
def_unary_op(t, t, :atanh, :atanh!, :Atanh)
65-
def_unary_op(t, t, :cosh, :cosh!, :Cosh)
66-
def_unary_op(t, t, :sinh, :sinh!, :Sinh)
67-
def_unary_op(t, t, :tanh, :tanh!, :Tanh)
68-
def_unary_op(t, t, :log10, :log10!, :Log10)
69-
7084
# # .^ to scalar power
7185
# mklfn = Base.Meta.quot(Symbol("$(vml_prefix(t))Powx"))
7286
# @eval begin
@@ -87,6 +101,7 @@ for t in (Float32, Float64)
87101

88102
# # Binary, real-only
89103
def_binary_op(t, t, :atan, :atan!, :Atan2, false)
104+
def_binary_op(t, t, :atanpi, :atanpi!, :Atan2pi, false)
90105
def_binary_op(t, t, :hypot, :hypot!, :Hypot, false)
91106

92107
# Unary, complex-only

src/setup.jl

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -264,57 +264,131 @@ function vml_prefix(t::DataType)
264264
error("unknown type $t")
265265
end
266266

267+
if isdefined(Base, :_checkcontiguous)
268+
alldense(@nospecialize(x)) = Base._checkcontiguous(Bool, x)
269+
else
270+
alldense(x) = x isa DenseArray
271+
alldense(x::Base.ReshapedArray) = alldense(parent(x))
272+
alldense(x::Base.FastContiguousSubArray) = alldense(parent(x))
273+
alldense(x::Base.ReinterpretArray) = alldense(parent(x))
274+
end
275+
alldense(x, y, z...) = alldense(x) && alldense(y, z...)
276+
277+
if isdefined(Base, :merge_adjacent_dim)
278+
const merge_adjacent_dim = Base.merge_adjacent_dim
279+
else
280+
merge_adjacent_dim(::Dims{0}, ::Dims{0}) = 1, 1, 0
281+
merge_adjacent_dim(apsz::Dims{1}, apst::Dims{1}) = apsz[1], apst[1], 1
282+
function merge_adjacent_dim(apsz::Dims{N}, apst::Dims{N}, n::Int = 1) where {N}
283+
sz, st = apsz[n], apst[n]
284+
while n < N
285+
szₙ, stₙ = apsz[n+1], apst[n+1]
286+
if sz == 1
287+
sz, st = szₙ, stₙ
288+
elseif stₙ == st * sz || szₙ == 1
289+
sz *= szₙ
290+
else
291+
break
292+
end
293+
n += 1
294+
end
295+
return sz, st, n
296+
end
297+
end
298+
299+
getstrides(x...) = map(stride1, x)
300+
function stride1(x::AbstractArray)
301+
alldense(x) && return 1
302+
ndims(x) == 1 && return stride(x, 1)
303+
szs::Dims = size(x)
304+
sts::Dims = strides(x)
305+
_, st, n = merge_adjacent_dim(szs, sts)
306+
n === ndims(x) && return st
307+
throw(ArgumentError("only support vector like inputs"))
308+
end
309+
267310
function def_unary_op(tin, tout, jlname, jlname!, mklname;
268311
vmltype = tin)
269-
mklfn = Base.Meta.quot(Symbol("$(vml_prefix(vmltype))$mklname"))
312+
mklfn = Base.Meta.quot(Symbol("$(vml_prefix(vmltype))$(mklname)I"))
313+
mklfndense = Base.Meta.quot(Symbol("$(vml_prefix(vmltype))$mklname"))
270314
exports = Symbol[]
271315
(@isdefined jlname) || push!(exports, jlname)
272316
(@isdefined jlname!) || push!(exports, jlname!)
273317
@eval begin
274-
function ($jlname!)(out::Array{$tout}, A::Array{$tin})
318+
function ($jlname!)(out::AbstractArray{$tout}, A::AbstractArray{$tin})
275319
size(out) == size(A) || throw(DimensionMismatch())
276-
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, out)
320+
if alldense(out, A) || ((sts = getstrides(out, A)) == (1, 1))
321+
ccall(($mklfndense, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, out)
322+
else
323+
stᵒ, stᴬ = sts
324+
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Int, Ptr{$tout}, Int), length(A), A, stᴬ, out, stᵒ)
325+
end
277326
vml_check_error()
278327
return out
279328
end
280329
$(if tin == tout
281330
quote
282-
function $(jlname!)(A::Array{$tin})
283-
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, A)
331+
function $(jlname!)(A::AbstractArray{$tin})
332+
if alldense(A) || ((sts = getstrides(A)) == (1,))
333+
ccall(($mklfndense, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, A)
334+
else
335+
(stᴬ,) = sts
336+
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Int, Ptr{$tout}, Int), length(A), A, stᴬ, A, stᴬ)
337+
end
284338
vml_check_error()
285339
return A
286340
end
287341
end
288342
end)
289-
function ($jlname)(A::Array{$tin})
290-
out = similar(A, $tout)
291-
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, out)
292-
vml_check_error()
293-
return out
294-
end
343+
($jlname)(A::AbstractArray{$tin}) = $(jlname!)(similar(A, $tout), A)
295344
$(isempty(exports) ? nothing : Expr(:export, exports...))
296345
end
297346
end
298347

299348
function def_binary_op(tin, tout, jlname, jlname!, mklname, broadcast)
300-
mklfn = Base.Meta.quot(Symbol("$(vml_prefix(tin))$mklname"))
349+
mklfndense = Base.Meta.quot(Symbol("$(vml_prefix(tin))$mklname"))
350+
mklfn = Base.Meta.quot(Symbol("$(vml_prefix(tin))$(mklname)I"))
301351
exports = Symbol[]
302352
(@isdefined jlname) || push!(exports, jlname)
303353
(@isdefined jlname!) || push!(exports, jlname!)
304354
@eval begin
305355
$(isempty(exports) ? nothing : Expr(:export, exports...))
306-
function ($jlname!)(out::Array{$tout}, A::Array{$tin}, B::Array{$tin})
307-
size(out) == size(A) == size(B) || throw(DimensionMismatch("Input arrays and output array need to have the same size"))
308-
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tin}, Ptr{$tout}), length(A), A, B, out)
356+
function ($jlname!)(out::AbstractArray{$tout}, A::AbstractArray{$tin}, B::AbstractArray{$tin})
357+
size(A) == size(B) || throw(DimensionMismatch("Input arrays need to have the same size"))
358+
size(out) == size(A) || throw(DimensionMismatch("Output array need to have the same size with input"))
359+
if alldense(out, A, B) || ((sts = getstrides(out, A, B)) == (1, 1, 1))
360+
ccall(($mklfndense, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tin}, Ptr{$tout}), length(A), A, B, out)
361+
else
362+
stᵒ, stᴬ, stᴮ = sts
363+
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Int, Ptr{$tin}, Int, Ptr{$tout}, Int), length(A), A, stᴬ, B, stᴮ, out, stᵒ)
364+
end
309365
vml_check_error()
310366
return out
311367
end
312-
function ($jlname)(A::Array{$tout}, B::Array{$tin})
313-
size(A) == size(B) || throw(DimensionMismatch("Input arrays need to have the same size"))
314-
out = similar(A)
315-
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tin}, Ptr{$tout}), length(A), A, B, out)
368+
($jlname)(A::AbstractArray{$tin}, B::AbstractArray{$tin}) = ($jlname!)(similar(A, $tout), A, B)
369+
end
370+
end
371+
372+
function def_one2two_op(tin, tout, jlname, jlname!, mklname)
373+
mklfndense = Base.Meta.quot(Symbol("$(vml_prefix(tin))$mklname"))
374+
mklfn = Base.Meta.quot(Symbol("$(vml_prefix(tin))$(mklname)I"))
375+
exports = Symbol[]
376+
(@isdefined jlname) || push!(exports, jlname)
377+
(@isdefined jlname!) || push!(exports, jlname!)
378+
@eval begin
379+
$(isempty(exports) ? nothing : Expr(:export, exports...))
380+
function ($jlname!)(out1::AbstractArray{$tout}, out2::AbstractArray{$tout}, A::AbstractArray{$tin})
381+
size(out1) == size(out2) || throw(DimensionMismatch("Output arrays need to have the same size"))
382+
size(A) == size(out2) || throw(DimensionMismatch("Output array need to have the same size with input"))
383+
if alldense(out1, out2, A) || ((sts = getstrides(out1, out2, A)) == (1, 1, 1))
384+
ccall(($mklfndense, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tin}, Ptr{$tout}), length(A), A, out1, out2)
385+
else
386+
st¹, st², stᴬ = sts
387+
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Int, Ptr{$tin}, Int, Ptr{$tout}, Int), length(A), A, stᴬ, out1, st¹, out2, st²)
388+
end
316389
vml_check_error()
317-
return out
390+
return out1, out2
318391
end
392+
($jlname)(A::AbstractArray{$tin}) = ($jlname!)(similar(A, $tout), similar(A, $tout), A)
319393
end
320394
end

test/common.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
using SpecialFunctions
22
const base_unary_real = (
33
(Base, :acos, (-1, 1)),
4+
(Base, :acospi, (-1, 1)),
45
(Base, :asin, (-1, 1)),
6+
(Base, :asinpi, (-1, 1)),
57
(Base, :atan, (-50, 50)),
8+
(Base, :atanpi, (-50, 50)),
69
(Base, :cos, (-1000, 1000)),
10+
(Base, :cosd, (-10000, 10000)),
11+
(Base, :cospi, (-300, 300)),
712
(Base, :sin, (-1000, 1000)),
13+
(Base, :sind, (-10000, 10000)),
14+
(Base, :sinpi, (-300, 300)),
815
(Base, :tan, (-1000, 1000)),
16+
(Base, :tand, (-10000, 10000)),
17+
(Base, :tanpi, (-300, 300)),
918
(Base, :acosh, (1, 1000)),
1019
(Base, :asinh, (-1000, 1000)),
1120
(Base, :atanh, (-1, 1)),
@@ -17,6 +26,7 @@ const base_unary_real = (
1726
(Base, :exp, (-88.72284f0, 88.72284f0)),
1827
(Base, :expm1, (-88.72284f0, 88.72284f0)),
1928
(Base, :log, (0, 1000)),
29+
(Base, :log2, (0, 1000)),
2030
(Base, :log10, (0, 1000)),
2131
(Base, :log1p, (-1, 1000)),
2232
(Base, :abs, (-1000, 1000)),

test/real.jl

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,58 @@ const fns = [[x[1:2] for x in base_unary_real]; [x[1:2] for x in base_binary_rea
1919
for t in (Float32, Float64), i = 1:length(fns)
2020
inp = input[t][i]
2121
mod, fn = fns[i]
22-
base_fn = getproperty(mod, fn)
22+
if fn === :acospi || fn === :asinpi || fn === :atanpi
23+
fn′ = getproperty(mod, Symbol(string(fn)[1:end-2]))
24+
base_fn = x -> oftype(x, fn′(widen(x))/pi)
25+
elseif fn === :tanpi
26+
base_fn = x -> oftype(x, Base.tan(widen(x)*pi))
27+
else
28+
base_fn = getproperty(mod, fn)
29+
end
2330
vml_fn = getproperty(IntelVectorMath, fn)
2431
vml_fn! = getproperty(IntelVectorMath, Symbol(fn, "!"))
2532

2633
Test.@test parentmodule(vml_fn) == IntelVectorMath
2734

2835
# Test.test_approx_eq(output[t][i], fn(input[t][i]...), "Base $t $fn", "IntelVectorMath $t $fn")
2936
baseres = base_fn.(inp...)
30-
Test.@test vml_fn(inp...) base_fn.(inp...)
37+
Test.@test vml_fn(inp...) baseres
3138

3239
# cis changes type (float to complex, does not have mutating function)
3340
if length(inp) == 1
3441
if fn != :cis
35-
vml_fn!(inp[1])
36-
Test.@test inp[1] baseres
42+
temp = similar(inp[1], 2NVALS)
43+
inp1′ = @views copyto!(temp[1:2:end], inp[1])
44+
inp1″ = @views copyto!(temp[end:-2:1], inp[1])
45+
for x in (inp[1], inp1′, inp1″)
46+
vml_fn!(x)
47+
Test.@test x baseres
48+
end
3749
end
3850
elseif length(inp) == 2
3951
out = similar(inp[1])
40-
vml_fn!(out, inp...)
41-
Test.@test out baseres
52+
temp = similar(inp[1], 2NVALS)
53+
x′ = @views copyto!(temp[1:2:end], inp[1])
54+
y′ = @views copyto!(temp[end:-2:1], inp[2])
55+
for (x, y) in (inp, (x′, y′))
56+
vml_fn!(out, x, y)
57+
Test.@test out baseres
58+
end
4259
end
4360

4461
end
4562

4663
end
4764

65+
@testset "sincos" begin
66+
for t in (Float32, Float64)
67+
a = randindomain(t, NVALS, (-1000, 1000))
68+
s, c = IVM.sincos(a)
69+
@test s IVM.sin(a)
70+
@test c IVM.cos(a)
71+
end
72+
end
73+
4874
@testset "Error Handling and Settings" begin
4975

5076
# Verify that we still throw DomainErrors

0 commit comments

Comments
 (0)