Skip to content

[AutoDiff] Rename 'in:' to 'of:' in differential operators. #36121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmark/single-source/Differentiation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public func run_DifferentiationIdentity(N: Int) {
x
}
for _ in 0..<1000*N {
blackHole(valueWithGradient(at: 1, in: f))
blackHole(valueWithGradient(at: 1, of: f))
}
}

Expand All @@ -50,7 +50,7 @@ public func run_DifferentiationSquare(N: Int) {
x * x
}
for _ in 0..<1000*N {
blackHole(valueWithGradient(at: 1, in: f))
blackHole(valueWithGradient(at: 1, of: f))
}
}

Expand All @@ -66,7 +66,7 @@ public func run_DifferentiationArraySum(N: Int) {
return result
}
for _ in 0..<N {
blackHole(valueWithGradient(at: onesArray, in: sum))
blackHole(valueWithGradient(at: onesArray, of: sum))
}
}

Expand Down
126 changes: 63 additions & 63 deletions docs/DifferentiableProgramming.md

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2869,16 +2869,16 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
(/*propName*/ Identifier, /*propType*/ Type, /*nominalName*/ Identifier,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
"synthesis of the 'Differentiable.move(by:)' requirement for %1 "
"requires 'wrappedValue' in property wrapper %0 to be mutable or have a "
"non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute"
"non-mutating 'move(by:)'; add an explicit '@noDerivative' attribute"
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
(/*wrapperType*/ Identifier, /*nominalName*/ Identifier,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
"synthesis of the 'Differentiable.move(by:)' requirement for %0 "
"requires all stored properties not marked with `@noDerivative` to be "
"mutable or have a non-mutating 'move(along:)'; use 'var' instead, or "
"mutable or have a non-mutating 'move(by:)'; use 'var' instead, or "
"add an explicit '@noDerivative' attribute "
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))
Expand Down
4 changes: 2 additions & 2 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ IDENTIFIER(AtomicStoreOrdering)
IDENTIFIER(AtomicUpdateOrdering)

// Differentiable programming
IDENTIFIER(along)
IDENTIFIER(by)
IDENTIFIER(differential)
IDENTIFIER(direction)
IDENTIFIER(offset)
IDENTIFIER(move)
IDENTIFIER(pullback)
IDENTIFIER(TangentVector)
Expand Down
44 changes: 22 additions & 22 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,24 @@

using namespace swift;

/// Return true if `move(along:)` can be invoked on the given `Differentiable`-
/// Return true if `move(by:)` can be invoked on the given `Differentiable`-
/// conforming property.
///
/// If the given property is a `var`, return true because `move(along:)` can be
/// If the given property is a `var`, return true because `move(by:)` can be
/// invoked regardless. Otherwise, return true if and only if the property's
/// type's 'Differentiable.move(along:)' witness is non-mutating.
/// type's 'Differentiable.move(by:)' witness is non-mutating.
static bool canInvokeMoveAlongOnProperty(
VarDecl *vd, ProtocolConformanceRef diffableConformance) {
assert(diffableConformance && "Property must conform to 'Differentiable'");
// `var` always supports `move(along:)` since it is mutable.
// `var` always supports `move(by:)` since it is mutable.
if (vd->getIntroducer() == VarDecl::Introducer::Var)
return true;
// When the property is a `let`, the only case that would be supported is when
// it has a `move(along:)` protocol requirement witness that is non-mutating.
// it has a `move(by:)` protocol requirement witness that is non-mutating.
auto interfaceType = vd->getInterfaceType();
auto &C = vd->getASTContext();
auto witness = diffableConformance.getWitnessByName(
interfaceType, DeclName(C, C.Id_move, {C.Id_along}));
interfaceType, DeclName(C, C.Id_move, {C.Id_by}));
if (!witness)
return false;
auto *decl = cast<FuncDecl>(witness.getDecl());
Expand All @@ -70,7 +70,7 @@ getStoredPropertiesForDifferentiation(
for (auto *vd : nominal->getStoredProperties()) {
// Peer through property wrappers: use original wrapped properties instead.
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
// Skip immutable wrapped properties. `mutating func move(along:)` cannot
// Skip immutable wrapped properties. `mutating func move(by:)` cannot
// be synthesized to update these properties.
if (!originalProperty->isSettable(DC))
continue;
Expand All @@ -87,8 +87,8 @@ getStoredPropertiesForDifferentiation(
varType, diffableProto, nominal);
if (!conformance)
continue;
// Skip `let` stored properties with a mutating `move(along:)` if requested.
// `mutating func move(along:)` cannot be synthesized to update `let`
// Skip `let` stored properties with a mutating `move(by:)` if requested.
// `mutating func move(by:)` cannot be synthesized to update `let`
// properties.
if (!includeLetPropertiesWithNonmutatingMoveAlong &&
!canInvokeMoveAlongOnProperty(vd, conformance))
Expand Down Expand Up @@ -214,14 +214,14 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
});
}

/// Synthesize body for `move(along:)`.
/// Synthesize body for `move(by:)`.
static std::pair<BraceStmt *, bool>
deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
auto &C = funcDecl->getASTContext();
auto *parentDC = funcDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();

// Get `Differentiable.move(along:)` protocol requirement.
// Get `Differentiable.move(by:)` protocol requirement.
auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable);
auto *requirement = getProtocolRequirement(diffProto, C.Id_move);

Expand All @@ -236,31 +236,31 @@ deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
SmallVector<VarDecl *, 8> diffProperties;
getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);

// Create call expression applying a member `move(along:)` method to a
// parameter member: `self.<member>.move(along: direction.<member>)`.
// Create call expression applying a member `move(by:)` method to a
// parameter member: `self.<member>.move(by: offset.<member>)`.
auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * {
auto *module = nominal->getModuleContext();
auto memberType =
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
auto confRef = module->lookupConformance(memberType, diffProto);
assert(confRef && "Member does not conform to `Differentiable`");

// Get member type's requirement witness: `<Member>.move(along:)`.
// Get member type's requirement witness: `<Member>.move(by:)`.
ValueDecl *memberWitnessDecl = requirement;
if (confRef.isConcrete())
if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement))
memberWitnessDecl = witness;
assert(memberWitnessDecl && "Member witness declaration must exist");

// Create reference to member method: `self.<member>.move(along:)`.
// Create reference to member method: `self.<member>.move(by:)`.
Expr *memberExpr =
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
/*Implicit*/ true);
auto *memberMethodExpr =
new (C) MemberRefExpr(memberExpr, SourceLoc(), memberWitnessDecl,
DeclNameLoc(), /*Implicit*/ true);

// Create reference to parameter member: `direction.<member>`.
// Create reference to parameter member: `offset.<member>`.
VarDecl *paramMember = nullptr;
auto *paramNominal = paramDecl->getType()->getAnyNominal();
assert(paramNominal && "Parameter should have a nominal type");
Expand All @@ -275,12 +275,12 @@ deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
auto *paramMemberExpr =
new (C) MemberRefExpr(paramDRE, SourceLoc(), paramMember, DeclNameLoc(),
/*Implicit*/ true);
// Create expression: `self.<member>.move(along: direction.<member>)`.
// Create expression: `self.<member>.move(by: offset.<member>)`.
return CallExpr::createImplicit(C, memberMethodExpr, {paramMemberExpr},
{C.Id_along});
{C.Id_by});
};

// Collect member `move(along:)` method call expressions.
// Collect member `move(by:)` method call expressions.
SmallVector<ASTNode, 2> memberMethodCallExprs;
SmallVector<Identifier, 2> memberNames;
for (auto *member : diffProperties) {
Expand Down Expand Up @@ -326,14 +326,14 @@ static ValueDecl *deriveDifferentiable_method(
return funcDecl;
}

/// Synthesize the `move(along:)` function declaration.
/// Synthesize the `move(by:)` function declaration.
static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
auto &C = derived.Context;
auto *parentDC = derived.getConformanceContext();
auto tangentType =
getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC);
return deriveDifferentiable_method(
derived, C.Id_move, C.Id_along, C.Id_direction, tangentType,
derived, C.Id_move, C.Id_by, C.Id_offset, tangentType,
C.TheEmptyTupleType, {deriveBodyDifferentiable_move, nullptr});
}

Expand Down Expand Up @@ -561,7 +561,7 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// Diagnose wrapped properties whose property wrappers do not define
// `wrappedValue.set`. `mutating func move(along:)` cannot be synthesized
// `wrappedValue.set`. `mutating func move(by:)` cannot be synthesized
// to update these properties.
if (!originalProperty->isSettable(DC)) {
auto *wrapperDecl =
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
return getRequirement(KnownProtocolKind::AdditiveArithmetic);
}

// Differentiable.move(along:)
// Differentiable.move(by:)
if (name.isCompoundName() && name.getBaseName() == ctx.Id_move) {
auto argumentNames = name.getArgumentNames();
if (argumentNames.size() == 1 && argumentNames[0] == ctx.Id_along)
if (argumentNames.size() == 1 && argumentNames[0] == ctx.Id_by)
return getRequirement(KnownProtocolKind::Differentiable);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,41 +325,41 @@ where T: Differentiable & FloatingPoint, T == T.TangentVector {
// Differential operators for `${Self}<T>`.

public func gradient<T, R: FloatingPoint>(
at x: T, in f: @differentiable(reverse) (T) -> ${Self}<R>
at x: T, of f: @differentiable(reverse) (T) -> ${Self}<R>
) -> T.TangentVector where R.TangentVector == R {
return pullback(at: x, in: f)(1)
return pullback(at: x, of: f)(1)
}

public func gradient<T, U, R: FloatingPoint>(
at x: T, _ y: U, in f: @differentiable(reverse) (T, U) -> ${Self}<R>
at x: T, _ y: U, of f: @differentiable(reverse) (T, U) -> ${Self}<R>
) -> (T.TangentVector, U.TangentVector) where R.TangentVector == R {
return pullback(at: x, y, in: f)(1)
return pullback(at: x, y, of: f)(1)
}

public func derivative<T: FloatingPoint, R>(
at x: ${Self}<T>, in f: @differentiable(reverse) (${Self}<T>) -> R
at x: ${Self}<T>, of f: @differentiable(reverse) (${Self}<T>) -> R
) -> R.TangentVector where T.TangentVector == T {
return differential(at: x, in: f)(1)
return differential(at: x, of: f)(1)
}

public func derivative<T: FloatingPoint, U: FloatingPoint, R>(
at x: ${Self}<T>, _ y: ${Self}<U>,
in f: @differentiable(reverse) (${Self}<T>, ${Self}<U>) -> R
of f: @differentiable(reverse) (${Self}<T>, ${Self}<U>) -> R
) -> R.TangentVector where T.TangentVector == T, U.TangentVector == U {
return differential(at: x, y, in: f)(1, 1)
return differential(at: x, y, of: f)(1, 1)
}

public func valueWithGradient<T, R: FloatingPoint>(
at x: T, in f: @differentiable(reverse) (T) -> ${Self}<R>
at x: T, of f: @differentiable(reverse) (T) -> ${Self}<R>
) -> (value: ${Self}<R>, gradient: T.TangentVector) {
let (y, pullback) = valueWithPullback(at: x, in: f)
let (y, pullback) = valueWithPullback(at: x, of: f)
return (y, pullback(1))
}

public func valueWithDerivative<T: FloatingPoint, R>(
at x: ${Self}<T>, in f: @differentiable(reverse) (${Self}<T>) -> R
at x: ${Self}<T>, of f: @differentiable(reverse) (${Self}<T>) -> R
) -> (value: R, derivative: R.TangentVector) {
let (y, differential) = valueWithDifferential(at: x, in: f)
let (y, differential) = valueWithDifferential(at: x, of: f)
return (y, differential(1))
}

Expand Down
40 changes: 17 additions & 23 deletions stdlib/public/Differentiation/AnyDifferentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import Swift

internal protocol _AnyDifferentiableBox {
// `Differentiable` requirements.
mutating func _move(along direction: AnyDerivative)
mutating func _move(by offset: AnyDerivative)

/// The underlying base value, type-erased to `Any`.
var _typeErasedBase: Any { get }
Expand All @@ -50,14 +50,11 @@ internal struct _ConcreteDifferentiableBox<T: Differentiable>: _AnyDifferentiabl
return (self as? _ConcreteDifferentiableBox<U>)?._base
}

mutating func _move(along direction: AnyDerivative) {
guard
let directionBase =
direction.base as? T.TangentVector
else {
_derivativeTypeMismatch(T.self, type(of: direction.base))
mutating func _move(by offset: AnyDerivative) {
guard let offsetBase = offset.base as? T.TangentVector else {
_derivativeTypeMismatch(T.self, type(of: offset.base))
}
_base.move(along: directionBase)
_base.move(by: offsetBase)
}
}

Expand Down Expand Up @@ -100,8 +97,8 @@ public struct AnyDifferentiable: Differentiable {

public typealias TangentVector = AnyDerivative

public mutating func move(along direction: TangentVector) {
_box._move(along: direction)
public mutating func move(by offset: TangentVector) {
_box._move(by: offset)
}
}

Expand All @@ -121,7 +118,7 @@ internal protocol _AnyDerivativeBox {
func _subtracting(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox

// `Differentiable` requirements.
mutating func _move(along direction: _AnyDerivativeBox)
mutating func _move(by offset: _AnyDerivativeBox)

/// The underlying base value, type-erased to `Any`.
var _typeErasedBase: Any { get }
Expand Down Expand Up @@ -215,19 +212,16 @@ where T: Differentiable, T.TangentVector == T {

// `Differentiable` requirements.
@inlinable
mutating func _move(along direction: _AnyDerivativeBox) {
if direction._isOpaqueZero() {
mutating func _move(by offset: _AnyDerivativeBox) {
if offset._isOpaqueZero() {
return
}
// The case where `self._isOpaqueZero()` returns true is handled in
// `AnyDerivative.move(along:)`.
guard
let directionBase =
direction._unboxed(to: T.TangentVector.self)
else {
_derivativeTypeMismatch(T.self, type(of: direction._typeErasedBase))
// `AnyDerivative.move(by:)`.
guard let offsetBase = offset._unboxed(to: T.TangentVector.self) else {
_derivativeTypeMismatch(T.self, type(of: offset._typeErasedBase))
}
_base.move(along: directionBase)
_base.move(by: offsetBase)
}
}

Expand Down Expand Up @@ -362,12 +356,12 @@ public struct AnyDerivative: Differentiable & AdditiveArithmetic {

// `Differentiable` requirements.
@inlinable
public mutating func move(along direction: TangentVector) {
public mutating func move(by offset: TangentVector) {
if _box._isOpaqueZero() {
_box = direction._box
_box = offset._box
return
}
_box._move(along: direction._box)
_box._move(by: offset._box)
}
}

Expand Down
Loading