Skip to content

Use extensions for weakdeps #464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
name = "LoopVectorization"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
authors = ["Chris Elrod <[email protected]>"]
version = "0.12.148"
version = "0.12.149"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
ForwardDiffExt = ["ChainRulesCore", "ForwardDiff"]
SpecialFunctionsExt = "SpecialFunctions"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
79 changes: 68 additions & 11 deletions src/simdfunctionals/vmap_grad_rrule.jl → ext/ForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,60 @@
module ForwardDiffExt
import ForwardDiff, ChainRulesCore
using SIMDDualNumbers, LoopVectorization
using LoopVectorization:
AbstractSIMD,
AbstractStridedPointer,
relu,
vmap,
VectorizationBase,
vmapt,
vmapnt,
vmapntt,
MM,
StaticInt,
vadd_nw,
vsub_nsw,
vload,
mask,
vfnmadd_fast,
mul_fast
using VectorizationBase: zero_offsets

import .ChainRulesCore
@generated function init_dual(v::Tuple{Vararg{AbstractSIMD,A}}) where {A}
res = Expr(:tuple)
q = Expr(:block, Expr(:meta, :inline))
for a ∈ 1:A
v_a = Symbol(:v_, a)
push!(q.args, Expr(:(=), v_a, Expr(:ref, :v, a)))
partials = Expr(:tuple)
for i ∈ 1:A
push!(partials.args, Expr(:call, i == a ? :one : :zero, v_a))
end
push!(res.args, :(ForwardDiff.Dual($v_a, ForwardDiff.Partials($partials))))
end
push!(q.args, res)
q
end
@generated function dual_store!(
∂p::Tuple{Vararg{AbstractStridedPointer,A}},
p::AbstractStridedPointer,
∂v,
im::Vararg{Any,N}
) where {A,N}
quote
$(Expr(:meta, :inline))
v = ∂v.value
∂ = ∂v.partials
Base.Cartesian.@nextract $N im im
Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! p v im # store
Base.Cartesian.@nexprs $A a -> begin # for each of `A` partials
∂p_a = ∂p[a]
∂_a = ∂[a]
Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! ∂p_a ∂_a im # store
end
nothing
end
end

if isdefined(ChainRulesCore, :ZeroTangent)
const ChainRulesZero = ChainRulesCore.ZeroTangent
Expand Down Expand Up @@ -38,32 +93,33 @@ function ∂vmap_singlethread!(
args::Vararg{DenseArray{<:Base.HWReal},A}
) where {F,T<:Base.HWReal,A}
N = length(y)
ptry = VectorizationBase.zero_offsets(stridedpointer(y))
ptrargs = VectorizationBase.zero_offsets.(stridedpointer.(args))
ptr∂y = VectorizationBase.zero_offsets.(stridedpointer.(∂y))

ptry = zero_offsets(stridedpointer(y))
ptrargs = map(zero_offsets, map(stridedpointer, args))
ptr∂y = map(zero_offsets, map(stridedpointer, ∂y))
i = 0
V = VectorizationBase.pick_vector_width(T)
W = Int(V)
st = VectorizationBase.static_sizeof(T)
zero_index = MM{W}(StaticInt(0), st)
while i < vsub_nsw(N, ((W << 2) - 1))
index = VectorizationBase.Unroll{1,W,4,1,W,zero(UInt)}((i,))
v = f(init_dual(vload.(ptrargs, index))...)
v = f(init_dual(map(Base.Fix2(vload, index), ptrargs))...)
dual_store!(ptr∂y, ptry, v, index)
i = vadd_nw(i, 4W)
end
while i < vsub_nsw(N, (W - 1))
vᵣ = f(init_dual(vload.(ptrargs, ((MM{W}(i),),)))...)
loader = Base.Fix2(vload, (MM{W}(i),))
vᵣ = f(init_dual(map(loader, ptrargs))...)
dual_store!(ptr∂y, ptry, vᵣ, (MM{W}(i),))
i = vadd_nw(i, W)
end
if i < N
m = mask(T, N & (W - 1))
mloader = let i = i, m = m
p -> vload(p, (MM{W}(i),), m)
end
dual_store!(
ptr∂y,
ptry,
f(init_dual(vload.(ptrargs, ((MM{W}(i),),), m))...),
f(init_dual(map(mloader, ptrargs))...),
(MM{W}(i),),
m
)
Expand Down Expand Up @@ -109,6 +165,7 @@ for f in (:vmapt, :vmapnt, :vmapntt)
f::F,
args::Vararg{Any,K}
) where {F,K}
ChainRulesCore.rrule(typeof(vmap), f, args...)
ChainRulesCore.rrule(typeof($vmap), f, args...)
end
end
end
6 changes: 6 additions & 0 deletions ext/SpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module SpecialFunctionsExt
using SpecialFunctions
using LoopVectorization: VectorizationBase
using LoopVectorization: AbstractSIMD
@inline SpecialFunctions.erf(x::AbstractSIMD) = VectorizationBase.verf(float(x))
end
8 changes: 4 additions & 4 deletions src/LoopVectorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ include("precompile.jl")

# import ChainRulesCore, ForwardDiff
# include("vmap_grad.jl")
using ChainRulesCore, ForwardDiff, SpecialFunctions
include("simdfunctionals/vmap_grad_rrule.jl")
include("simdfunctionals/vmap_grad_forwarddiff.jl")
@inline SpecialFunctions.erf(x::AbstractSIMD) = VectorizationBase.verf(float(x))
if !isdefined(Base, :get_extension)
include("../ext/ForwardDiffExt.jl")
include("../ext/SpecialFunctionsExt.jl")
end

end # module
38 changes: 0 additions & 38 deletions src/simdfunctionals/vmap_grad_forwarddiff.jl

This file was deleted.