1
1
module ForwardDiffExt
2
2
import ForwardDiff, ChainRulesCore
3
- using SIMDDualNumbers, LoopVectorization
3
+ using LoopVectorization, VectorizationBase, SLEEFPirates, ForwardDiff
4
+
5
+ import IfElse: ifelse
6
+ using VectorizationBase: AbstractSIMD, AbstractMask, zero_offsets
7
+
4
8
using LoopVectorization:
5
9
AbstractSIMD,
6
10
AbstractStridedPointer,
@@ -18,7 +22,188 @@ using LoopVectorization:
18
22
mask,
19
23
vfnmadd_fast,
20
24
mul_fast
21
- using VectorizationBase: zero_offsets
25
+
26
+ @generated function Base. abs (
27
+ x:: ForwardDiff.Dual{TAG,S,N}
28
+ ) where {TAG,S<: AbstractSIMD ,N}
29
+ quote
30
+ $ (Expr (:meta , :inline ))
31
+ val = x. value
32
+ p = x. partials
33
+ cmp = val < zero ($ S)
34
+ absx = $ ifelse (cmp, - val, val)
35
+ Base. Cartesian. @nexprs $ N n -> p_n = p[n]
36
+ ForwardDiff. Dual {$TAG} (
37
+ absx,
38
+ ForwardDiff. Partials (
39
+ Base. Cartesian. @ntuple $ N n -> $ ifelse (cmp, - p_n, p_n)
40
+ )
41
+ )
42
+ end
43
+ end
44
+ @inline function Base. max (
45
+ x:: ForwardDiff.Dual{TAG,<:AbstractSIMD,N} ,
46
+ y:: ForwardDiff.Dual{TAG,<:AbstractSIMD,N}
47
+ ) where {TAG,N}
48
+ vx = ForwardDiff. value (x)
49
+ vy = ForwardDiff. value (y)
50
+ xgy = vx > vy
51
+ z = ifelse (xgy, vx, vy)
52
+ p = VectorizationBase. fmap (
53
+ ifelse,
54
+ xgy,
55
+ ForwardDiff. partials (x). values,
56
+ ForwardDiff. partials (y). values
57
+ )
58
+ ForwardDiff. Dual {TAG} (z, ForwardDiff. Partials (p))
59
+ end
60
+
61
+ @inline Base. max (
62
+ x:: T ,
63
+ y:: Real
64
+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = max (x, T (y))
65
+ @inline Base. max (
66
+ y:: Real ,
67
+ x:: T
68
+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = max (x, T (y))
69
+ @inline Base. max (
70
+ x:: T ,
71
+ y:: Int
72
+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = max (x, T (y))
73
+ @inline Base. max (
74
+ y:: Int ,
75
+ x:: T
76
+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = max (x, T (y))
77
+
78
+ @inline function Base. min (
79
+ x:: ForwardDiff.Dual{TAG,<:AbstractSIMD,N} ,
80
+ y:: ForwardDiff.Dual{TAG,<:AbstractSIMD,N}
81
+ ) where {TAG,N}
82
+ vx = ForwardDiff. value (x)
83
+ vy = ForwardDiff. value (y)
84
+ xgy = vx < vy
85
+ z = ifelse (xgy, vx, vy)
86
+ p = VectorizationBase. fmap (
87
+ ifelse,
88
+ xgy,
89
+ ForwardDiff. partials (x). values,
90
+ ForwardDiff. partials (y). values
91
+ )
92
+ ForwardDiff. Dual {TAG} (z, ForwardDiff. Partials (p))
93
+ end
94
+ @inline Base. min (
95
+ x:: T ,
96
+ y:: Real
97
+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = min (x, T (y))
98
+ @inline Base. min (
99
+ y:: Real ,
100
+ x:: T
101
+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = min (x, T (y))
102
+ @inline Base. min (
103
+ x:: T ,
104
+ y:: Int
105
+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = min (x, T (y))
106
+ @inline Base. min (
107
+ y:: Int ,
108
+ x:: T
109
+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = min (x, T (y))
110
+
111
+ @generated function SLEEFPirates. tanh_fast (
112
+ x:: ForwardDiff.Dual{T,S,N}
113
+ ) where {T,S,N}
114
+ quote
115
+ $ (Expr (:meta , :inline ))
116
+ t = tanh_fast (x. value)
117
+ ∂t = $ (VectorizationBase. vfnmadd_fast)(t, t, one (S))
118
+ p = x. partials
119
+ ForwardDiff. Dual {T} (
120
+ t,
121
+ ForwardDiff. Partials (
122
+ Base. Cartesian. @ntuple $ N n -> $ (Base. FastMath. mul_fast)(∂t, p[n])
123
+ )
124
+ )
125
+ end
126
+ end
127
+ @generated function SLEEFPirates. sigmoid_fast (
128
+ x:: ForwardDiff.Dual{T,S,N}
129
+ ) where {T,S,N}
130
+ quote
131
+ $ (Expr (:meta , :inline ))
132
+ s = sigmoid_fast (x. value)
133
+ ∂s = $ (VectorizationBase. vfnmadd_fast)(s, s, s)
134
+ p = x. partials
135
+ ForwardDiff. Dual {T} (
136
+ s,
137
+ ForwardDiff. Partials (
138
+ Base. Cartesian. @ntuple $ N n -> $ (Base. FastMath. mul_fast)(∂s, p[n])
139
+ )
140
+ )
141
+ end
142
+ end
143
+ @generated function VectorizationBase. relu (
144
+ x:: ForwardDiff.Dual{T,S,N}
145
+ ) where {T,S,N}
146
+ quote
147
+ $ (Expr (:meta , :inline ))
148
+ v = x. value
149
+ z = zero (v)
150
+ cmp = v < z
151
+ r = ifelse (cmp, z, v)
152
+ p = x. partials
153
+ ForwardDiff. Dual {T} (
154
+ r,
155
+ ForwardDiff. Partials (Base. Cartesian. @ntuple $ N n -> ifelse (cmp, z, p[n]))
156
+ )
157
+ end
158
+ end
159
+
160
+ @generated function ifelse (
161
+ m:: AbstractMask ,
162
+ x:: ForwardDiff.Dual{TAG,V,P} ,
163
+ y:: ForwardDiff.Dual{TAG,V,P}
164
+ ) where {TAG,V,P}
165
+ quote
166
+ $ (Expr (:meta , :inline ))
167
+ z = $ ifelse (m, ForwardDiff. value (x), ForwardDiff. value (y))
168
+ px = ForwardDiff. partials (x)
169
+ py = ForwardDiff. partials (y)
170
+ p = Base. Cartesian. @ntuple $ P p -> $ ifelse (m, px[p], py[p])
171
+ ForwardDiff. Dual {$TAG} (z, ForwardDiff. Partials (p))
172
+ end
173
+ end
174
+ @generated function ifelse (
175
+ m:: AbstractMask ,
176
+ x:: Number ,
177
+ y:: ForwardDiff.Dual{TAG,V,P}
178
+ ) where {TAG,V,P}
179
+ quote
180
+ $ (Expr (:meta , :inline ))
181
+ z = $ ifelse (m, x, ForwardDiff. value (y))
182
+ py = ForwardDiff. partials (y)
183
+ p = Base. Cartesian. @ntuple $ P p -> $ ifelse (m, zero ($ V), py[p])
184
+ ForwardDiff. Dual {$TAG} (z, ForwardDiff. Partials (p))
185
+ end
186
+ end
187
+ @generated function ifelse (
188
+ m:: AbstractMask ,
189
+ x:: ForwardDiff.Dual{TAG,V,P} ,
190
+ y:: Number
191
+ ) where {TAG,V,P}
192
+ quote
193
+ $ (Expr (:meta , :inline ))
194
+ z = $ ifelse (m, ForwardDiff. value (x), y)
195
+ px = ForwardDiff. partials (x)
196
+ p = Base. Cartesian. @ntuple $ P p -> $ ifelse (m, px[p], zero ($ V))
197
+ ForwardDiff. Dual {$TAG} (z, ForwardDiff. Partials (p))
198
+ end
199
+ end
200
+ @inline function SLEEFPirates. softplus (x:: ForwardDiff.Dual{TAG} ) where {TAG}
201
+ val = ForwardDiff. value (x)
202
+ expx = exp (val)
203
+ vx = log1p (expx)
204
+ px = Base. FastMath. inv_fast (one (val) + Base. FastMath. inv_fast (expx))
205
+ ForwardDiff. Dual {TAG} (vx, Base. FastMath. mul_fast (ForwardDiff. partials (x), px))
206
+ end
22
207
23
208
@generated function init_dual (v:: Tuple{Vararg{AbstractSIMD,A}} ) where {A}
24
209
res = Expr (:tuple )
0 commit comments