Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 7a104ca

Browse files
dan12411dan-zheng
authored andcommitted
Added meanAbsoluteError function (#182)
1 parent dbeeb06 commit 7a104ca

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ public func meanSquaredError<Scalar: TensorFlowFloatingPoint>(
2424
return (expected - predicted).squared().mean()
2525
}
2626

27+
/// Computes the mean absolute error between predictions and expectations.
28+
///
29+
/// - Parameters:
30+
/// - predicted: Predicted outputs from a neural network.
31+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
32+
@differentiable(wrt: predicted)
33+
public func meanAbsoluteError<Scalar: TensorFlowFloatingPoint>(
34+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
35+
) -> Tensor<Scalar> {
36+
return abs(expected - predicted).mean()
37+
}
38+
2739
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
2840
///
2941
/// - Parameters:

Tests/TensorFlowTests/LossTests.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ final class LossTests: XCTestCase {
2727
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
2828
}
2929

30+
func testMeanAbsoluteError() {
31+
let predicted = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
32+
let expected = Tensor<Float>(
33+
shape: [2, 4],
34+
scalars: [0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1])
35+
36+
let loss = meanAbsoluteError(predicted: predicted, expected: expected)
37+
let expectedLoss: Float = 4.25
38+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
39+
}
40+
3041
func testMeanSquaredErrorGrad() {
3142
let predicted = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
3243
let expected = Tensor<Float>(

0 commit comments

Comments
 (0)