Skip to content

Commit 8bc6143

Browse files
committed
[AutoDiff] Rename 'move(along:)' to 'move(by:)'.
Rename `move(along:)` to `move(by:)` based on the proposal feedback. The main argument for the change is that tangent vectors specify both a direction and a magnitude, whereas `along:` does not indicate that `self` is being moved by the specified magnitude.
1 parent 6189e8e commit 8bc6143

29 files changed

+192
-198
lines changed

docs/DifferentiableProgramming.md

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -982,10 +982,10 @@ public protocol Differentiable {
982982
associatedtype TangentVector: Differentiable & AdditiveArithmetic
983983
where TangentVector == TangentVector.TangentVector
984984

985-
/// Moves `self` along the given direction. In Riemannian geometry, this is
985+
/// Moves `self` by the given offset. In Riemannian geometry, this is
986986
/// equivalent to exponential map, which moves `self` on the geodesic
987-
/// surface along the given tangent vector.
988-
mutating func move(along direction: TangentVector)
987+
/// surface by the given tangent vector.
988+
mutating func move(by offset: TangentVector)
989989
}
990990
```
991991

@@ -1037,19 +1037,19 @@ and
10371037
[`+(_:_:)`](https://developer.apple.com/documentation/swift/additivearithmetic/3126821)
10381038
are necessary for initializing and accumulating derivative values.
10391039

1040-
The `move(along:)` method is equivalent to the mathematical notion of
1040+
The `move(by:)` method is equivalent to the mathematical notion of
10411041
[exponential map](https://en.wikipedia.org/wiki/Exponential_map_\(Riemannian_geometry\)),
10421042
which takes a tangent vector (e.g. a derivative), and moves the value along the
10431043
direction specified by the tangent vector on the geodesic surface of the
10441044
manifold. In vector spaces where the tangent vector is of the same vector space
1045-
as the original differentiable space, `move(along:)` is equivalent to vector
1045+
as the original differentiable space, `move(by:)` is equivalent to vector
10461046
addition. Mathematical optimization algorithms such as gradient descent will
10471047
make use of this method.
10481048

10491049
```swift
10501050
public extension Differentiable where Self == TangentVector {
1051-
mutating func move(along direction: TangentVector) {
1052-
self += direction
1051+
mutating func move(by offset: TangentVector) {
1052+
self += offset
10531053
}
10541054
}
10551055
```
@@ -1096,9 +1096,9 @@ extension Array: Differentiable where Element: Differentiable {
10961096
...
10971097
}
10981098

1099-
public mutating func move(along direction: TangentVector) {
1099+
public mutating func move(by offset: TangentVector) {
11001100
for i in indices {
1101-
self[i].move(along: Element.TangentVector(direction.elements[i]))
1101+
self[i].move(by: Element.TangentVector(offset.elements[i]))
11021102
}
11031103
}
11041104
}
@@ -1116,9 +1116,9 @@ extension Dictionary: Differentiable where Value: Differentiable {
11161116
...
11171117
}
11181118

1119-
public mutating func move(along direction: TangentVector) {
1119+
public mutating func move(by offset: TangentVector) {
11201120
for i in indices {
1121-
self[i].move(along: Value.TangentVector(direction.elements[i]))
1121+
self[i].move(by: Value.TangentVector(offset.elements[i]))
11221122
}
11231123
}
11241124
}
@@ -1134,9 +1134,9 @@ extension Optional: Differentiable where Wrapped: Differentiable {
11341134
...
11351135
}
11361136

1137-
public mutating func move(along direction: TangentVector) {
1138-
if let value = direction.value {
1139-
self?.move(along: value)
1137+
public mutating func move(by offset: TangentVector) {
1138+
if let value = offset.value {
1139+
self?.move(by: value)
11401140
}
11411141
}
11421142
}
@@ -1189,12 +1189,12 @@ product manifold of the manifolds each differentiable variable's type
11891189
represents. Differentiable variables' types are required to conform to
11901190
`Differentiable` because the synthesized implementation needs to access each
11911191
differentiable variable's type's `TangentVector` associated type and invoke each
1192-
differentiable variable's implementation of `move(along:)`. Because the
1193-
synthesized implementation needs to invoke `move(along:)` on each differentiable
1194-
variable, the differentiable variables must have a `move(along:)` which satisfies the
1192+
differentiable variable's implementation of `move(by:)`. Because the
1193+
synthesized implementation needs to invoke `move(by:)` on each differentiable
1194+
variable, the differentiable variables must have a `move(by:)` which satisfies the
11951195
protocol requirement and can be invoked on the property. That is, the property
11961196
must be either a variable (`var`) or a constant (`let`) with a non-`mutating`
1197-
implementation of the `move(along:)` protocol requirement.
1197+
implementation of the `move(by:)` protocol requirement.
11981198

11991199
The synthesized `TangentVector` has the same effective access level as the
12001200
original type declaration. Properties in the synthesized `TangentVector` have
@@ -1206,7 +1206,7 @@ example, synthesized `TangentVector`s always adopt the `AdditiveArithmetic` and
12061206
`Differentiable` protocols because the `Differentiable` protocol requires that
12071207
`TangentVector` conforms to `AdditiveArithmetic` and `Differentiable`.
12081208

1209-
The synthesized `move(along:)` method calls `move(along:)` for each pair of a
1209+
The synthesized `move(by:)` method calls `move(by:)` for each pair of a
12101210
differentiable variable and its corresponding property in `TangentVector`.
12111211

12121212
```swift
@@ -1223,9 +1223,9 @@ struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
12231223
// var y: U.TangentVector
12241224
// }
12251225
//
1226-
// mutating func move(along direction: TangentVector) {
1227-
// x.move(along: direction.x)
1228-
// y.move(along: direction.y)
1226+
// mutating func move(by offset: TangentVector) {
1227+
// x.move(by: offset.x)
1228+
// y.move(by: offset.y)
12291229
// }
12301230
}
12311231
```
@@ -1235,7 +1235,7 @@ struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
12351235
The synthesized implementation of `Differentiable` protocol requirements already
12361236
excludes stored properties that are not differentiable variables, such as stored
12371237
properties that do not conform to `Differentiable` and `let`
1238-
properties that do not have a non-mutating `move(along:)`. In addition to this
1238+
properties that do not have a non-mutating `move(by:)`. In addition to this
12391239
behavior, we also introduce a `@noDerivative` declaration attribute, which can
12401240
be attached to properties that the programmer does not wish to include in the
12411241
synthesized `Differentiable` protocol requirement implementation.
@@ -1284,7 +1284,7 @@ test.swift:5:4: note: add a '@noDerivative' attribute to make it explicit
12841284
^
12851285
@noDerivative
12861286

1287-
test.swift:6:4: warning: synthesis of the 'Differentiable.move(along:)' requirement for 'Foo' requires all stored properties not marked with `@noDerivative` to be mutable
1287+
test.swift:6:4: warning: synthesis of the 'Differentiable.move(by:)' requirement for 'Foo' requires all stored properties not marked with `@noDerivative` to be mutable
12881288
let helperVariable: T
12891289

12901290
test.swift:6:4: note: change 'let' to 'var' to make it mutable
@@ -1307,7 +1307,7 @@ properties are declared to conform to `AdditiveArithmetic`. There are no
13071307
`@noDerivative` stored properties.
13081308

13091309
In these cases, the compiler will make `TangentVector` be a type alias for Self.
1310-
Method `move(along:)` will not be synthesized because a default implementation
1310+
Method `move(by:)` will not be synthesized because a default implementation
13111311
already exists.
13121312

13131313
```swift
@@ -2743,7 +2743,7 @@ for i in 0..<iterationCount {
27432743
}
27442744
𝛁loss.weight *= -learningRate
27452745
𝛁loss.bias *= -learningRate
2746-
model.move(along: 𝛁loss)
2746+
model.move(by: 𝛁loss)
27472747
if i.isMultiple(of: 10) {
27482748
print("Iteration: \(iteration) Avg Loss: \(loss / Float(data.count))")
27492749
}

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2869,16 +2869,16 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
28692869
(/*propName*/ Identifier, /*propType*/ Type, /*nominalName*/ Identifier,
28702870
/*nominalCanDeriveAdditiveArithmetic*/ bool))
28712871
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
2872-
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
2872+
"synthesis of the 'Differentiable.move(by:)' requirement for %1 "
28732873
"requires 'wrappedValue' in property wrapper %0 to be mutable or have a "
2874-
"non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute"
2874+
"non-mutating 'move(by:)'; add an explicit '@noDerivative' attribute"
28752875
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
28762876
(/*wrapperType*/ Identifier, /*nominalName*/ Identifier,
28772877
/*nominalCanDeriveAdditiveArithmetic*/ bool))
28782878
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
2879-
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
2879+
"synthesis of the 'Differentiable.move(by:)' requirement for %0 "
28802880
"requires all stored properties not marked with `@noDerivative` to be "
2881-
"mutable or have a non-mutating 'move(along:)'; use 'var' instead, or "
2881+
"mutable or have a non-mutating 'move(by:)'; use 'var' instead, or "
28822882
"add an explicit '@noDerivative' attribute "
28832883
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
28842884
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))

include/swift/AST/KnownIdentifiers.def

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,9 @@ IDENTIFIER(AtomicStoreOrdering)
228228
IDENTIFIER(AtomicUpdateOrdering)
229229

230230
// Differentiable programming
231-
IDENTIFIER(along)
231+
IDENTIFIER(by)
232232
IDENTIFIER(differential)
233-
IDENTIFIER(direction)
233+
IDENTIFIER(offset)
234234
IDENTIFIER(move)
235235
IDENTIFIER(pullback)
236236
IDENTIFIER(TangentVector)

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,24 @@
3434

3535
using namespace swift;
3636

37-
/// Return true if `move(along:)` can be invoked on the given `Differentiable`-
37+
/// Return true if `move(by:)` can be invoked on the given `Differentiable`-
3838
/// conforming property.
3939
///
40-
/// If the given property is a `var`, return true because `move(along:)` can be
40+
/// If the given property is a `var`, return true because `move(by:)` can be
4141
/// invoked regardless. Otherwise, return true if and only if the property's
42-
/// type's 'Differentiable.move(along:)' witness is non-mutating.
42+
/// type's 'Differentiable.move(by:)' witness is non-mutating.
4343
static bool canInvokeMoveAlongOnProperty(
4444
VarDecl *vd, ProtocolConformanceRef diffableConformance) {
4545
assert(diffableConformance && "Property must conform to 'Differentiable'");
46-
// `var` always supports `move(along:)` since it is mutable.
46+
// `var` always supports `move(by:)` since it is mutable.
4747
if (vd->getIntroducer() == VarDecl::Introducer::Var)
4848
return true;
4949
// When the property is a `let`, the only case that would be supported is when
50-
// it has a `move(along:)` protocol requirement witness that is non-mutating.
50+
// it has a `move(by:)` protocol requirement witness that is non-mutating.
5151
auto interfaceType = vd->getInterfaceType();
5252
auto &C = vd->getASTContext();
5353
auto witness = diffableConformance.getWitnessByName(
54-
interfaceType, DeclName(C, C.Id_move, {C.Id_along}));
54+
interfaceType, DeclName(C, C.Id_move, {C.Id_by}));
5555
if (!witness)
5656
return false;
5757
auto *decl = cast<FuncDecl>(witness.getDecl());
@@ -70,7 +70,7 @@ getStoredPropertiesForDifferentiation(
7070
for (auto *vd : nominal->getStoredProperties()) {
7171
// Peer through property wrappers: use original wrapped properties instead.
7272
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
73-
// Skip immutable wrapped properties. `mutating func move(along:)` cannot
73+
// Skip immutable wrapped properties. `mutating func move(by:)` cannot
7474
// be synthesized to update these properties.
7575
if (!originalProperty->isSettable(DC))
7676
continue;
@@ -87,8 +87,8 @@ getStoredPropertiesForDifferentiation(
8787
varType, diffableProto, nominal);
8888
if (!conformance)
8989
continue;
90-
// Skip `let` stored properties with a mutating `move(along:)` if requested.
91-
// `mutating func move(along:)` cannot be synthesized to update `let`
90+
// Skip `let` stored properties with a mutating `move(by:)` if requested.
91+
// `mutating func move(by:)` cannot be synthesized to update `let`
9292
// properties.
9393
if (!includeLetPropertiesWithNonmutatingMoveAlong &&
9494
!canInvokeMoveAlongOnProperty(vd, conformance))
@@ -214,14 +214,14 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
214214
});
215215
}
216216

217-
/// Synthesize body for `move(along:)`.
217+
/// Synthesize body for `move(by:)`.
218218
static std::pair<BraceStmt *, bool>
219219
deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
220220
auto &C = funcDecl->getASTContext();
221221
auto *parentDC = funcDecl->getParent();
222222
auto *nominal = parentDC->getSelfNominalTypeDecl();
223223

224-
// Get `Differentiable.move(along:)` protocol requirement.
224+
// Get `Differentiable.move(by:)` protocol requirement.
225225
auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable);
226226
auto *requirement = getProtocolRequirement(diffProto, C.Id_move);
227227

@@ -236,31 +236,31 @@ deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
236236
SmallVector<VarDecl *, 8> diffProperties;
237237
getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);
238238

239-
// Create call expression applying a member `move(along:)` method to a
240-
// parameter member: `self.<member>.move(along: direction.<member>)`.
239+
// Create call expression applying a member `move(by:)` method to a
240+
// parameter member: `self.<member>.move(by: offset.<member>)`.
241241
auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * {
242242
auto *module = nominal->getModuleContext();
243243
auto memberType =
244244
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
245245
auto confRef = module->lookupConformance(memberType, diffProto);
246246
assert(confRef && "Member does not conform to `Differentiable`");
247247

248-
// Get member type's requirement witness: `<Member>.move(along:)`.
248+
// Get member type's requirement witness: `<Member>.move(by:)`.
249249
ValueDecl *memberWitnessDecl = requirement;
250250
if (confRef.isConcrete())
251251
if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement))
252252
memberWitnessDecl = witness;
253253
assert(memberWitnessDecl && "Member witness declaration must exist");
254254

255-
// Create reference to member method: `self.<member>.move(along:)`.
255+
// Create reference to member method: `self.<member>.move(by:)`.
256256
Expr *memberExpr =
257257
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
258258
/*Implicit*/ true);
259259
auto *memberMethodExpr =
260260
new (C) MemberRefExpr(memberExpr, SourceLoc(), memberWitnessDecl,
261261
DeclNameLoc(), /*Implicit*/ true);
262262

263-
// Create reference to parameter member: `direction.<member>`.
263+
// Create reference to parameter member: `offset.<member>`.
264264
VarDecl *paramMember = nullptr;
265265
auto *paramNominal = paramDecl->getType()->getAnyNominal();
266266
assert(paramNominal && "Parameter should have a nominal type");
@@ -275,12 +275,12 @@ deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
275275
auto *paramMemberExpr =
276276
new (C) MemberRefExpr(paramDRE, SourceLoc(), paramMember, DeclNameLoc(),
277277
/*Implicit*/ true);
278-
// Create expression: `self.<member>.move(along: direction.<member>)`.
278+
// Create expression: `self.<member>.move(by: offset.<member>)`.
279279
return CallExpr::createImplicit(C, memberMethodExpr, {paramMemberExpr},
280-
{C.Id_along});
280+
{C.Id_by});
281281
};
282282

283-
// Collect member `move(along:)` method call expressions.
283+
// Collect member `move(by:)` method call expressions.
284284
SmallVector<ASTNode, 2> memberMethodCallExprs;
285285
SmallVector<Identifier, 2> memberNames;
286286
for (auto *member : diffProperties) {
@@ -326,14 +326,14 @@ static ValueDecl *deriveDifferentiable_method(
326326
return funcDecl;
327327
}
328328

329-
/// Synthesize the `move(along:)` function declaration.
329+
/// Synthesize the `move(by:)` function declaration.
330330
static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
331331
auto &C = derived.Context;
332332
auto *parentDC = derived.getConformanceContext();
333333
auto tangentType =
334334
getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC);
335335
return deriveDifferentiable_method(
336-
derived, C.Id_move, C.Id_along, C.Id_direction, tangentType,
336+
derived, C.Id_move, C.Id_by, C.Id_offset, tangentType,
337337
C.TheEmptyTupleType, {deriveBodyDifferentiable_move, nullptr});
338338
}
339339

@@ -561,7 +561,7 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
561561
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
562562
continue;
563563
// Diagnose wrapped properties whose property wrappers do not define
564-
// `wrappedValue.set`. `mutating func move(along:)` cannot be synthesized
564+
// `wrappedValue.set`. `mutating func move(by:)` cannot be synthesized
565565
// to update these properties.
566566
if (!originalProperty->isSettable(DC)) {
567567
auto *wrapperDecl =

lib/Sema/DerivedConformances.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,10 +329,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
329329
return getRequirement(KnownProtocolKind::AdditiveArithmetic);
330330
}
331331

332-
// Differentiable.move(along:)
332+
// Differentiable.move(by:)
333333
if (name.isCompoundName() && name.getBaseName() == ctx.Id_move) {
334334
auto argumentNames = name.getArgumentNames();
335-
if (argumentNames.size() == 1 && argumentNames[0] == ctx.Id_along)
335+
if (argumentNames.size() == 1 && argumentNames[0] == ctx.Id_by)
336336
return getRequirement(KnownProtocolKind::Differentiable);
337337
}
338338

0 commit comments

Comments
 (0)