Skip to content

[AutoDiff] Added support for 'withoutDerivative(at:)'. #25691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -413,16 +413,16 @@ NOTE(autodiff_external_nondifferentiable_function,none,
"'@differentiable' and that are defined in other files", ())
NOTE(autodiff_nondifferentiable_argument,none,
"cannot differentiate through a non-differentiable argument; do you want "
"to add '.withoutDerivative()'?", ())
"to use 'withoutDerivative(at:)'?", ())
NOTE(autodiff_nondifferentiable_result,none,
"cannot differentiate through a non-differentiable result; do you want to "
"add '.withoutDerivative()'?", ())
"use 'withoutDerivative(at:)'?", ())
NOTE(autodiff_noderivative_stored_property,none,
"cannot differentiate through a '@noDerivative' stored property; do you "
"want to add '.withoutDerivative()'?", ())
"want to use 'withoutDerivative(at:)'?", ())
WARNING(autodiff_nonvaried_result_fixit,none,
"result does not depend on differentiation arguments and will always "
"have a zero derivative; do you want to add '.withoutDerivative()'?",
"have a zero derivative; do you want to use 'withoutDerivative(at:)'?",
())
NOTE(autodiff_enums_unsupported,none,
"differentiating enum values is not yet supported", ())
Expand Down
10 changes: 6 additions & 4 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4270,11 +4270,13 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
// have a zero derivative.
if (!getActivityInfo().isVaried(origResult, getIndices().parameters)) {
// Emit fixit if original result has a valid source location.
auto sourceLoc = origResult.getLoc().getEndSourceLoc();
if (sourceLoc.isValid()) {
auto startLoc = origResult.getLoc().getStartSourceLoc();
auto endLoc = origResult.getLoc().getEndSourceLoc();
if (startLoc.isValid() && endLoc.isValid()) {
getContext()
.diagnose(sourceLoc, diag::autodiff_nonvaried_result_fixit)
.fixItInsertAfter(sourceLoc, ".withoutDerivative()");
.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit)
.fixItInsert(startLoc, "withoutDerivative(at:")
.fixItInsertAfter(endLoc, ")");
}
}
builder.setInsertionPoint(
Expand Down
14 changes: 8 additions & 6 deletions stdlib/public/core/AutoDiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,14 @@ public extension Differentiable where TangentVector == Self {
}
}

public extension Differentiable {
/// Identity function that stops derivatives from propagating.
@inlinable
@inline(__always)
@_semantics("autodiff.nonvarying")
func withoutDerivative() -> Self { return self }
/// Returns `x` like an identity function. When used in a context where `x` is
/// being differentiated with respect to, this function will not produce any
/// derivative at `x`.
@inlinable
@inline(__always)
@_semantics("autodiff.nonvarying")
public func withoutDerivative<T>(at x: T) -> T {
x
}

/// Applies the given closure `body` to `x`. When used in a context where `x` is
Expand Down
12 changes: 6 additions & 6 deletions test/AutoDiff/autodiff_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ struct NoDerivativeProperty : Differentiable {
// expected-error @+1 {{function is not differentiable}}
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s -> Float in
var tmp = s
// expected-note @+1 {{cannot differentiate through a '@noDerivative' stored property; do you want to add '.withoutDerivative()'?}}
// expected-note @+1 {{cannot differentiate through a '@noDerivative' stored property; do you want to use 'withoutDerivative(at:)'?}}
tmp.y = tmp.x
return tmp.x
}
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s in
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{13-13=.withoutDerivative()}}
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} {{13-13=)}}
return s.y
}
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) {
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{7-7=.withoutDerivative()}}
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{3-3=withoutDerivative(at:}} {{7-7=)}}
$0.y
}

Expand All @@ -74,7 +74,7 @@ _ = gradient(at: 0, in: uses_optionals)

func base(_ x: Float) -> Float {
// expected-error @+2 {{expression is not differentiable}}
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to add '.withoutDerivative()'?}}
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
return Float(Int(x))
}

Expand Down Expand Up @@ -225,7 +225,7 @@ let no_return: @differentiable (Float) -> Float = { x in
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func roundingGivesError(x: Float) -> Float {
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to add '.withoutDerivative()'?}}
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
return Float(Int(x))
}

Expand Down Expand Up @@ -261,7 +261,7 @@ func one() -> Float {
}
@differentiable
func nonVariedResult(_ x: Float) -> Float {
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{15-15=.withoutDerivative()}}
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} {{15-15=)}}
return one()
}

Expand Down
2 changes: 1 addition & 1 deletion utils/update_checkout/update-checkout-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@
"clang-tools-extra": "swift-DEVELOPMENT-SNAPSHOT-2019-06-17-a",
"libcxx": "swift-DEVELOPMENT-SNAPSHOT-2019-06-17-a",
"tensorflow": "ebc41609e27dcf0998d8970e77a2e1f53e13ac86",
"tensorflow-swift-apis": "7a3ed481bba53a7cd82f8a46c0df9f09a6e9747f",
"tensorflow-swift-apis": "6045403998c155382466ece924a60b4cfe9cd466",
"indexstore-db": "swift-DEVELOPMENT-SNAPSHOT-2019-06-17-a",
"sourcekit-lsp": "swift-DEVELOPMENT-SNAPSHOT-2019-06-17-a"
}
Expand Down