Skip to content

Commit 70d7cc7

Browse files
eaplataniosrxwei
authored andcommitted
[AutoDiff] Added support for 'withoutDerivative(at:)'. (#25691)
1 parent def2be9 commit 70d7cc7

File tree

5 files changed

+25
-21
lines changed

5 files changed

+25
-21
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,16 +413,16 @@ NOTE(autodiff_external_nondifferentiable_function,none,
413413
"'@differentiable' and that are defined in other files", ())
414414
NOTE(autodiff_nondifferentiable_argument,none,
415415
"cannot differentiate through a non-differentiable argument; do you want "
416-
"to add '.withoutDerivative()'?", ())
416+
"to use 'withoutDerivative(at:)'?", ())
417417
NOTE(autodiff_nondifferentiable_result,none,
418418
"cannot differentiate through a non-differentiable result; do you want to "
419-
"add '.withoutDerivative()'?", ())
419+
"use 'withoutDerivative(at:)'?", ())
420420
NOTE(autodiff_noderivative_stored_property,none,
421421
"cannot differentiate through a '@noDerivative' stored property; do you "
422-
"want to add '.withoutDerivative()'?", ())
422+
"want to use 'withoutDerivative(at:)'?", ())
423423
WARNING(autodiff_nonvaried_result_fixit,none,
424424
"result does not depend on differentiation arguments and will always "
425-
"have a zero derivative; do you want to add '.withoutDerivative()'?",
425+
"have a zero derivative; do you want to use 'withoutDerivative(at:)'?",
426426
())
427427
NOTE(autodiff_enums_unsupported,none,
428428
"differentiating enum values is not yet supported", ())

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4270,11 +4270,13 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
42704270
// have a zero derivative.
42714271
if (!getActivityInfo().isVaried(origResult, getIndices().parameters)) {
42724272
// Emit fixit if original result has a valid source location.
4273-
auto sourceLoc = origResult.getLoc().getEndSourceLoc();
4274-
if (sourceLoc.isValid()) {
4273+
auto startLoc = origResult.getLoc().getStartSourceLoc();
4274+
auto endLoc = origResult.getLoc().getEndSourceLoc();
4275+
if (startLoc.isValid() && endLoc.isValid()) {
42754276
getContext()
4276-
.diagnose(sourceLoc, diag::autodiff_nonvaried_result_fixit)
4277-
.fixItInsertAfter(sourceLoc, ".withoutDerivative()");
4277+
.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit)
4278+
.fixItInsert(startLoc, "withoutDerivative(at:")
4279+
.fixItInsertAfter(endLoc, ")");
42784280
}
42794281
}
42804282
builder.setInsertionPoint(

stdlib/public/core/AutoDiff.swift

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,14 @@ public extension Differentiable where TangentVector == Self {
139139
}
140140
}
141141

142-
public extension Differentiable {
143-
/// Identity function that stops derivatives from propagating.
144-
@inlinable
145-
@inline(__always)
146-
@_semantics("autodiff.nonvarying")
147-
func withoutDerivative() -> Self { return self }
142+
/// Returns `x` like an identity function. When used in a context where `x` is
143+
/// being differentiated with respect to, this function will not produce any
144+
/// derivative at `x`.
145+
@inlinable
146+
@inline(__always)
147+
@_semantics("autodiff.nonvarying")
148+
public func withoutDerivative<T>(at x: T) -> T {
149+
x
148150
}
149151

150152
/// Applies the given closure `body` to `x`. When used in a context where `x` is

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ struct NoDerivativeProperty : Differentiable {
4545
// expected-error @+1 {{function is not differentiable}}
4646
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s -> Float in
4747
var tmp = s
48-
// expected-note @+1 {{cannot differentiate through a '@noDerivative' stored property; do you want to add '.withoutDerivative()'?}}
48+
// expected-note @+1 {{cannot differentiate through a '@noDerivative' stored property; do you want to use 'withoutDerivative(at:)'?}}
4949
tmp.y = tmp.x
5050
return tmp.x
5151
}
5252
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s in
53-
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{13-13=.withoutDerivative()}}
53+
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} {{13-13=)}}
5454
return s.y
5555
}
5656
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) {
57-
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{7-7=.withoutDerivative()}}
57+
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{3-3=withoutDerivative(at:}} {{7-7=)}}
5858
$0.y
5959
}
6060

@@ -74,7 +74,7 @@ _ = gradient(at: 0, in: uses_optionals)
7474

7575
func base(_ x: Float) -> Float {
7676
// expected-error @+2 {{expression is not differentiable}}
77-
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to add '.withoutDerivative()'?}}
77+
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
7878
return Float(Int(x))
7979
}
8080

@@ -225,7 +225,7 @@ let no_return: @differentiable (Float) -> Float = { x in
225225
@differentiable
226226
// expected-note @+1 {{when differentiating this function definition}}
227227
func roundingGivesError(x: Float) -> Float {
228-
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to add '.withoutDerivative()'?}}
228+
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
229229
return Float(Int(x))
230230
}
231231

@@ -261,7 +261,7 @@ func one() -> Float {
261261
}
262262
@differentiable
263263
func nonVariedResult(_ x: Float) -> Float {
264-
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{15-15=.withoutDerivative()}}
264+
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} {{15-15=)}}
265265
return one()
266266
}
267267

utils/update_checkout/update-checkout-config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@
401401
"clang-tools-extra": "swift-DEVELOPMENT-SNAPSHOT-2019-06-17-a",
402402
"libcxx": "swift-DEVELOPMENT-SNAPSHOT-2019-06-17-a",
403403
"tensorflow": "ebc41609e27dcf0998d8970e77a2e1f53e13ac86",
404-
"tensorflow-swift-apis": "7a3ed481bba53a7cd82f8a46c0df9f09a6e9747f",
404+
"tensorflow-swift-apis": "6045403998c155382466ece924a60b4cfe9cd466",
405405
"indexstore-db": "swift-DEVELOPMENT-SNAPSHOT-2019-06-17-a",
406406
"sourcekit-lsp": "swift-DEVELOPMENT-SNAPSHOT-2019-06-17-a"
407407
}

0 commit comments

Comments
 (0)