Skip to content

[AutoDiff] Remove 'withoutDerivative(at:in:)' and 'withDerivative(_:)'. #36124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmark/single-source/Differentiation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public func run_DifferentiationIdentity(N: Int) {
x
}
for _ in 0..<1000*N {
blackHole(valueWithGradient(at: 1, in: f))
blackHole(valueWithGradient(at: 1, of: f))
}
}

Expand All @@ -50,7 +50,7 @@ public func run_DifferentiationSquare(N: Int) {
x * x
}
for _ in 0..<1000*N {
blackHole(valueWithGradient(at: 1, in: f))
blackHole(valueWithGradient(at: 1, of: f))
}
}

Expand All @@ -66,7 +66,7 @@ public func run_DifferentiationArraySum(N: Int) {
return result
}
for _ in 0..<N {
blackHole(valueWithGradient(at: onesArray, in: sum))
blackHole(valueWithGradient(at: onesArray, of: sum))
}
}

Expand Down
126 changes: 63 additions & 63 deletions docs/DifferentiableProgramming.md

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2869,16 +2869,16 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
(/*propName*/ Identifier, /*propType*/ Type, /*nominalName*/ Identifier,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
"synthesis of the 'Differentiable.move(by:)' requirement for %1 "
"requires 'wrappedValue' in property wrapper %0 to be mutable or have a "
"non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute"
"non-mutating 'move(by:)'; add an explicit '@noDerivative' attribute"
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
(/*wrapperType*/ Identifier, /*nominalName*/ Identifier,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
"synthesis of the 'Differentiable.move(by:)' requirement for %0 "
"requires all stored properties not marked with `@noDerivative` to be "
"mutable or have a non-mutating 'move(along:)'; use 'var' instead, or "
"mutable or have a non-mutating 'move(by:)'; use 'var' instead, or "
"add an explicit '@noDerivative' attribute "
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))
Expand Down
4 changes: 2 additions & 2 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ IDENTIFIER(AtomicStoreOrdering)
IDENTIFIER(AtomicUpdateOrdering)

// Differentiable programming
IDENTIFIER(along)
IDENTIFIER(by)
IDENTIFIER(differential)
IDENTIFIER(direction)
IDENTIFIER(offset)
IDENTIFIER(move)
IDENTIFIER(pullback)
IDENTIFIER(TangentVector)
Expand Down
44 changes: 22 additions & 22 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,24 @@

using namespace swift;

/// Return true if `move(along:)` can be invoked on the given `Differentiable`-
/// Return true if `move(by:)` can be invoked on the given `Differentiable`-
/// conforming property.
///
/// If the given property is a `var`, return true because `move(along:)` can be
/// If the given property is a `var`, return true because `move(by:)` can be
/// invoked regardless. Otherwise, return true if and only if the property's
/// type's 'Differentiable.move(along:)' witness is non-mutating.
/// type's 'Differentiable.move(by:)' witness is non-mutating.
static bool canInvokeMoveAlongOnProperty(
VarDecl *vd, ProtocolConformanceRef diffableConformance) {
assert(diffableConformance && "Property must conform to 'Differentiable'");
// `var` always supports `move(along:)` since it is mutable.
// `var` always supports `move(by:)` since it is mutable.
if (vd->getIntroducer() == VarDecl::Introducer::Var)
return true;
// When the property is a `let`, the only case that would be supported is when
// it has a `move(along:)` protocol requirement witness that is non-mutating.
// it has a `move(by:)` protocol requirement witness that is non-mutating.
auto interfaceType = vd->getInterfaceType();
auto &C = vd->getASTContext();
auto witness = diffableConformance.getWitnessByName(
interfaceType, DeclName(C, C.Id_move, {C.Id_along}));
interfaceType, DeclName(C, C.Id_move, {C.Id_by}));
if (!witness)
return false;
auto *decl = cast<FuncDecl>(witness.getDecl());
Expand All @@ -70,7 +70,7 @@ getStoredPropertiesForDifferentiation(
for (auto *vd : nominal->getStoredProperties()) {
// Peer through property wrappers: use original wrapped properties instead.
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
// Skip immutable wrapped properties. `mutating func move(along:)` cannot
// Skip immutable wrapped properties. `mutating func move(by:)` cannot
// be synthesized to update these properties.
if (!originalProperty->isSettable(DC))
continue;
Expand All @@ -87,8 +87,8 @@ getStoredPropertiesForDifferentiation(
varType, diffableProto, nominal);
if (!conformance)
continue;
// Skip `let` stored properties with a mutating `move(along:)` if requested.
// `mutating func move(along:)` cannot be synthesized to update `let`
// Skip `let` stored properties with a mutating `move(by:)` if requested.
// `mutating func move(by:)` cannot be synthesized to update `let`
// properties.
if (!includeLetPropertiesWithNonmutatingMoveAlong &&
!canInvokeMoveAlongOnProperty(vd, conformance))
Expand Down Expand Up @@ -214,14 +214,14 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
});
}

/// Synthesize body for `move(along:)`.
/// Synthesize body for `move(by:)`.
static std::pair<BraceStmt *, bool>
deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
auto &C = funcDecl->getASTContext();
auto *parentDC = funcDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();

// Get `Differentiable.move(along:)` protocol requirement.
// Get `Differentiable.move(by:)` protocol requirement.
auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable);
auto *requirement = getProtocolRequirement(diffProto, C.Id_move);

Expand All @@ -236,31 +236,31 @@ deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
SmallVector<VarDecl *, 8> diffProperties;
getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);

// Create call expression applying a member `move(along:)` method to a
// parameter member: `self.<member>.move(along: direction.<member>)`.
// Create call expression applying a member `move(by:)` method to a
// parameter member: `self.<member>.move(by: offset.<member>)`.
auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * {
auto *module = nominal->getModuleContext();
auto memberType =
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
auto confRef = module->lookupConformance(memberType, diffProto);
assert(confRef && "Member does not conform to `Differentiable`");

// Get member type's requirement witness: `<Member>.move(along:)`.
// Get member type's requirement witness: `<Member>.move(by:)`.
ValueDecl *memberWitnessDecl = requirement;
if (confRef.isConcrete())
if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement))
memberWitnessDecl = witness;
assert(memberWitnessDecl && "Member witness declaration must exist");

// Create reference to member method: `self.<member>.move(along:)`.
// Create reference to member method: `self.<member>.move(by:)`.
Expr *memberExpr =
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
/*Implicit*/ true);
auto *memberMethodExpr =
new (C) MemberRefExpr(memberExpr, SourceLoc(), memberWitnessDecl,
DeclNameLoc(), /*Implicit*/ true);

// Create reference to parameter member: `direction.<member>`.
// Create reference to parameter member: `offset.<member>`.
VarDecl *paramMember = nullptr;
auto *paramNominal = paramDecl->getType()->getAnyNominal();
assert(paramNominal && "Parameter should have a nominal type");
Expand All @@ -275,12 +275,12 @@ deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
auto *paramMemberExpr =
new (C) MemberRefExpr(paramDRE, SourceLoc(), paramMember, DeclNameLoc(),
/*Implicit*/ true);
// Create expression: `self.<member>.move(along: direction.<member>)`.
// Create expression: `self.<member>.move(by: offset.<member>)`.
return CallExpr::createImplicit(C, memberMethodExpr, {paramMemberExpr},
{C.Id_along});
{C.Id_by});
};

// Collect member `move(along:)` method call expressions.
// Collect member `move(by:)` method call expressions.
SmallVector<ASTNode, 2> memberMethodCallExprs;
SmallVector<Identifier, 2> memberNames;
for (auto *member : diffProperties) {
Expand Down Expand Up @@ -326,14 +326,14 @@ static ValueDecl *deriveDifferentiable_method(
return funcDecl;
}

/// Synthesize the `move(along:)` function declaration.
/// Synthesize the `move(by:)` function declaration.
static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
auto &C = derived.Context;
auto *parentDC = derived.getConformanceContext();
auto tangentType =
getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC);
return deriveDifferentiable_method(
derived, C.Id_move, C.Id_along, C.Id_direction, tangentType,
derived, C.Id_move, C.Id_by, C.Id_offset, tangentType,
C.TheEmptyTupleType, {deriveBodyDifferentiable_move, nullptr});
}

Expand Down Expand Up @@ -561,7 +561,7 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// Diagnose wrapped properties whose property wrappers do not define
// `wrappedValue.set`. `mutating func move(along:)` cannot be synthesized
// `wrappedValue.set`. `mutating func move(by:)` cannot be synthesized
// to update these properties.
if (!originalProperty->isSettable(DC)) {
auto *wrapperDecl =
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
return getRequirement(KnownProtocolKind::AdditiveArithmetic);
}

// Differentiable.move(along:)
// Differentiable.move(by:)
if (name.isCompoundName() && name.getBaseName() == ctx.Id_move) {
auto argumentNames = name.getArgumentNames();
if (argumentNames.size() == 1 && argumentNames[0] == ctx.Id_along)
if (argumentNames.size() == 1 && argumentNames[0] == ctx.Id_by)
return getRequirement(KnownProtocolKind::Differentiable);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,38 @@ public func withLeakChecking(
file: file, line: line)
}

@inlinable
@_semantics("autodiff.nonvarying")
public func withoutDerivative<T, R>(at x: T, in body: (T) -> R) -> R
{
body(x)
}

public extension Differentiable {
/// Applies the given closure to the derivative of `self`.
///
/// Returns `self` like an identity function. When the return value is used in
/// a context where it is differentiated with respect to, applies the given
/// closure to the derivative of the return value.
@inlinable
@differentiable(reverse, wrt: self)
func withDerivative(_ body: @escaping (inout TangentVector) -> Void) -> Self {
return self
}

@inlinable
@derivative(of: withDerivative)
internal func _vjpWithDerivative(
_ body: @escaping (inout TangentVector) -> Void
) -> (value: Self, pullback: (TangentVector) -> TangentVector) {
return (self, { grad in
var grad = grad
body(&grad)
return grad
})
}
}

public extension TestSuite {
/// Execute test function and check expected leak count.
func testWithLeakChecking(
Expand Down Expand Up @@ -325,41 +357,41 @@ where T: Differentiable & FloatingPoint, T == T.TangentVector {
// Differential operators for `${Self}<T>`.

public func gradient<T, R: FloatingPoint>(
at x: T, in f: @differentiable(reverse) (T) -> ${Self}<R>
at x: T, of f: @differentiable(reverse) (T) -> ${Self}<R>
) -> T.TangentVector where R.TangentVector == R {
return pullback(at: x, in: f)(1)
return pullback(at: x, of: f)(1)
}

public func gradient<T, U, R: FloatingPoint>(
at x: T, _ y: U, in f: @differentiable(reverse) (T, U) -> ${Self}<R>
at x: T, _ y: U, of f: @differentiable(reverse) (T, U) -> ${Self}<R>
) -> (T.TangentVector, U.TangentVector) where R.TangentVector == R {
return pullback(at: x, y, in: f)(1)
return pullback(at: x, y, of: f)(1)
}

public func derivative<T: FloatingPoint, R>(
at x: ${Self}<T>, in f: @differentiable(reverse) (${Self}<T>) -> R
at x: ${Self}<T>, of f: @differentiable(reverse) (${Self}<T>) -> R
) -> R.TangentVector where T.TangentVector == T {
return differential(at: x, in: f)(1)
return differential(at: x, of: f)(1)
}

public func derivative<T: FloatingPoint, U: FloatingPoint, R>(
at x: ${Self}<T>, _ y: ${Self}<U>,
in f: @differentiable(reverse) (${Self}<T>, ${Self}<U>) -> R
of f: @differentiable(reverse) (${Self}<T>, ${Self}<U>) -> R
) -> R.TangentVector where T.TangentVector == T, U.TangentVector == U {
return differential(at: x, y, in: f)(1, 1)
return differential(at: x, y, of: f)(1, 1)
}

public func valueWithGradient<T, R: FloatingPoint>(
at x: T, in f: @differentiable(reverse) (T) -> ${Self}<R>
at x: T, of f: @differentiable(reverse) (T) -> ${Self}<R>
) -> (value: ${Self}<R>, gradient: T.TangentVector) {
let (y, pullback) = valueWithPullback(at: x, in: f)
let (y, pullback) = valueWithPullback(at: x, of: f)
return (y, pullback(1))
}

public func valueWithDerivative<T: FloatingPoint, R>(
at x: ${Self}<T>, in f: @differentiable(reverse) (${Self}<T>) -> R
at x: ${Self}<T>, of f: @differentiable(reverse) (${Self}<T>) -> R
) -> (value: R, derivative: R.TangentVector) {
let (y, differential) = valueWithDifferential(at: x, in: f)
let (y, differential) = valueWithDifferential(at: x, of: f)
return (y, differential(1))
}

Expand Down
Loading