Skip to content

[AutoDiff] Gardening. #29038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1722,7 +1722,7 @@ class DifferentiableAttr final
// SWIFT_ENABLE_TENSORFLOW END
/// Whether this function is linear (optional).
bool Linear;
/// The number of parsed parameters specified in 'wrt:'.
/// The number of parsed differentiability parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The JVP function.
Optional<DeclNameRefWithLoc> JVP;
Expand All @@ -1737,7 +1737,7 @@ class DifferentiableAttr final
// SWIFT_ENABLE_TENSORFLOW
// NOTE: Parameter indices requestification is done on `tensorflow` branch but
// has not yet been upstreamed to `master` branch.
/// The differentiation parameters' indices, resolved by the type checker.
/// The differentiability parameter indices, resolved by the type checker.
/// The bit stores whether the parameter indices have been computed.
llvm::PointerIntPair<IndexSubset *, 1, bool> ParameterIndicesAndBit;
// SWIFT_ENABLE_TENSORFLOW END
Expand Down Expand Up @@ -1801,7 +1801,7 @@ class DifferentiableAttr final
void setParameterIndices(IndexSubset *paramIndices);
// SWIFT_ENABLE_TENSORFLOW END

/// The parsed differentiation parameters, i.e. the list of parameters
/// The parsed differentiability parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
Expand Down Expand Up @@ -1852,15 +1852,15 @@ class DifferentiableAttr final
///
/// The `@derivative(of:)` attribute also has an optional `wrt:` clause
/// specifying the parameters that are differentiated "with respect to", i.e.
/// the differentiation parameters. The differentiation parameters must conform
/// to the `Differentiable` protocol.
/// the differentiability parameters. The differentiability parameters must
/// conform to the `Differentiable` protocol.
///
/// If the `wrt:` clause is unspecified, the differentiation parameters are
/// If the `wrt:` clause is unspecified, the differentiability parameters are
/// inferred to be all parameters that conform to `Differentiable`.
///
/// `@derivative(of:)` attribute type-checking verifies that the type of the
/// derivative function declaration is consistent with the type of the
/// referenced original declaration and the differentiation parameters.
/// referenced original declaration and the differentiability parameters.
///
/// Examples:
/// @derivative(of: sin(_:))
Expand All @@ -1879,9 +1879,9 @@ class DerivativeAttr final
DeclNameRefWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
/// The number of parsed differentiability parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
/// The differentiability parameter indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The derivative function kind (JVP or VJP), resolved by the type checker.
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
Expand Down Expand Up @@ -1924,7 +1924,7 @@ class DerivativeAttr final
}
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }

/// The parsed differentiation parameters, i.e. the list of parameters
/// The parsed differentiability parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
Expand Down Expand Up @@ -1958,7 +1958,7 @@ using DifferentiatingAttr = DerivativeAttr;
/// computed property declaration.
///
/// The `@transpose(of:)` attribute also has a `wrt:` clause specifying the
/// parameters that are transposed "with respect to", i.e. the transposed
/// parameters that are transposed "with respect to", i.e. the linearity
/// parameters.
///
/// Examples:
Expand All @@ -1978,9 +1978,9 @@ class TransposeAttr final
DeclNameRefWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
/// The number of parsed linearity parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The transposed parameters' indices, resolved by the type checker.
/// The linearity parameter indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;

explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
Expand Down Expand Up @@ -2013,7 +2013,7 @@ class TransposeAttr final
OriginalFunction = decl;
}

/// The parsed transposed parameters, i.e. the list of parameters specified in
/// The parsed linearity parameters, i.e. the list of parameters specified in
/// 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
Expand Down
11 changes: 6 additions & 5 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3035,18 +3035,19 @@ NOTE(derivative_attr_duplicate_note,none,
"other attribute declared here", ())

// @transpose
ERROR(transpose_params_clause_param_not_differentiable,none,
"can only transpose with respect to parameters that conform to "
"'Differentiable' and where '%0 == %0.TangentVector'", (StringRef))
ERROR(transpose_attr_invalid_linearity_parameter_or_result,none,
"cannot transpose with respect to original %select{result|parameter}1 "
"'%0' that does not conform to 'Differentiable' and satisfy "
"'%0 == %0.TangentVector'", (StringRef, /*isParameter*/ bool))
ERROR(transpose_attr_overload_not_found,none,
"could not find function %0 with expected type %1", (DeclName, Type))
ERROR(transpose_attr_cannot_use_named_wrt_params,none,
"cannot use named 'wrt' parameters in '@transpose(of:)' attribute, found "
"%0", (Identifier))
ERROR(transpose_func_wrt_self_must_be_static,none,
ERROR(transpose_attr_wrt_self_must_be_static,none,
"the transpose of an instance method must be a 'static' method in the "
"same type when 'self' is a linearity parameter", ())
NOTE(transpose_func_wrt_self_self_type_mismatch_note,none,
NOTE(transpose_attr_wrt_self_self_type_mismatch_note,none,
"the transpose is declared in %0 but the original function is declared in "
"%1", (Type, Type))

Expand Down
8 changes: 4 additions & 4 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3235,12 +3235,12 @@ class AnyFunctionType : public TypeBase {
GenericSignature whereClauseGenericSignature = GenericSignature(),
bool makeSelfParamFirst = false);

/// Given the type of an autodiff derivative function, returns the
/// Given that `this` is an autodiff derivative function type, returns the
/// corresponding original function type.
AnyFunctionType *getAutoDiffOriginalFunctionType();
AnyFunctionType *getDerivativeOriginalFunctionType();

/// Given the type of a transpose function, returns the corresponding original
/// function type.
/// Given that `this` is an autodiff transpose function type, returns the
/// corresponding original function type.
AnyFunctionType *
getTransposeOriginalFunctionType(IndexSubset *wrtParamIndices, bool wrtSelf);

Expand Down
11 changes: 6 additions & 5 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -1002,12 +1002,13 @@ class Parser {
Optional<DeclNameRefWithLoc> &vjpSpec,
TrailingWhereClause *&whereClause);

/// Parse a differentiation parameters clause, i.e. the 'wrt:' clause in
/// `@differentiable` and `@derivative` attributes.
/// Parse a differentiability parameters clause, i.e. the 'wrt:' clause in
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
///
/// If `allowNamedParameters` is false, allow only index parameters and
/// 'self'.
bool parseDifferentiationParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName,
/// 'self'. Used for `@transpose` attributes.
bool parseDifferentiabilityParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &parameters, StringRef attrName,
bool allowNamedParameters = true);

/// Parse the @derivative attribute.
Expand Down
6 changes: 3 additions & 3 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,9 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
Printer.printNewline();
}

/// The kind of a differentiation parameter in a `wrt:` differentiation
/// parameters clause: differentiability or linearity. Used for printing
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
/// The kind of a parameter in a `wrt:` differentiation parameters clause:
/// either a differentiability parameter or a linearity parameter. Used for
/// printing `@differentiable`, `@derivative`, and `@transpose` attributes.
enum class DifferentiationParameterKind {
/// A differentiability parameter, printed by name.
/// Used for `@differentiable` and `@derivative` attribute.
Expand Down
69 changes: 53 additions & 16 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4914,8 +4914,9 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(

auto &ctx = getASTContext();

SmallVector<Type, 8> wrtParamTypes;
autodiff::getSubsetParameterTypes(indices, this, wrtParamTypes,
// Get differentiability parameter types.
SmallVector<Type, 8> diffParamTypes;
autodiff::getSubsetParameterTypes(indices, this, diffParamTypes,
/*reverseCurryLevels*/ !makeSelfParamFirst);

// Unwrap curry levels. At most, two parameter lists are necessary, for
Expand All @@ -4940,11 +4941,10 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(
// closure is the JVP "differential":
// (T.TangentVector...) -> (R.TangentVector...)
SmallVector<AnyFunctionType::Param, 8> differentialParams;
for (auto wrtParamType : wrtParamTypes)
differentialParams.push_back(
AnyFunctionType::Param(
wrtParamType->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getType()));
for (auto diffParamType : diffParamTypes)
differentialParams.push_back(AnyFunctionType::Param(
diffParamType->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getType()));

SmallVector<TupleTypeElt, 8> differentialResults;
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
Expand Down Expand Up @@ -4984,9 +4984,9 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(
}

SmallVector<TupleTypeElt, 8> pullbackResults;
for (auto wrtParamType : wrtParamTypes)
pullbackResults.push_back(wrtParamType
->getAutoDiffAssociatedTangentSpace(lookupConformance)
for (auto diffParamType : diffParamTypes)
pullbackResults.push_back(
diffParamType->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getType());
Type pullbackResult = pullbackResults.size() > 1
? TupleType::get(pullbackResults, ctx)
Expand Down Expand Up @@ -5068,10 +5068,47 @@ makeFunctionType(ArrayRef<AnyFunctionType::Param> params, Type retTy,
return FunctionType::get(params, retTy);
}

// Compute the original function type corresponding to the given transpose
// function type.
/// Given that `this` is an autodiff derivative function type, returns the
/// corresponding original function type.
AnyFunctionType *AnyFunctionType::getDerivativeOriginalFunctionType() {
// Unwrap curry levels. At most, two parameter lists are necessary, for
// curried method types with a `(Self)` parameter list.
SmallVector<AnyFunctionType *, 2> curryLevels;
auto *currentLevel = this;
for (unsigned i : range(2)) {
(void)i;
if (currentLevel == nullptr)
break;
curryLevels.push_back(currentLevel);
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
}

auto derivativeResult = curryLevels.back()->getResult()->getAs<TupleType>();
assert(derivativeResult && derivativeResult->getNumElements() == 2 &&
"Expected derivative result to be a two-element tuple");
auto originalResult = derivativeResult->getElement(0).getType();
auto *originalType = makeFunctionType(
curryLevels.back(), curryLevels.back()->getParams(), originalResult,
curryLevels.size() == 1 ? getOptGenericSignature() : nullptr);

// Wrap the derivative function type in additional curry levels.
auto curryLevelsWithoutLast =
ArrayRef<AnyFunctionType *>(curryLevels).drop_back(1);
for (auto pair : enumerate(llvm::reverse(curryLevelsWithoutLast))) {
unsigned i = pair.index();
AnyFunctionType *curryLevel = pair.value();
originalType = makeFunctionType(
curryLevel, curryLevel->getParams(), originalType,
i == curryLevelsWithoutLast.size() - 1 ? getOptGenericSignature()
: nullptr);
}
return originalType;
}

/// Given that `this` is an autodiff transpose function type, returns the
/// corresponding original function type.
AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
IndexSubset *wrtParamIndices, bool wrtSelf) {
IndexSubset *linearParamIndices, bool wrtSelf) {
unsigned transposeParamsIndex = 0;
bool isCurried = getResult()->is<AnyFunctionType>();

Expand Down Expand Up @@ -5121,18 +5158,18 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
// - The number of original transposed parameters.
// - This is the number of linearity parameters.
unsigned originalParameterCount =
transposeParams.size() - 1 + wrtParamIndices->getNumIndices();
transposeParams.size() - 1 + linearParamIndices->getNumIndices();
// Iterate over all original parameter indices.
for (auto i : range(originalParameterCount)) {
// Skip `self` parameter if `self` is a linearity parameter.
// The `self` is handled specially later to form a curried function type.
bool isSelfParameterAndWrtSelf =
wrtSelf && i == wrtParamIndices->getCapacity() - 1;
wrtSelf && i == linearParamIndices->getCapacity() - 1;
if (isSelfParameterAndWrtSelf)
continue;
// If `i` is a linearity parameter index, the next original parameter is
// the next transpose result.
if (wrtParamIndices->contains(i)) {
if (linearParamIndices->contains(i)) {
auto resultType =
transposeResultTypes[transposeResultTypesIndex++].getType();
originalParams.push_back(AnyFunctionType::Param(resultType));
Expand Down
Loading