17
17
function ChainRulesCore. rrule (dist:: Delta , x:: AbstractVector , y:: AbstractVector )
18
18
d = dist (x, y)
19
19
function evaluate_pullback (:: Any )
20
- return NO_FIELDS, Zero (), Zero ()
20
+ return NoTangent (), ZeroTangent (), ZeroTangent ()
21
21
end
22
22
return d, evaluate_pullback
23
23
end
@@ -27,7 +27,7 @@ function ChainRulesCore.rrule(
27
27
)
28
28
P = Distances. pairwise (d, X, Y; dims= dims)
29
29
function pairwise_pullback (:: AbstractMatrix )
30
- return NO_FIELDS, NO_FIELDS, Zero (), Zero ()
30
+ return NoTangent (), NoTangent (), ZeroTangent (), ZeroTangent ()
31
31
end
32
32
return P, pairwise_pullback
33
33
end
@@ -37,7 +37,7 @@ function ChainRulesCore.rrule(
37
37
)
38
38
P = Distances. pairwise (d, X; dims= dims)
39
39
function pairwise_pullback (:: AbstractMatrix )
40
- return NO_FIELDS, NO_FIELDS, Zero ()
40
+ return NoTangent (), NoTangent (), ZeroTangent ()
41
41
end
42
42
return P, pairwise_pullback
43
43
end
@@ -47,7 +47,7 @@ function ChainRulesCore.rrule(
47
47
)
48
48
C = Distances. colwise (d, X, Y)
49
49
function colwise_pullback (:: AbstractVector )
50
- return NO_FIELDS, NO_FIELDS, Zero (), Zero ()
50
+ return NoTangent (), NoTangent (), ZeroTangent (), ZeroTangent ()
51
51
end
52
52
return C, colwise_pullback
53
53
end
57
57
function ChainRulesCore. rrule (dist:: DotProduct , x:: AbstractVector , y:: AbstractVector )
58
58
d = dist (x, y)
59
59
function evaluate_pullback (Δ:: Any )
60
- return NO_FIELDS , Δ .* y, Δ .* x
60
+ return NoTangent () , Δ .* y, Δ .* x
61
61
end
62
62
return d, evaluate_pullback
63
63
end
@@ -72,9 +72,9 @@ function ChainRulesCore.rrule(
72
72
P = Distances. pairwise (d, X, Y; dims= dims)
73
73
function pairwise_pullback_cols (Δ:: AbstractMatrix )
74
74
if dims == 1
75
- return NO_FIELDS, NO_FIELDS , Δ * Y, Δ' * X
75
+ return NoTangent (), NoTangent () , Δ * Y, Δ' * X
76
76
else
77
- return NO_FIELDS, NO_FIELDS , Y * Δ' , X * Δ
77
+ return NoTangent (), NoTangent () , Y * Δ' , X * Δ
78
78
end
79
79
end
80
80
return P, pairwise_pullback_cols
@@ -86,9 +86,9 @@ function ChainRulesCore.rrule(
86
86
P = Distances. pairwise (d, X; dims= dims)
87
87
function pairwise_pullback_cols (Δ:: AbstractMatrix )
88
88
if dims == 1
89
- return NO_FIELDS, NO_FIELDS , 2 * Δ * X
89
+ return NoTangent (), NoTangent () , 2 * Δ * X
90
90
else
91
- return NO_FIELDS, NO_FIELDS , 2 * X * Δ
91
+ return NoTangent (), NoTangent () , 2 * X * Δ
92
92
end
93
93
end
94
94
return P, pairwise_pullback_cols
@@ -99,7 +99,7 @@ function ChainRulesCore.rrule(
99
99
)
100
100
C = Distances. colwise (d, X, Y)
101
101
function colwise_pullback (Δ:: AbstractVector )
102
- return NO_FIELDS, NO_FIELDS , Δ' .* Y, Δ' .* X
102
+ return NoTangent (), NoTangent () , Δ' .* Y, Δ' .* X
103
103
end
104
104
return C, colwise_pullback
105
105
end
@@ -135,15 +135,15 @@ function ChainRulesCore.rrule(
135
135
∂b = InplaceableThunk (
136
136
@thunk ((- 2 * Δ) * dist. qmat * a_b), X̄ -> mul! (X̄, dist. qmat, a_b, true , - 2 * Δ)
137
137
)
138
- return Composite {typeof(dist)} (; qmat= ∂qmat), ∂a, ∂b
138
+ return Tangent {typeof(dist)} (; qmat= ∂qmat), ∂a, ∂b
139
139
end
140
140
return d, SqMahalanobis_pullback
141
141
end
142
142
143
143
# # Reverse Rules for matrix wrappers
144
144
145
145
function ChainRulesCore. rrule (:: Type{<:ColVecs} , X:: AbstractMatrix )
146
- ColVecs_pullback (Δ:: Composite ) = (NO_FIELDS , Δ. X)
146
+ ColVecs_pullback (Δ:: Tangent ) = (NoTangent () , Δ. X)
147
147
function ColVecs_pullback (:: AbstractVector{<:AbstractVector{<:Real}} )
148
148
return error (
149
149
" Pullback on AbstractVector{<:AbstractVector}.\n " *
@@ -155,7 +155,7 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
155
155
end
156
156
157
157
function ChainRulesCore. rrule (:: Type{<:RowVecs} , X:: AbstractMatrix )
158
- RowVecs_pullback (Δ:: Composite ) = (NO_FIELDS , Δ. X)
158
+ RowVecs_pullback (Δ:: Tangent ) = (NoTangent () , Δ. X)
159
159
function RowVecs_pullback (:: AbstractVector{<:AbstractVector{<:Real}} )
160
160
return error (
161
161
" Pullback on AbstractVector{<:AbstractVector}.\n " *
0 commit comments