Skip to content

Commit 9ffbeee

Browse files
jekbradburyrxwei
authored andcommitted
Some TF stdlib fixes (#22112)
* [TF API] Fix VJPs for sum and mean The VJP registered for the squeezingAxes variant of each of those methods is in fact the correct VJP for the alongAxes variant; the VJP for the squeezingAxes variant requires inserting an additional dimension, which is left for future implementation. * [TF API] Add Tensor.variance * [TF API] Mark some more reduction methods `@differentiable where` * [TF API] Remove tests for broken sum/mean VJPs * [TF] check out branch of swift-apis * [AutoDiff] Fix SILGen for `AutoDiffFunctionExtractExpr` so that it won't leak owned values.
1 parent 11b384d commit 9ffbeee

File tree

6 files changed

+39
-39
lines changed

6 files changed

+39
-39
lines changed

lib/SILGen/SILGenExpr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5390,7 +5390,7 @@ RValue RValueEmitter::visitAutoDiffFunctionExtractOriginalExpr(
53905390
auto diffFunc = SGF.emitRValueAsSingleValue(E->getSubExpr());
53915391
auto *orig = SGF.B.createAutoDiffFunctionExtractOriginal(
53925392
E, diffFunc.forward(SGF));
5393-
return RValue(SGF, E, ManagedValue::forUnmanaged(orig));
5393+
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(orig));
53945394
}
53955395

53965396
RValue RValueEmitter::visitTapExpr(TapExpr *E, SGFContext C) {

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -559,16 +559,16 @@ extension Tensor where Scalar : Differentiable & FloatingPoint {
559559
}
560560

561561
@inlinable
562-
func _vjpMean(squeezingAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
563-
let value = mean(squeezingAxes: axes)
562+
func _vjpMean(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
563+
let value = mean(alongAxes: axes)
564564
return (value, { [shape = shapeTensor, count = scalarCountTensor] in
565565
$0.broadcast(toShape: shape) / Tensor(count)
566566
})
567567
}
568568

569569
@inlinable
570-
func _vjpSum(squeezingAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
571-
let value = sum(squeezingAxes: axes)
570+
func _vjpSum(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
571+
let value = sum(alongAxes: axes)
572572
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
573573
}
574574
}

stdlib/public/TensorFlow/Ops.swift

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,10 +1225,6 @@ public extension Tensor where Scalar : Numeric {
12251225
/// - Parameter axes: The dimensions to reduce.
12261226
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
12271227
@inlinable @inline(__always)
1228-
@differentiable(
1229-
wrt: self, vjp: _vjpMean(squeezingAxes:)
1230-
where Scalar : Differentiable & FloatingPoint
1231-
)
12321228
func mean(squeezingAxes axes: [Int32]) -> Tensor {
12331229
return Raw.mean(self, reductionIndices: Tensor<Int32>(axes),
12341230
keepDims: false)
@@ -1239,10 +1235,6 @@ public extension Tensor where Scalar : Numeric {
12391235
/// - Parameter axes: The dimensions to reduce.
12401236
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
12411237
@inlinable @inline(__always)
1242-
@differentiable(
1243-
wrt: self, vjp: _vjpMean(squeezingAxes:)
1244-
where Scalar : Differentiable & FloatingPoint
1245-
)
12461238
func mean(squeezingAxes axes: Int32...) -> Tensor {
12471239
return mean(squeezingAxes: axes)
12481240
}
@@ -1252,10 +1244,6 @@ public extension Tensor where Scalar : Numeric {
12521244
/// - Parameter axes: The dimensions to reduce.
12531245
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
12541246
@inlinable @inline(__always)
1255-
@differentiable(
1256-
wrt: self, vjp: _vjpSum(squeezingAxes:)
1257-
where Scalar : Differentiable & FloatingPoint
1258-
)
12591247
func sum(squeezingAxes axes: [Int32]) -> Tensor {
12601248
return Raw.sum(self, reductionIndices: Tensor<Int32>(axes), keepDims: false)
12611249
}
@@ -1265,10 +1253,6 @@ public extension Tensor where Scalar : Numeric {
12651253
/// - Parameter axes: The dimensions to reduce.
12661254
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
12671255
@inlinable @inline(__always)
1268-
@differentiable(
1269-
wrt: self, vjp: _vjpSum(squeezingAxes:)
1270-
where Scalar : Differentiable & FloatingPoint
1271-
)
12721256
func sum(squeezingAxes axes: Int32...) -> Tensor {
12731257
return sum(squeezingAxes: axes)
12741258
}
@@ -1298,7 +1282,7 @@ public extension Tensor where Scalar : Numeric {
12981282
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
12991283
@inlinable @inline(__always)
13001284
@differentiable(
1301-
wrt: self, vjp: _vjpMean(squeezingAxes:)
1285+
wrt: self, vjp: _vjpMean(alongAxes:)
13021286
where Scalar : Differentiable & FloatingPoint
13031287
)
13041288
func mean(alongAxes axes: [Int32]) -> Tensor {
@@ -1310,10 +1294,7 @@ public extension Tensor where Scalar : Numeric {
13101294
/// - Parameter axes: The dimensions to reduce.
13111295
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
13121296
@inlinable @inline(__always)
1313-
@differentiable(
1314-
wrt: self, vjp: _vjpMean(squeezingAxes:)
1315-
where Scalar : Differentiable & FloatingPoint
1316-
)
1297+
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
13171298
func mean(alongAxes axes: Int32...) -> Tensor {
13181299
return mean(alongAxes: axes)
13191300
}
@@ -1324,7 +1305,7 @@ public extension Tensor where Scalar : Numeric {
13241305
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
13251306
@inlinable @inline(__always)
13261307
@differentiable(
1327-
wrt: self, vjp: _vjpSum(squeezingAxes:)
1308+
wrt: self, vjp: _vjpSum(alongAxes:)
13281309
where Scalar : Differentiable & FloatingPoint
13291310
)
13301311
func sum(alongAxes axes: [Int32]) -> Tensor {
@@ -1336,14 +1317,33 @@ public extension Tensor where Scalar : Numeric {
13361317
/// - Parameter axes: The dimensions to reduce.
13371318
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
13381319
@inlinable @inline(__always)
1339-
@differentiable(
1340-
wrt: self, vjp: _vjpSum(squeezingAxes:)
1341-
where Scalar : Differentiable & FloatingPoint
1342-
)
1320+
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
13431321
func sum(alongAxes axes: Int32...) -> Tensor {
13441322
return sum(alongAxes: axes)
13451323
}
13461324

1325+
/// Returns the variance along the specified axes. The reduced dimensions are
1326+
/// retained with value 1. Does not apply Bessel's correction.
1327+
/// - Parameter axes: The dimensions to reduce.
1328+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1329+
@inlinable @inline(__always)
1330+
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
1331+
func variance(alongAxes axes: Int32...) -> Tensor {
1332+
return variance(alongAxes: axes)
1333+
}
1334+
1335+
/// Returns the variance along the specified axes. The reduced dimensions are
1336+
/// retained with value 1. Does not apply Bessel's correction.
1337+
/// - Parameter axes: The dimensions to reduce.
1338+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1339+
@inlinable @inline(__always)
1340+
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
1341+
func variance(alongAxes axes: [Int32]) -> Tensor {
1342+
let mean = self.mean(alongAxes: axes)
1343+
let squaredDiff = (self - mean).squared()
1344+
return squaredDiff.mean(alongAxes: axes)
1345+
}
1346+
13471347
/// Returns the product along the specified axes. The reduced dimensions are
13481348
/// retained with value 1.
13491349
/// - Parameter axes: The dimensions to reduce.

test/AutoDiff/autodiff_function_silgen.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ func apply() {
3333
// CHECK-SILGEN: [[ORIG:%.*]] = autodiff_function_extract [original] [[DIFFED_COPY]] : $@autodiff @callee_guaranteed (Float) -> Float
3434
// CHECK-SILGEN: [[BORROWED_ORIG:%.*]] = begin_borrow [[ORIG]] : $@callee_guaranteed (Float) -> Float
3535
// CHECK-SILGEN: apply [[BORROWED_ORIG]]({{%.*}}) : $@callee_guaranteed (Float) -> Float
36+
// CHECK-SILGEN: destroy_value [[ORIG]] : $@callee_guaranteed (Float) -> Float
3637
// CHECK-SILGEN: [[DIFFED_COPY:%.*]] = copy_value [[DIFFED]] : $@autodiff @callee_guaranteed (Float) -> Float
3738
// CHECK-SILGEN: [[ORIG:%.*]] = autodiff_function_extract [original] [[DIFFED_COPY]] : $@autodiff @callee_guaranteed (Float) -> Float
38-
// CHECK-SILGEN: [[ORIG_COPY:%.*]] = copy_value [[ORIG]] : $@callee_guaranteed (Float) -> Float
39-
// CHECK-SILGEN: return [[ORIG_COPY]] : $@callee_guaranteed (Float) -> Float
39+
// CHECK-SILGEN: return [[ORIG]] : $@callee_guaranteed (Float) -> Float
4040

4141
// CHECK-SILGEN-LABEL: @{{.*}}apply{{.*}}
4242
// CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float

test/TensorFlowRuntime/tensor_autodiff_runtime.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,27 +75,27 @@ TensorADTests.testAllBackends("negate") {
7575
TensorADTests.testAllBackends("sum") {
7676
let input = Tensor<Float>(randomNormal: [2, 2])
7777
let sumPullbackScalar = pullback(at: input) { (a: Tensor<Float>) in a.sum() }
78-
let sumPullbackSqueezingAxes = pullback(at: input) { (a: Tensor<Float>) in a.sum(squeezingAxes: 0, 1) }
78+
// let sumPullbackSqueezingAxes = pullback(at: input) { (a: Tensor<Float>) in a.sum(squeezingAxes: 0, 1) }
7979
let sumPullbackAlongAxes = pullback(at: input) { (a: Tensor<Float>) in a.sum(alongAxes: 0, 1) }
8080

8181
let expected = Tensor<Float>(ones: [2, 2])
8282
expectTrue(sumPullbackScalar(Tensor(1)) == expected)
83-
expectTrue(sumPullbackSqueezingAxes(Tensor(1)) == expected)
83+
// expectTrue(sumPullbackSqueezingAxes(Tensor(1)) == expected)
8484
expectTrue(sumPullbackAlongAxes(Tensor(1)) == expected)
8585
expectTrue(sumPullbackScalar(Tensor(3)) == expected * 3)
86-
expectTrue(sumPullbackSqueezingAxes(Tensor(3)) == expected * 3)
86+
// expectTrue(sumPullbackSqueezingAxes(Tensor(3)) == expected * 3)
8787
expectTrue(sumPullbackAlongAxes(Tensor(3)) == expected * 3)
8888
}
8989

9090
TensorADTests.testAllBackends("mean") {
9191
let meanGradScalar = gradient { (a: Tensor<Float>) in a.mean() }
92-
let meanGradSqueezingAxes = gradient { (a: Tensor<Float>) in a.mean(squeezingAxes: 0, 1) }
92+
// let meanGradSqueezingAxes = gradient { (a: Tensor<Float>) in a.mean(squeezingAxes: 0, 1) }
9393
let meanGradAlongAxes = gradient { (a: Tensor<Float>) in a.mean(alongAxes: 0, 1) }
9494

9595
let input = Tensor<Float>(ones: [2, 2])
9696
let expected = Tensor<Float>(shape: [2, 2], repeating: 0.25)
9797
expectTrue(meanGradScalar(input) == expected)
98-
expectTrue(meanGradSqueezingAxes(input) == expected)
98+
// expectTrue(meanGradSqueezingAxes(input) == expected)
9999
expectTrue(meanGradAlongAxes(input) == expected)
100100
}
101101

utils/update_checkout/update-checkout-config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@
242242
"icu": "release-61-1",
243243
"tensorflow": "a6924e6affd935f537cdaf8977094df0e15a7957",
244244
"tensorflow-swift-bindings": "10e591340134c37a6c3a1df735a7334a77d5cbc7",
245-
"tensorflow-swift-apis": "731ce402ce0bd7459b898b112deb89379bbda893"
245+
"tensorflow-swift-apis": "ece4a67ed844919b04d2aec9d131b098aee39413"
246246
}
247247
}
248248
}

0 commit comments

Comments
 (0)