Skip to content

Some TF stdlib fixes #22112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/SILGen/SILGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5390,7 +5390,7 @@ RValue RValueEmitter::visitAutoDiffFunctionExtractOriginalExpr(
auto diffFunc = SGF.emitRValueAsSingleValue(E->getSubExpr());
auto *orig = SGF.B.createAutoDiffFunctionExtractOriginal(
E, diffFunc.forward(SGF));
return RValue(SGF, E, ManagedValue::forUnmanaged(orig));
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(orig));
}

RValue RValueEmitter::visitTapExpr(TapExpr *E, SGFContext C) {
Expand Down
8 changes: 4 additions & 4 deletions stdlib/public/TensorFlow/Gradients.swift
Original file line number Diff line number Diff line change
Expand Up @@ -559,16 +559,16 @@ extension Tensor where Scalar : Differentiable & FloatingPoint {
}

@inlinable
func _vjpMean(squeezingAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
let value = mean(squeezingAxes: axes)
func _vjpMean(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
let value = mean(alongAxes: axes)
return (value, { [shape = shapeTensor, count = scalarCountTensor] in
$0.broadcast(toShape: shape) / Tensor(count)
})
}

@inlinable
func _vjpSum(squeezingAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
let value = sum(squeezingAxes: axes)
func _vjpSum(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
let value = sum(alongAxes: axes)
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
}
}
Expand Down
52 changes: 26 additions & 26 deletions stdlib/public/TensorFlow/Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1225,10 +1225,6 @@ public extension Tensor where Scalar : Numeric {
/// - Parameter axes: The dimensions to reduce.
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
@inlinable @inline(__always)
@differentiable(
wrt: self, vjp: _vjpMean(squeezingAxes:)
where Scalar : Differentiable & FloatingPoint
)
func mean(squeezingAxes axes: [Int32]) -> Tensor {
return Raw.mean(self, reductionIndices: Tensor<Int32>(axes),
keepDims: false)
Expand All @@ -1239,10 +1235,6 @@ public extension Tensor where Scalar : Numeric {
/// - Parameter axes: The dimensions to reduce.
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
@inlinable @inline(__always)
@differentiable(
wrt: self, vjp: _vjpMean(squeezingAxes:)
where Scalar : Differentiable & FloatingPoint
)
func mean(squeezingAxes axes: Int32...) -> Tensor {
return mean(squeezingAxes: axes)
}
Expand All @@ -1252,10 +1244,6 @@ public extension Tensor where Scalar : Numeric {
/// - Parameter axes: The dimensions to reduce.
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
@inlinable @inline(__always)
@differentiable(
wrt: self, vjp: _vjpSum(squeezingAxes:)
where Scalar : Differentiable & FloatingPoint
)
func sum(squeezingAxes axes: [Int32]) -> Tensor {
return Raw.sum(self, reductionIndices: Tensor<Int32>(axes), keepDims: false)
}
Expand All @@ -1265,10 +1253,6 @@ public extension Tensor where Scalar : Numeric {
/// - Parameter axes: The dimensions to reduce.
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
@inlinable @inline(__always)
@differentiable(
wrt: self, vjp: _vjpSum(squeezingAxes:)
where Scalar : Differentiable & FloatingPoint
)
func sum(squeezingAxes axes: Int32...) -> Tensor {
return sum(squeezingAxes: axes)
}
Expand Down Expand Up @@ -1298,7 +1282,7 @@ public extension Tensor where Scalar : Numeric {
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable @inline(__always)
@differentiable(
wrt: self, vjp: _vjpMean(squeezingAxes:)
wrt: self, vjp: _vjpMean(alongAxes:)
where Scalar : Differentiable & FloatingPoint
)
func mean(alongAxes axes: [Int32]) -> Tensor {
Expand All @@ -1310,10 +1294,7 @@ public extension Tensor where Scalar : Numeric {
/// - Parameter axes: The dimensions to reduce.
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable @inline(__always)
@differentiable(
wrt: self, vjp: _vjpMean(squeezingAxes:)
where Scalar : Differentiable & FloatingPoint
)
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
func mean(alongAxes axes: Int32...) -> Tensor {
return mean(alongAxes: axes)
}
Expand All @@ -1324,7 +1305,7 @@ public extension Tensor where Scalar : Numeric {
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable @inline(__always)
@differentiable(
wrt: self, vjp: _vjpSum(squeezingAxes:)
wrt: self, vjp: _vjpSum(alongAxes:)
where Scalar : Differentiable & FloatingPoint
)
func sum(alongAxes axes: [Int32]) -> Tensor {
Expand All @@ -1336,14 +1317,33 @@ public extension Tensor where Scalar : Numeric {
/// - Parameter axes: The dimensions to reduce.
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable @inline(__always)
@differentiable(
wrt: self, vjp: _vjpSum(squeezingAxes:)
where Scalar : Differentiable & FloatingPoint
)
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
func sum(alongAxes axes: Int32...) -> Tensor {
return sum(alongAxes: axes)
}

/// Returns the variance along the specified axes. The reduced dimensions are
/// retained with value 1. Does not apply Bessel's correction.
/// - Parameter axes: The dimensions to reduce.
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable @inline(__always)
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
func variance(alongAxes axes: Int32...) -> Tensor {
return variance(alongAxes: axes)
}

/// Returns the variance along the specified axes. The reduced dimensions are
/// retained with value 1. Does not apply Bessel's correction.
/// - Parameter axes: The dimensions to reduce.
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable @inline(__always)
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
func variance(alongAxes axes: [Int32]) -> Tensor {
let mean = self.mean(alongAxes: axes)
let squaredDiff = (self - mean).squared()
return squaredDiff.mean(alongAxes: axes)
}

/// Returns the product along the specified axes. The reduced dimensions are
/// retained with value 1.
/// - Parameter axes: The dimensions to reduce.
Expand Down
4 changes: 2 additions & 2 deletions test/AutoDiff/autodiff_function_silgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ func apply() {
// CHECK-SILGEN: [[ORIG:%.*]] = autodiff_function_extract [original] [[DIFFED_COPY]] : $@autodiff @callee_guaranteed (Float) -> Float
// CHECK-SILGEN: [[BORROWED_ORIG:%.*]] = begin_borrow [[ORIG]] : $@callee_guaranteed (Float) -> Float
// CHECK-SILGEN: apply [[BORROWED_ORIG]]({{%.*}}) : $@callee_guaranteed (Float) -> Float
// CHECK-SILGEN: destroy_value [[ORIG]] : $@callee_guaranteed (Float) -> Float
// CHECK-SILGEN: [[DIFFED_COPY:%.*]] = copy_value [[DIFFED]] : $@autodiff @callee_guaranteed (Float) -> Float
// CHECK-SILGEN: [[ORIG:%.*]] = autodiff_function_extract [original] [[DIFFED_COPY]] : $@autodiff @callee_guaranteed (Float) -> Float
// CHECK-SILGEN: [[ORIG_COPY:%.*]] = copy_value [[ORIG]] : $@callee_guaranteed (Float) -> Float
// CHECK-SILGEN: return [[ORIG_COPY]] : $@callee_guaranteed (Float) -> Float
// CHECK-SILGEN: return [[ORIG]] : $@callee_guaranteed (Float) -> Float

// CHECK-SILGEN-LABEL: @{{.*}}apply{{.*}}
// CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float
Expand Down
10 changes: 5 additions & 5 deletions test/TensorFlowRuntime/tensor_autodiff_runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,27 @@ TensorADTests.testAllBackends("negate") {
TensorADTests.testAllBackends("sum") {
let input = Tensor<Float>(randomNormal: [2, 2])
let sumPullbackScalar = pullback(at: input) { (a: Tensor<Float>) in a.sum() }
let sumPullbackSqueezingAxes = pullback(at: input) { (a: Tensor<Float>) in a.sum(squeezingAxes: 0, 1) }
// let sumPullbackSqueezingAxes = pullback(at: input) { (a: Tensor<Float>) in a.sum(squeezingAxes: 0, 1) }
let sumPullbackAlongAxes = pullback(at: input) { (a: Tensor<Float>) in a.sum(alongAxes: 0, 1) }

let expected = Tensor<Float>(ones: [2, 2])
expectTrue(sumPullbackScalar(Tensor(1)) == expected)
expectTrue(sumPullbackSqueezingAxes(Tensor(1)) == expected)
// expectTrue(sumPullbackSqueezingAxes(Tensor(1)) == expected)
expectTrue(sumPullbackAlongAxes(Tensor(1)) == expected)
expectTrue(sumPullbackScalar(Tensor(3)) == expected * 3)
expectTrue(sumPullbackSqueezingAxes(Tensor(3)) == expected * 3)
// expectTrue(sumPullbackSqueezingAxes(Tensor(3)) == expected * 3)
expectTrue(sumPullbackAlongAxes(Tensor(3)) == expected * 3)
}

TensorADTests.testAllBackends("mean") {
let meanGradScalar = gradient { (a: Tensor<Float>) in a.mean() }
let meanGradSqueezingAxes = gradient { (a: Tensor<Float>) in a.mean(squeezingAxes: 0, 1) }
// let meanGradSqueezingAxes = gradient { (a: Tensor<Float>) in a.mean(squeezingAxes: 0, 1) }
let meanGradAlongAxes = gradient { (a: Tensor<Float>) in a.mean(alongAxes: 0, 1) }

let input = Tensor<Float>(ones: [2, 2])
let expected = Tensor<Float>(shape: [2, 2], repeating: 0.25)
expectTrue(meanGradScalar(input) == expected)
expectTrue(meanGradSqueezingAxes(input) == expected)
// expectTrue(meanGradSqueezingAxes(input) == expected)
expectTrue(meanGradAlongAxes(input) == expected)
}

Expand Down
2 changes: 1 addition & 1 deletion utils/update_checkout/update-checkout-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@
"icu": "release-61-1",
"tensorflow": "a6924e6affd935f537cdaf8977094df0e15a7957",
"tensorflow-swift-bindings": "10e591340134c37a6c3a1df735a7334a77d5cbc7",
"tensorflow-swift-apis": "731ce402ce0bd7459b898b112deb89379bbda893"
"tensorflow-swift-apis": "ece4a67ed844919b04d2aec9d131b098aee39413"
}
}
}
Expand Down