@@ -52,7 +52,7 @@ public func meanAbsoluteError<Scalar: TensorFlowFloatingPoint>(
52
52
predicted: Tensor < Scalar > ,
53
53
expected: Tensor < Scalar >
54
54
) -> Tensor < Scalar > {
55
- l1Loss ( predicted: predicted, expected: expected, reduction: { $0 . mean ( ) } )
55
+ l1Loss ( predicted: predicted, expected: expected, reduction: _mean )
56
56
}
57
57
58
58
/// Returns the mean squared error between predictions and expectations.
@@ -65,7 +65,7 @@ public func meanSquaredError<Scalar: TensorFlowFloatingPoint>(
65
65
predicted: Tensor < Scalar > ,
66
66
expected: Tensor < Scalar >
67
67
) -> Tensor < Scalar > {
68
- l2Loss ( predicted: predicted, expected: expected, reduction: { $0 . mean ( ) } )
68
+ l2Loss ( predicted: predicted, expected: expected, reduction: _mean )
69
69
}
70
70
71
71
/// Returns the mean squared logarithmic error between predictions and expectations.
@@ -83,7 +83,7 @@ public func meanSquaredLogarithmicError<Scalar: TensorFlowFloatingPoint>(
83
83
) -> Tensor < Scalar > {
84
84
let logPredicted = log ( max ( predicted, Tensor ( 0 ) ) + 1 )
85
85
let logExpected = log ( max ( expected, Tensor ( 0 ) ) + 1 )
86
- return l2Loss ( predicted: logPredicted, expected: logExpected, reduction: { $0 . mean ( ) } )
86
+ return l2Loss ( predicted: logPredicted, expected: logExpected, reduction: _mean )
87
87
}
88
88
89
89
/// Returns the mean absolute percentage error between predictions and expectations.
@@ -109,7 +109,7 @@ public func meanAbsolutePercentageError<Scalar: TensorFlowFloatingPoint>(
109
109
public func hingeLoss< Scalar: TensorFlowFloatingPoint > (
110
110
predicted: Tensor < Scalar > ,
111
111
expected: Tensor < Scalar > ,
112
- reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = { $0 . mean ( ) }
112
+ reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = _mean
113
113
) -> Tensor < Scalar > {
114
114
reduction ( max ( Tensor ( 0 ) , Tensor ( 1 ) - expected * predicted) )
115
115
}
@@ -124,7 +124,7 @@ public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
124
124
public func squaredHingeLoss< Scalar: TensorFlowFloatingPoint > (
125
125
predicted: Tensor < Scalar > ,
126
126
expected: Tensor < Scalar > ,
127
- reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = { $0 . mean ( ) }
127
+ reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = _mean
128
128
) -> Tensor < Scalar > {
129
129
reduction ( hingeLoss ( predicted: predicted, expected: expected) . squared ( ) )
130
130
}
@@ -139,7 +139,7 @@ public func squaredHingeLoss<Scalar: TensorFlowFloatingPoint>(
139
139
public func categoricalHingeLoss< Scalar: TensorFlowFloatingPoint > (
140
140
predicted: Tensor < Scalar > ,
141
141
expected: Tensor < Scalar > ,
142
- reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = { $0 . mean ( ) }
142
+ reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = _mean
143
143
) -> Tensor < Scalar > {
144
144
let positive = ( expected * predicted) . sum ( alongAxes: - 1 )
145
145
let negative = ( ( Tensor ( 1 ) - expected) * predicted) . max ( alongAxes: - 1 )
@@ -157,7 +157,7 @@ public func categoricalHingeLoss<Scalar: TensorFlowFloatingPoint>(
157
157
public func logCoshLoss< Scalar: TensorFlowFloatingPoint > (
158
158
predicted: Tensor < Scalar > ,
159
159
expected: Tensor < Scalar > ,
160
- reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = { $0 . mean ( ) }
160
+ reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = _mean
161
161
) -> Tensor < Scalar > {
162
162
let x = predicted - expected
163
163
return reduction ( x + softplus( Tensor ( - 2 ) * x) - log( Tensor ( 2 ) ) )
@@ -173,7 +173,7 @@ public func logCoshLoss<Scalar: TensorFlowFloatingPoint>(
173
173
public func poissonLoss< Scalar: TensorFlowFloatingPoint > (
174
174
predicted: Tensor < Scalar > ,
175
175
expected: Tensor < Scalar > ,
176
- reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = { $0 . mean ( ) }
176
+ reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = _mean
177
177
) -> Tensor < Scalar > {
178
178
reduction ( predicted - expected * log( predicted) )
179
179
}
@@ -194,6 +194,15 @@ public func kullbackLeiblerDivergence<Scalar: TensorFlowFloatingPoint>(
194
194
reduction ( expected * log( expected / predicted) )
195
195
}
196
196
197
+ /// Workaround for cross-module default parameter @differentiable functions.
198
+ /// Tensor<Scalar>.mean() is the preferred way to do this.
199
+ @differentiable
200
+ public func _mean< Scalar: TensorFlowFloatingPoint > (
201
+ _ value: Tensor < Scalar >
202
+ ) -> Tensor < Scalar > {
203
+ return value. mean ( )
204
+ }
205
+
197
206
/// Returns the softmax cross entropy (categorical cross entropy) between logits and labels.
198
207
///
199
208
/// - Parameters:
@@ -204,7 +213,7 @@ public func kullbackLeiblerDivergence<Scalar: TensorFlowFloatingPoint>(
204
213
public func softmaxCrossEntropy< Scalar: TensorFlowFloatingPoint > (
205
214
logits: Tensor < Scalar > ,
206
215
labels: Tensor < Int32 > ,
207
- reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = { $0 . mean ( ) }
216
+ reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = _mean
208
217
) -> Tensor < Scalar > {
209
218
reduction ( softmaxCrossEntropyHelper ( logits: logits, labels: labels) )
210
219
}
@@ -238,7 +247,7 @@ func _vjpSoftmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
238
247
public func softmaxCrossEntropy< Scalar: TensorFlowFloatingPoint > (
239
248
logits: Tensor < Scalar > ,
240
249
probabilities: Tensor < Scalar > ,
241
- reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = { $0 . mean ( ) }
250
+ reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = _mean
242
251
) -> Tensor < Scalar > {
243
252
reduction ( softmaxCrossEntropyHelper ( logits: logits, probabilities: probabilities) )
244
253
}
@@ -274,7 +283,7 @@ func _vjpSoftmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
274
283
public func sigmoidCrossEntropy< Scalar: TensorFlowFloatingPoint > (
275
284
logits: Tensor < Scalar > ,
276
285
labels: Tensor < Scalar > ,
277
- reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = { $0 . mean ( ) }
286
+ reduction: @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar > = _mean
278
287
) -> Tensor < Scalar > {
279
288
// This numerically stable implementation is based on the TensorFlow Python API.
280
289
let maxLogitsWithZero = max ( logits, Tensor ( 0 ) )
0 commit comments