Skip to content

some refactoring of the overload macro #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@ version = "0.3"
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
BinaryProvider = "0.5.8"
CpuId = "0.2"
SpecialFunctions = "0.8, 0.9, 0.10"
julia = "0.7, 1.0"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[targets]
test = ["Test"]
test = ["Test", "SpecialFunctions"]
59 changes: 25 additions & 34 deletions src/IntelVectorMath.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
__precompile__()

module IntelVectorMath

export IVM
const IVM = IntelVectorMath

# import Base: .^, ./
using SpecialFunctions
# using Libdl
include("../deps/deps.jl")

Expand Down Expand Up @@ -105,46 +102,40 @@ for t in (Float32, Float64)
end

"""
@overload exp log sin

This macro adds a method to each function in `Base` (or perhaps in `SpecialFunctions`),
so that when acting on an array (or two arrays) it calls the `IntelVectorMath` function of the same name.
@vml_overload Base.exp Base.log SpecialFunctions.erfc

The existing action on scalars is unaffected. However, `exp(M::Matrix)` will now mean
element-wise `IntelVectorMath.exp(M) == exp.(M)`, rather than matrix exponentiation.
This macro adds a method to each given function in `Base` or `SpecialFunctions`,
so that when acting on a `Vector` (or two `Vector`s) it calls the `IntelVectorMath` function of the same name.
"""
macro overload(funs...)
macro vml_overload(funs...)
out = quote end
say = []
for f in funs
if f in _UNARY
if isdefined(Base, f)
push!(out.args, :( Base.$f(A::Array) = IntelVectorMath.$f(A) ))
push!(say, "Base.$f(A)")
elseif isdefined(SpecialFunctions, f)
push!(out.args, :( IntelVectorMath.SpecialFunctions.$f(A::Array) = IntelVectorMath.$f(A) ))
push!(say, "SpecialFunctions.$f(A)")
else
@error "function IntelVectorMath.$f is not defined in Base or SpecialFunctions, so there is nothing to overload"
end
if f.head !== :(.) || !(length(f.args) == 2) || !(f.args[1] isa Symbol && f.args[2] isa QuoteNode)
error("expected a Module.function type of expression, got $f")
end
mod, f = f.args[1], f.args[2].value
if !(mod in (:Base, :SpecialFunctions))
error("expected module to be either Base or SpecialFunctions, got $mod")
end
if f in keys(_UNARY)
input_types = _UNARY[f]
expr = :($(esc(mod)).$f(A::Vector{T}) where {T <: Union{$(input_types...)}} =
IntelVectorMath.$f(A))
push!(out.args, expr)
end
if f in _BINARY
if isdefined(Base, f)
push!(out.args, :( Base.$f(A::Array, B::Array) = IntelVectorMath.$f(A, B) ))
push!(say, "Base.$f(A, B)")
else
@error "function IntelVectorMath.$f is not defined in Base, so there is nothing to overload"
end
if f in keys(_BINARY)
input_types = _BINARY[f]
expr = :($(esc(mod)).$f(A::Vector{T}, B::Vector{T}) where {T <: Union{$(input_types...)}} =
IntelVectorMath.$f(A, B))
push!(out.args, expr)
end
if !(f in _UNARY) && !(f in _BINARY)
error("there is no function $f defined by IntelVectorMath.jl")
if !(f in keys(_UNARY)) && !(f in keys(_BINARY))
error("there is no function $f defined in IntelVectorMath.jl")
end
end
str = string("Overloaded these functions: \n ", join(say, " \n "))
push!(out.args, str)
esc(out)
return out
end

export VML_LA, VML_HA, VML_EP, vml_set_accuracy, vml_get_accuracy, @overload
export VML_LA, VML_HA, VML_EP, vml_set_accuracy, vml_get_accuracy, @vml_overload

end
10 changes: 5 additions & 5 deletions src/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ const VML_LA = VMLAccuracy(0x00000001)
const VML_HA = VMLAccuracy(0x00000002)
const VML_EP = VMLAccuracy(0x00000003)

const _UNARY = [] # for @overload to check
const _BINARY = []
const _UNARY = Dict{Symbol, Vector{DataType}}() # for @vml_overload to check
const _BINARY = Dict{Symbol, Vector{DataType}}()

Base.show(io::IO, m::VMLAccuracy) = print(io, m == VML_LA ? "VML_LA" :
m == VML_HA ? "VML_HA" : "VML_EP")
Expand Down Expand Up @@ -59,13 +59,13 @@ function vml_prefix(t::DataType)
error("unknown type $t")
end

function def_unary_op(tin, tout, jlname, jlname!, mklname;
function def_unary_op(tin, tout, jlname, jlname!, mklname;
vmltype = tin)
mklfn = Base.Meta.quot(Symbol("$(vml_prefix(vmltype))$mklname"))
exports = Symbol[]
(@isdefined jlname) || push!(exports, jlname)
(@isdefined jlname!) || push!(exports, jlname!)
push!(_UNARY, jlname)
push!(get!(_UNARY, jlname, DataType[]), tin)
@eval begin
function ($jlname!)(out::Array{$tout,N}, A::Array{$tin,N}) where {N}
size(out) == size(A) || throw(DimensionMismatch())
Expand Down Expand Up @@ -97,7 +97,7 @@ function def_binary_op(tin, tout, jlname, jlname!, mklname, broadcast)
exports = Symbol[]
(@isdefined jlname) || push!(exports, jlname)
(@isdefined jlname!) || push!(exports, jlname!)
push!(_BINARY, jlname)
push!(get!(_BINARY, jlname, DataType[]), tin)
@eval begin
$(isempty(exports) ? nothing : Expr(:export, exports...))
function ($jlname!)(out::Array{$tout,N}, A::Array{$tin,N}, B::Array{$tin,N}) where {N}
Expand Down
16 changes: 9 additions & 7 deletions test/real.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fns = [[x[1:2] for x in base_unary_real]; [x[1:2] for x in base_binary_real]]
@testset "Definitions and Comparison with Base for Reals" begin

for t in (Float32, Float64), i = 1:length(fns)
base_fn = eval(:($(fns[i][1]).$(fns[i][2])))
base_fn = eval(:($(fns[i][1]).$(fns[i][2])))
vml_fn = eval(:(IntelVectorMath.$(fns[i][2])))
vml_fn! = eval(:(IntelVectorMath.$(Symbol(fns[i][2], !))))

Expand All @@ -28,10 +28,10 @@ fns = [[x[1:2] for x in base_unary_real]; [x[1:2] for x in base_binary_real]]
Test.@test vml_fn(input[t][i]...) ≈ baseres

# cis changes type (float to complex, does not have mutating function)


if length(input[t][i]) == 1
if fns[i][2] != :cis
if fns[i][2] != :cis
vml_fn!(input[t][i]...)
Test.@test input[t][i][1] ≈ baseres
end
Expand Down Expand Up @@ -60,15 +60,17 @@ end

end

@testset "@overload macro" begin

@testset "@vml_overload macro" begin
@test IntelVectorMath.exp([1.0]) ≈ exp.([1.0])
@test_throws MethodError Base.exp([1.0])
@test (@overload log exp) isa String
@vml_overload Base.log Base.exp
@test Base.exp([1.0]) ≈ exp.([1.0])

@test_throws MethodError Base.atan([1.0], [2.0])
@test (@overload atan) isa String
@vml_overload Base.atan
@test Base.atan([1.0], [2.0]) ≈ atan.([1.0], [2.0])

@test_throws MethodError SpecialFunctions.erfc([1.0, 2.0])
@vml_overload SpecialFunctions.erfc
@test SpecialFunctions.erfc([1.0, 2.0]) ≈ SpecialFunctions.erfc.([1.0, 2.0])
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using IntelVectorMath
using SpecialFunctions

include("common.jl")
include("real.jl")
Expand Down