Skip to content

[NFC] [AutoDiff] Gardening. #27651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1537,8 +1537,8 @@ class DifferentiableAttr final
ParsedAutoDiffParameter> {
friend TrailingObjects;

/// Whether this function is linear (optional).
bool linear;
/// Whether this function is linear.
bool Linear;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The JVP function.
Expand Down Expand Up @@ -1621,7 +1621,7 @@ class DifferentiableAttr final
return NumParsedParameters;
}

bool isLinear() const { return linear; }
bool isLinear() const { return Linear; }

TrailingWhereClause *getWhereClause() const { return WhereClause; }

Expand Down Expand Up @@ -1676,8 +1676,8 @@ class DifferentiatingAttr final
DeclNameWithLoc Original;
/// The original function, resolved by the type checker.
FuncDecl *OriginalFunction = nullptr;
/// Whether this function is linear (optional).
bool linear;
/// Whether this function is linear.
bool Linear;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
Expand Down Expand Up @@ -1706,7 +1706,7 @@ class DifferentiatingAttr final

DeclNameWithLoc getOriginal() const { return Original; }

bool isLinear() const { return linear; }
bool isLinear() const { return Linear; }

FuncDecl *getOriginalFunction() const { return OriginalFunction; }
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
Expand Down
8 changes: 4 additions & 4 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
Linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
VJP(std::move(vjp)), WhereClause(clause) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
Expand All @@ -1461,7 +1461,7 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
Optional<DeclNameWithLoc> vjp,
GenericSignature *derivativeGenSig)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
ParameterIndices(indices) {
setDerivativeGenericSignature(context, derivativeGenSig);
}
Expand Down Expand Up @@ -1530,7 +1530,7 @@ DifferentiatingAttr::DifferentiatingAttr(
DeclNameWithLoc original, bool linear,
ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
Original(std::move(original)), linear(linear),
Original(std::move(original)), Linear(linear),
NumParsedParameters(params.size()) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
Expand All @@ -1540,7 +1540,7 @@ DifferentiatingAttr::DifferentiatingAttr(
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear, IndexSubset *indices)
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
Original(std::move(original)), linear(linear), ParameterIndices(indices) {
Original(std::move(original)), Linear(linear), ParameterIndices(indices) {
}

DifferentiatingAttr *
Expand Down
6 changes: 1 addition & 5 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
paramIndices, origFnType);
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
assert(silDiffAttr->getIndices() == indices &&
"Expected matching @differentiable and [differentiable]");
"Expected matching @differentiable and [differentiable] indices");

auto lookUpConformance = LookUpConformanceInModule(M.getSwiftModule());
auto expectedJVPType = origSilFnType->getAutoDiffDerivativeFunctionType(
Expand Down Expand Up @@ -875,10 +875,6 @@ void SILGenModule::emitAbstractFuncDecl(AbstractFunctionDecl *AFD) {
if (!hasFunction(thunk))
emitNativeToForeignThunk(thunk);
}

// TODO: Handle SILGen for `@differentiating` attribute.
// Tentative solution: SILGen derivative function normally but also emit
// mangled redirection thunk for retroactive differentiation.
}

void SILGenModule::emitFunction(FuncDecl *fd) {
Expand Down
6 changes: 3 additions & 3 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ struct DifferentiationInvoker {
Kind kind;
union Value {
/// The instruction associated with the `DifferentiableFunctionInst` case.
DifferentiableFunctionInst *adFuncInst;
Value(DifferentiableFunctionInst *inst) : adFuncInst(inst) {}
DifferentiableFunctionInst *diffFuncInst;
Value(DifferentiableFunctionInst *inst) : diffFuncInst(inst) {}

/// The parent `apply` instruction and `[differentiable]` attribute
/// associated with the `IndirectDifferentiation` case.
Expand Down Expand Up @@ -385,7 +385,7 @@ struct DifferentiationInvoker {

DifferentiableFunctionInst *getDifferentiableFunctionInst() const {
assert(kind == Kind::DifferentiableFunctionInst);
return value.adFuncInst;
return value.diffFuncInst;
}

std::pair<ApplyInst *, SILDifferentiableAttr *>
Expand Down
2 changes: 1 addition & 1 deletion stdlib/public/core/DifferentiationSupport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic {
@differentiating(+)
@usableFromInline internal static func _jvpAdd(
lhs: AnyDerivative, rhs: AnyDerivative
) -> (value: AnyDerivative,
) -> (value: AnyDerivative,
differential: (AnyDerivative, AnyDerivative) -> (AnyDerivative)) {
return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs })
}
Expand Down