Skip to content

Commit 98f3545

Browse files
authored
[NFC] [AutoDiff] Gardening. (#27651)
1 parent a5dc918 commit 98f3545

File tree

5 files changed

+15
-19
lines changed

5 files changed

+15
-19
lines changed

include/swift/AST/Attr.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,8 +1537,8 @@ class DifferentiableAttr final
15371537
ParsedAutoDiffParameter> {
15381538
friend TrailingObjects;
15391539

1540-
/// Whether this function is linear (optional).
1541-
bool linear;
1540+
/// Whether this function is linear.
1541+
bool Linear;
15421542
/// The number of parsed parameters specified in 'wrt:'.
15431543
unsigned NumParsedParameters = 0;
15441544
/// The JVP function.
@@ -1621,7 +1621,7 @@ class DifferentiableAttr final
16211621
return NumParsedParameters;
16221622
}
16231623

1624-
bool isLinear() const { return linear; }
1624+
bool isLinear() const { return Linear; }
16251625

16261626
TrailingWhereClause *getWhereClause() const { return WhereClause; }
16271627

@@ -1676,8 +1676,8 @@ class DifferentiatingAttr final
16761676
DeclNameWithLoc Original;
16771677
/// The original function, resolved by the type checker.
16781678
FuncDecl *OriginalFunction = nullptr;
1679-
/// Whether this function is linear (optional).
1680-
bool linear;
1679+
/// Whether this function is linear.
1680+
bool Linear;
16811681
/// The number of parsed parameters specified in 'wrt:'.
16821682
unsigned NumParsedParameters = 0;
16831683
/// The differentiation parameters' indices, resolved by the type checker.
@@ -1706,7 +1706,7 @@ class DifferentiatingAttr final
17061706

17071707
DeclNameWithLoc getOriginal() const { return Original; }
17081708

1709-
bool isLinear() const { return linear; }
1709+
bool isLinear() const { return Linear; }
17101710

17111711
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
17121712
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }

lib/AST/Attr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,7 +1447,7 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
14471447
Optional<DeclNameWithLoc> vjp,
14481448
TrailingWhereClause *clause)
14491449
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1450-
linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
1450+
Linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
14511451
VJP(std::move(vjp)), WhereClause(clause) {
14521452
std::copy(params.begin(), params.end(),
14531453
getTrailingObjects<ParsedAutoDiffParameter>());
@@ -1461,7 +1461,7 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
14611461
Optional<DeclNameWithLoc> vjp,
14621462
GenericSignature *derivativeGenSig)
14631463
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1464-
linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
1464+
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
14651465
ParameterIndices(indices) {
14661466
setDerivativeGenericSignature(context, derivativeGenSig);
14671467
}
@@ -1530,7 +1530,7 @@ DifferentiatingAttr::DifferentiatingAttr(
15301530
DeclNameWithLoc original, bool linear,
15311531
ArrayRef<ParsedAutoDiffParameter> params)
15321532
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1533-
Original(std::move(original)), linear(linear),
1533+
Original(std::move(original)), Linear(linear),
15341534
NumParsedParameters(params.size()) {
15351535
std::copy(params.begin(), params.end(),
15361536
getTrailingObjects<ParsedAutoDiffParameter>());
@@ -1540,7 +1540,7 @@ DifferentiatingAttr::DifferentiatingAttr(
15401540
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
15411541
DeclNameWithLoc original, bool linear, IndexSubset *indices)
15421542
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1543-
Original(std::move(original)), linear(linear), ParameterIndices(indices) {
1543+
Original(std::move(original)), Linear(linear), ParameterIndices(indices) {
15441544
}
15451545

15461546
DifferentiatingAttr *

lib/SILGen/SILGen.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
770770
paramIndices, origFnType);
771771
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
772772
assert(silDiffAttr->getIndices() == indices &&
773-
"Expected matching @differentiable and [differentiable]");
773+
"Expected matching @differentiable and [differentiable] indices");
774774

775775
auto lookUpConformance = LookUpConformanceInModule(M.getSwiftModule());
776776
auto expectedJVPType = origSilFnType->getAutoDiffDerivativeFunctionType(
@@ -875,10 +875,6 @@ void SILGenModule::emitAbstractFuncDecl(AbstractFunctionDecl *AFD) {
875875
if (!hasFunction(thunk))
876876
emitNativeToForeignThunk(thunk);
877877
}
878-
879-
// TODO: Handle SILGen for `@differentiating` attribute.
880-
// Tentative solution: SILGen derivative function normally but also emit
881-
// mangled redirection thunk for retroactive differentiation.
882878
}
883879

884880
void SILGenModule::emitFunction(FuncDecl *fd) {

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ struct DifferentiationInvoker {
353353
Kind kind;
354354
union Value {
355355
/// The instruction associated with the `DifferentiableFunctionInst` case.
356-
DifferentiableFunctionInst *adFuncInst;
357-
Value(DifferentiableFunctionInst *inst) : adFuncInst(inst) {}
356+
DifferentiableFunctionInst *diffFuncInst;
357+
Value(DifferentiableFunctionInst *inst) : diffFuncInst(inst) {}
358358

359359
/// The parent `apply` instruction and `[differentiable]` attribute
360360
/// associated with the `IndirectDifferentiation` case.
@@ -385,7 +385,7 @@ struct DifferentiationInvoker {
385385

386386
DifferentiableFunctionInst *getDifferentiableFunctionInst() const {
387387
assert(kind == Kind::DifferentiableFunctionInst);
388-
return value.adFuncInst;
388+
return value.diffFuncInst;
389389
}
390390

391391
std::pair<ApplyInst *, SILDifferentiableAttr *>

stdlib/public/core/DifferentiationSupport.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic {
919919
@differentiating(+)
920920
@usableFromInline internal static func _jvpAdd(
921921
lhs: AnyDerivative, rhs: AnyDerivative
922-
) -> (value: AnyDerivative,
922+
) -> (value: AnyDerivative,
923923
differential: (AnyDerivative, AnyDerivative) -> (AnyDerivative)) {
924924
return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs })
925925
}

0 commit comments

Comments
 (0)