Skip to content

Commit 57c7ffa

Browse files
authored
Extensions (#467)
* Use extensions for weakdeps * LSP added "using Base: get_extension", which is probably a bad idea * Fix extension syntax * Fixes * module * Fix zygote tests * Copy SIMDDualNumbers contents into extension, dropping it
1 parent 807675d commit 57c7ffa

File tree

3 files changed

+189
-6
lines changed

3 files changed

+189
-6
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.149"
4+
version = "0.12.150"
55

66
[weakdeps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -28,7 +28,6 @@ LayoutPointers = "10f19ff3-798f-405d-979b-55457f8fc047"
2828
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2929
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
3030
PolyesterWeave = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad"
31-
SIMDDualNumbers = "3cdde19b-5bb0-4aaf-8931-af3e248e098b"
3231
SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4"
3332
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
3433
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
@@ -53,7 +52,6 @@ IfElse = "0.1"
5352
LayoutPointers = "0.1.11"
5453
OffsetArrays = "1.4.1"
5554
PolyesterWeave = "0.1.10, 0.2"
56-
SIMDDualNumbers = "0.1"
5755
SIMDTypes = "0.1"
5856
SLEEFPirates = "0.6.23"
5957
SnoopPrecompile = "1"

ext/ForwardDiffExt.jl

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
module ForwardDiffExt
22
import ForwardDiff, ChainRulesCore
3-
using SIMDDualNumbers, LoopVectorization
3+
using LoopVectorization, VectorizationBase, SLEEFPirates, ForwardDiff
4+
5+
import IfElse: ifelse
6+
using VectorizationBase: AbstractSIMD, AbstractMask, zero_offsets
7+
48
using LoopVectorization:
59
AbstractSIMD,
610
AbstractStridedPointer,
@@ -18,7 +22,188 @@ using LoopVectorization:
1822
mask,
1923
vfnmadd_fast,
2024
mul_fast
21-
using VectorizationBase: zero_offsets
25+
26+
@generated function Base.abs(
27+
x::ForwardDiff.Dual{TAG,S,N}
28+
) where {TAG,S<:AbstractSIMD,N}
29+
quote
30+
$(Expr(:meta, :inline))
31+
val = x.value
32+
p = x.partials
33+
cmp = val < zero($S)
34+
absx = $ifelse(cmp, -val, val)
35+
Base.Cartesian.@nexprs $N n -> p_n = p[n]
36+
ForwardDiff.Dual{$TAG}(
37+
absx,
38+
ForwardDiff.Partials(
39+
Base.Cartesian.@ntuple $N n -> $ifelse(cmp, -p_n, p_n)
40+
)
41+
)
42+
end
43+
end
44+
@inline function Base.max(
45+
x::ForwardDiff.Dual{TAG,<:AbstractSIMD,N},
46+
y::ForwardDiff.Dual{TAG,<:AbstractSIMD,N}
47+
) where {TAG,N}
48+
vx = ForwardDiff.value(x)
49+
vy = ForwardDiff.value(y)
50+
xgy = vx > vy
51+
z = ifelse(xgy, vx, vy)
52+
p = VectorizationBase.fmap(
53+
ifelse,
54+
xgy,
55+
ForwardDiff.partials(x).values,
56+
ForwardDiff.partials(y).values
57+
)
58+
ForwardDiff.Dual{TAG}(z, ForwardDiff.Partials(p))
59+
end
60+
61+
@inline Base.max(
62+
x::T,
63+
y::Real
64+
) where {N,T<:ForwardDiff.Dual{<:Any,<:AbstractSIMD,N}} = max(x, T(y))
65+
@inline Base.max(
66+
y::Real,
67+
x::T
68+
) where {N,T<:ForwardDiff.Dual{<:Any,<:AbstractSIMD,N}} = max(x, T(y))
69+
@inline Base.max(
70+
x::T,
71+
y::Int
72+
) where {N,T<:ForwardDiff.Dual{<:Any,<:AbstractSIMD,N}} = max(x, T(y))
73+
@inline Base.max(
74+
y::Int,
75+
x::T
76+
) where {N,T<:ForwardDiff.Dual{<:Any,<:AbstractSIMD,N}} = max(x, T(y))
77+
78+
@inline function Base.min(
79+
x::ForwardDiff.Dual{TAG,<:AbstractSIMD,N},
80+
y::ForwardDiff.Dual{TAG,<:AbstractSIMD,N}
81+
) where {TAG,N}
82+
vx = ForwardDiff.value(x)
83+
vy = ForwardDiff.value(y)
84+
xgy = vx < vy
85+
z = ifelse(xgy, vx, vy)
86+
p = VectorizationBase.fmap(
87+
ifelse,
88+
xgy,
89+
ForwardDiff.partials(x).values,
90+
ForwardDiff.partials(y).values
91+
)
92+
ForwardDiff.Dual{TAG}(z, ForwardDiff.Partials(p))
93+
end
94+
@inline Base.min(
95+
x::T,
96+
y::Real
97+
) where {N,T<:ForwardDiff.Dual{<:Any,<:AbstractSIMD,N}} = min(x, T(y))
98+
@inline Base.min(
99+
y::Real,
100+
x::T
101+
) where {N,T<:ForwardDiff.Dual{<:Any,<:AbstractSIMD,N}} = min(x, T(y))
102+
@inline Base.min(
103+
x::T,
104+
y::Int
105+
) where {N,T<:ForwardDiff.Dual{<:Any,<:AbstractSIMD,N}} = min(x, T(y))
106+
@inline Base.min(
107+
y::Int,
108+
x::T
109+
) where {N,T<:ForwardDiff.Dual{<:Any,<:AbstractSIMD,N}} = min(x, T(y))
110+
111+
@generated function SLEEFPirates.tanh_fast(
112+
x::ForwardDiff.Dual{T,S,N}
113+
) where {T,S,N}
114+
quote
115+
$(Expr(:meta, :inline))
116+
t = tanh_fast(x.value)
117+
∂t = $(VectorizationBase.vfnmadd_fast)(t, t, one(S))
118+
p = x.partials
119+
ForwardDiff.Dual{T}(
120+
t,
121+
ForwardDiff.Partials(
122+
Base.Cartesian.@ntuple $N n -> $(Base.FastMath.mul_fast)(∂t, p[n])
123+
)
124+
)
125+
end
126+
end
127+
@generated function SLEEFPirates.sigmoid_fast(
128+
x::ForwardDiff.Dual{T,S,N}
129+
) where {T,S,N}
130+
quote
131+
$(Expr(:meta, :inline))
132+
s = sigmoid_fast(x.value)
133+
∂s = $(VectorizationBase.vfnmadd_fast)(s, s, s)
134+
p = x.partials
135+
ForwardDiff.Dual{T}(
136+
s,
137+
ForwardDiff.Partials(
138+
Base.Cartesian.@ntuple $N n -> $(Base.FastMath.mul_fast)(∂s, p[n])
139+
)
140+
)
141+
end
142+
end
143+
@generated function VectorizationBase.relu(
144+
x::ForwardDiff.Dual{T,S,N}
145+
) where {T,S,N}
146+
quote
147+
$(Expr(:meta, :inline))
148+
v = x.value
149+
z = zero(v)
150+
cmp = v < z
151+
r = ifelse(cmp, z, v)
152+
p = x.partials
153+
ForwardDiff.Dual{T}(
154+
r,
155+
ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, z, p[n]))
156+
)
157+
end
158+
end
159+
160+
@generated function ifelse(
161+
m::AbstractMask,
162+
x::ForwardDiff.Dual{TAG,V,P},
163+
y::ForwardDiff.Dual{TAG,V,P}
164+
) where {TAG,V,P}
165+
quote
166+
$(Expr(:meta, :inline))
167+
z = $ifelse(m, ForwardDiff.value(x), ForwardDiff.value(y))
168+
px = ForwardDiff.partials(x)
169+
py = ForwardDiff.partials(y)
170+
p = Base.Cartesian.@ntuple $P p -> $ifelse(m, px[p], py[p])
171+
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
172+
end
173+
end
174+
@generated function ifelse(
175+
m::AbstractMask,
176+
x::Number,
177+
y::ForwardDiff.Dual{TAG,V,P}
178+
) where {TAG,V,P}
179+
quote
180+
$(Expr(:meta, :inline))
181+
z = $ifelse(m, x, ForwardDiff.value(y))
182+
py = ForwardDiff.partials(y)
183+
p = Base.Cartesian.@ntuple $P p -> $ifelse(m, zero($V), py[p])
184+
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
185+
end
186+
end
187+
@generated function ifelse(
188+
m::AbstractMask,
189+
x::ForwardDiff.Dual{TAG,V,P},
190+
y::Number
191+
) where {TAG,V,P}
192+
quote
193+
$(Expr(:meta, :inline))
194+
z = $ifelse(m, ForwardDiff.value(x), y)
195+
px = ForwardDiff.partials(x)
196+
p = Base.Cartesian.@ntuple $P p -> $ifelse(m, px[p], zero($V))
197+
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
198+
end
199+
end
200+
@inline function SLEEFPirates.softplus(x::ForwardDiff.Dual{TAG}) where {TAG}
201+
val = ForwardDiff.value(x)
202+
expx = exp(val)
203+
vx = log1p(expx)
204+
px = Base.FastMath.inv_fast(one(val) + Base.FastMath.inv_fast(expx))
205+
ForwardDiff.Dual{TAG}(vx, Base.FastMath.mul_fast(ForwardDiff.partials(x), px))
206+
end
22207

23208
@generated function init_dual(v::Tuple{Vararg{AbstractSIMD,A}}) where {A}
24209
res = Expr(:tuple)

src/predicates.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ isscopedname(:(Base.Checked.checked_add), (:Base, :Checked), :checked_add)
1111
function isscopedname(ex, modpath, name::Symbol)
1212
isexpr(ex, :(.), 2) &&
1313
(a = ex.args[2]; isa(a, QuoteNode) && a.value === name) &&
14-
hasscope(ex.args[1], modpath)
14+
hasscope(ex.args[1], modpath)
1515
end
1616
hasscope(modex, mod::Symbol) = modex === mod
1717
hasscope(modex, mod::Tuple{Symbol}) = hasscope(modex, mod[1])

0 commit comments

Comments
 (0)