Skip to content

Call xxxI instead of xxx to support strided layout. #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions src/IntelVectorMath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ function __init__()
if Sys.isapple() && haskey(Base.loaded_modules, Base.PkgId(compilersupportlibaries_jll_uuid, "CompilerSupportLibraries_jll"))
@warn "It appears CompilerSupportLibraries_jll was loaded prior to this package, which currently on mac may lead to wrong results in some cases. For further details see github.com/JuliaMath/IntelVectorMath.jl"
end
set_num_threads(Threads.nthreads())
end

for t in (Float32, Float64, ComplexF32, ComplexF64)
Expand All @@ -35,13 +36,37 @@ for t in (Float32, Float64)
def_unary_op(t, t, :cbrt, :cbrt!, :Cbrt)
def_unary_op(t, t, :expm1, :expm1!, :Expm1)
def_unary_op(t, t, :log1p, :log1p!, :Log1p)
def_unary_op(t, t, :log2, :log2!, :Log2)
def_unary_op(t, t, :abs, :abs!, :Abs)
def_unary_op(t, t, :abs2, :abs2!, :Sqr)
def_unary_op(t, t, :ceil, :ceil!, :Ceil)
def_unary_op(t, t, :floor, :floor!, :Floor)
def_unary_op(t, t, :round, :round!, :Round)
def_unary_op(t, t, :trunc, :trunc!, :Trunc)

# Enabled only for Real. MKL guarantees higher accuracy, but at a
# substantial performance cost.
def_unary_op(t, t, :atan, :atan!, :Atan)
def_unary_op(t, t, :cos, :cos!, :Cos)
def_unary_op(t, t, :sin, :sin!, :Sin)
def_unary_op(t, t, :tan, :tan!, :Tan)
def_unary_op(t, t, :atanh, :atanh!, :Atanh)
def_unary_op(t, t, :cosh, :cosh!, :Cosh)
def_unary_op(t, t, :sinh, :sinh!, :Sinh)
def_unary_op(t, t, :tanh, :tanh!, :Tanh)
def_unary_op(t, t, :log10, :log10!, :Log10)

# Unary, real-only
def_unary_op(t, t, :cospi, :cospi!, :Cosp)
def_unary_op(t, t, :sinpi, :sinpi!, :Sinp)
def_unary_op(t, t, :tanpi, :tanpi!, :Tanp)
def_unary_op(t, t, :acospi, :acospi!, :Acosp)
def_unary_op(t, t, :asinpi, :asinpi!, :Asinp)
def_unary_op(t, t, :atanpi, :atanpi!, :Atanp)
def_unary_op(t, t, :cosd, :cosd!, :Cosd)
def_unary_op(t, t, :sind, :sind!, :Sind)
def_unary_op(t, t, :tand, :tand!, :Tand)

# now in SpecialFunctions (make smart, maybe?)
def_unary_op(t, t, :erf, :erf!, :Erf)
def_unary_op(t, t, :erfc, :erfc!, :Erfc)
Expand All @@ -55,18 +80,6 @@ for t in (Float32, Float64)
def_unary_op(t, t, :pow2o3, :pow2o3!, :Pow2o3)
def_unary_op(t, t, :pow3o2, :pow3o2!, :Pow3o2)

# Enabled only for Real. MKL guarantees higher accuracy, but at a
# substantial performance cost.
def_unary_op(t, t, :atan, :atan!, :Atan)
def_unary_op(t, t, :cos, :cos!, :Cos)
def_unary_op(t, t, :sin, :sin!, :Sin)
def_unary_op(t, t, :tan, :tan!, :Tan)
def_unary_op(t, t, :atanh, :atanh!, :Atanh)
def_unary_op(t, t, :cosh, :cosh!, :Cosh)
def_unary_op(t, t, :sinh, :sinh!, :Sinh)
def_unary_op(t, t, :tanh, :tanh!, :Tanh)
def_unary_op(t, t, :log10, :log10!, :Log10)

# # .^ to scalar power
# mklfn = Base.Meta.quot(Symbol("$(vml_prefix(t))Powx"))
# @eval begin
Expand All @@ -87,6 +100,7 @@ for t in (Float32, Float64)

# # Binary, real-only
def_binary_op(t, t, :atan, :atan!, :Atan2, false)
def_binary_op(t, t, :atanpi, :atanpi!, :Atan2pi, false)
def_binary_op(t, t, :hypot, :hypot!, :Hypot, false)

# Unary, complex-only
Expand All @@ -109,4 +123,12 @@ end

export VML_LA, VML_HA, VML_EP, vml_set_accuracy, vml_get_accuracy

function get_num_threads()::Int
ccall((:MKL_Domain_Get_Max_Threads, MKL_jll.libmkl_rt), Cint, (Cint,), 3)
end
function set_num_threads(n::Integer)
flag = ccall((:MKL_Domain_Set_Num_Threads, MKL_jll.libmkl_rt), Cint, (Cint, Cint), min(n, Threads.nthreads()), 3)
flag == 1 || throw("Vml threads setting failed with $flag")
return
end
end
65 changes: 43 additions & 22 deletions src/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const VML_EP = VMLAccuracy(0x00000003)

Base.show(io::IO, m::VMLAccuracy) = print(io, m == VML_LA ? "VML_LA" :
m == VML_HA ? "VML_HA" : "VML_EP")

vml_get_mode() = ccall((:vmlGetMode, MKL_jll.libmkl_rt), Cuint, ())
vml_set_mode(mode::Integer) = (ccall((:vmlSetMode, MKL_jll.libmkl_rt), Cuint, (UInt,), mode); nothing)

Expand Down Expand Up @@ -45,58 +45,79 @@ function vml_prefix(t::DataType)
end
error("unknown type $t")
end
alldense(x, y, z...) = alldense(x) && alldense(y, z...)
alldense(x) = x isa DenseArray
alldense(x::Base.ReshapedArray) = alldense(parent(x))
alldense(x::Base.FastContiguousSubArray) = alldense(parent(x))
alldense(x::Base.ReinterpretArray) = alldense(parent(x))

getstrides(x...) = map(stride1, x)
function stride1(x::AbstractArray)
alldense(x) && return 1
ndims(x) == 1 && return stride(x, 1)
st = strides(x)
st == Base.size_to_strides(st[1], size(x)...) || throw(ArgumentError("Invalid memory layout."))
st[1]
end

function def_unary_op(tin, tout, jlname, jlname!, mklname;
vmltype = tin)
mklfn = Base.Meta.quot(Symbol("$(vml_prefix(vmltype))$mklname"))
mklfn = Base.Meta.quot(Symbol("$(vml_prefix(vmltype))$(mklname)I"))
mklfndense = Base.Meta.quot(Symbol("$(vml_prefix(vmltype))$mklname"))
exports = Symbol[]
(@isdefined jlname) || push!(exports, jlname)
(@isdefined jlname!) || push!(exports, jlname!)
@eval begin
function ($jlname!)(out::Array{$tout}, A::Array{$tin})
function ($jlname!)(out::AbstractArray{$tout}, A::AbstractArray{$tin})
size(out) == size(A) || throw(DimensionMismatch())
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, out)
if alldense(out, A)
ccall(($mklfndense, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, out)
else
stᵒ, stᴬ = getstrides(out, A)
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Int, Ptr{$tout}, Int), length(A), A, stᴬ, out, stᵒ)
end
vml_check_error()
return out
end
$(if tin == tout
quote
function $(jlname!)(A::Array{$tin})
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, A)
function $(jlname!)(A::AbstractArray{$tin})
if alldense(A)
ccall(($mklfndense, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, A)
else
(stᴬ,) = getstrides(A)
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Int, Ptr{$tout}, Int), length(A), A, stᴬ, A, stᴬ)
end
vml_check_error()
return A
end
end
end)
function ($jlname)(A::Array{$tin})
out = similar(A, $tout)
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tout}), length(A), A, out)
vml_check_error()
return out
end
($jlname)(A::AbstractArray{$tin}) = $(jlname!)(similar(A, $tout), A)
$(isempty(exports) ? nothing : Expr(:export, exports...))
end
end

function def_binary_op(tin, tout, jlname, jlname!, mklname, broadcast)
mklfn = Base.Meta.quot(Symbol("$(vml_prefix(tin))$mklname"))
mklfndense = Base.Meta.quot(Symbol("$(vml_prefix(tin))$mklname"))
mklfn = Base.Meta.quot(Symbol("$(vml_prefix(tin))$(mklname)I"))
exports = Symbol[]
(@isdefined jlname) || push!(exports, jlname)
(@isdefined jlname!) || push!(exports, jlname!)
@eval begin
$(isempty(exports) ? nothing : Expr(:export, exports...))
function ($jlname!)(out::Array{$tout}, A::Array{$tin}, B::Array{$tin})
size(out) == size(A) == size(B) || throw(DimensionMismatch("Input arrays and output array need to have the same size"))
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tin}, Ptr{$tout}), length(A), A, B, out)
vml_check_error()
return out
end
function ($jlname)(A::Array{$tout}, B::Array{$tin})
function ($jlname!)(out::AbstractArray{$tout}, A::AbstractArray{$tin}, B::AbstractArray{$tin})
size(A) == size(B) || throw(DimensionMismatch("Input arrays need to have the same size"))
out = similar(A)
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tin}, Ptr{$tout}), length(A), A, B, out)
size(out) == size(A) || throw(DimensionMismatch("Output array need to have the same size with input"))
if alldense(out, A, B)
ccall(($mklfndense, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Ptr{$tin}, Ptr{$tout}), length(A), A, B, out)
else
stᵒ, stᴬ, stᴮ = getstrides(out, A, B)
ccall(($mklfn, MKL_jll.libmkl_rt), Nothing, (Int, Ptr{$tin}, Int, Ptr{$tin}, Int, Ptr{$tout}, Int), length(A), A, stᴬ, B, stᴮ, out, stᵒ)
end
vml_check_error()
return out
end
($jlname)(A::AbstractArray{$tin}, B::AbstractArray{$tin}) = ($jlname!)(similar(A, $tout), A, B)
end
end