Skip to content

Commit 239dd3b

Browse files
authored
[AutoDiff] Gardening. (#29038)
- Clean up AutoDiff attribute parsing, printing, type-checking. - Rename "wrt parameters" to "differentiability/linearity parameters". - Improve `@transpose` attribute diagnostics. - Generalize diagnostic for `@transpose` attribute invalid original result. - Uncomment tests in test/AutoDiff/transpose_attr_type_checking.swift. NFC except `@transpose` attribute diagnostic message changes.
1 parent 88a7d61 commit 239dd3b

File tree

13 files changed

+320
-336
lines changed

13 files changed

+320
-336
lines changed

include/swift/AST/Attr.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,7 +1722,7 @@ class DifferentiableAttr final
17221722
// SWIFT_ENABLE_TENSORFLOW END
17231723
/// Whether this function is linear (optional).
17241724
bool Linear;
1725-
/// The number of parsed parameters specified in 'wrt:'.
1725+
/// The number of parsed differentiability parameters specified in 'wrt:'.
17261726
unsigned NumParsedParameters = 0;
17271727
/// The JVP function.
17281728
Optional<DeclNameRefWithLoc> JVP;
@@ -1737,7 +1737,7 @@ class DifferentiableAttr final
17371737
// SWIFT_ENABLE_TENSORFLOW
17381738
// NOTE: Parameter indices requestification is done on `tensorflow` branch but
17391739
// has not yet been upstreamed to `master` branch.
1740-
/// The differentiation parameters' indices, resolved by the type checker.
1740+
/// The differentiability parameter indices, resolved by the type checker.
17411741
/// The bit stores whether the parameter indices have been computed.
17421742
llvm::PointerIntPair<IndexSubset *, 1, bool> ParameterIndicesAndBit;
17431743
// SWIFT_ENABLE_TENSORFLOW END
@@ -1801,7 +1801,7 @@ class DifferentiableAttr final
18011801
void setParameterIndices(IndexSubset *paramIndices);
18021802
// SWIFT_ENABLE_TENSORFLOW END
18031803

1804-
/// The parsed differentiation parameters, i.e. the list of parameters
1804+
/// The parsed differentiability parameters, i.e. the list of parameters
18051805
/// specified in 'wrt:'.
18061806
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
18071807
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
@@ -1852,15 +1852,15 @@ class DifferentiableAttr final
18521852
///
18531853
/// The `@derivative(of:)` attribute also has an optional `wrt:` clause
18541854
/// specifying the parameters that are differentiated "with respect to", i.e.
1855-
/// the differentiation parameters. The differentiation parameters must conform
1856-
/// to the `Differentiable` protocol.
1855+
/// the differentiability parameters. The differentiability parameters must
1856+
/// conform to the `Differentiable` protocol.
18571857
///
1858-
/// If the `wrt:` clause is unspecified, the differentiation parameters are
1858+
/// If the `wrt:` clause is unspecified, the differentiability parameters are
18591859
/// inferred to be all parameters that conform to `Differentiable`.
18601860
///
18611861
/// `@derivative(of:)` attribute type-checking verifies that the type of the
18621862
/// derivative function declaration is consistent with the type of the
1863-
/// referenced original declaration and the differentiation parameters.
1863+
/// referenced original declaration and the differentiability parameters.
18641864
///
18651865
/// Examples:
18661866
/// @derivative(of: sin(_:))
@@ -1879,9 +1879,9 @@ class DerivativeAttr final
18791879
DeclNameRefWithLoc OriginalFunctionName;
18801880
/// The original function declaration, resolved by the type checker.
18811881
AbstractFunctionDecl *OriginalFunction = nullptr;
1882-
/// The number of parsed parameters specified in 'wrt:'.
1882+
/// The number of parsed differentiability parameters specified in 'wrt:'.
18831883
unsigned NumParsedParameters = 0;
1884-
/// The differentiation parameters' indices, resolved by the type checker.
1884+
/// The differentiability parameter indices, resolved by the type checker.
18851885
IndexSubset *ParameterIndices = nullptr;
18861886
/// The derivative function kind (JVP or VJP), resolved by the type checker.
18871887
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
@@ -1924,7 +1924,7 @@ class DerivativeAttr final
19241924
}
19251925
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }
19261926

1927-
/// The parsed differentiation parameters, i.e. the list of parameters
1927+
/// The parsed differentiability parameters, i.e. the list of parameters
19281928
/// specified in 'wrt:'.
19291929
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
19301930
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
@@ -1958,7 +1958,7 @@ using DifferentiatingAttr = DerivativeAttr;
19581958
/// computed property declaration.
19591959
///
19601960
/// The `@transpose(of:)` attribute also has a `wrt:` clause specifying the
1961-
/// parameters that are transposed "with respect to", i.e. the transposed
1961+
/// parameters that are transposed "with respect to", i.e. the linearity
19621962
/// parameters.
19631963
///
19641964
/// Examples:
@@ -1978,9 +1978,9 @@ class TransposeAttr final
19781978
DeclNameRefWithLoc OriginalFunctionName;
19791979
/// The original function declaration, resolved by the type checker.
19801980
AbstractFunctionDecl *OriginalFunction = nullptr;
1981-
/// The number of parsed parameters specified in 'wrt:'.
1981+
/// The number of parsed linearity parameters specified in 'wrt:'.
19821982
unsigned NumParsedParameters = 0;
1983-
/// The transposed parameters' indices, resolved by the type checker.
1983+
/// The linearity parameter indices, resolved by the type checker.
19841984
IndexSubset *ParameterIndices = nullptr;
19851985

19861986
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
@@ -2013,7 +2013,7 @@ class TransposeAttr final
20132013
OriginalFunction = decl;
20142014
}
20152015

2016-
/// The parsed transposed parameters, i.e. the list of parameters specified in
2016+
/// The parsed linearity parameters, i.e. the list of parameters specified in
20172017
/// 'wrt:'.
20182018
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
20192019
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};

include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3035,18 +3035,19 @@ NOTE(derivative_attr_duplicate_note,none,
30353035
"other attribute declared here", ())
30363036

30373037
// @transpose
3038-
ERROR(transpose_params_clause_param_not_differentiable,none,
3039-
"can only transpose with respect to parameters that conform to "
3040-
"'Differentiable' and where '%0 == %0.TangentVector'", (StringRef))
3038+
ERROR(transpose_attr_invalid_linearity_parameter_or_result,none,
3039+
"cannot transpose with respect to original %select{result|parameter}1 "
3040+
"'%0' that does not conform to 'Differentiable' and satisfy "
3041+
"'%0 == %0.TangentVector'", (StringRef, /*isParameter*/ bool))
30413042
ERROR(transpose_attr_overload_not_found,none,
30423043
"could not find function %0 with expected type %1", (DeclName, Type))
30433044
ERROR(transpose_attr_cannot_use_named_wrt_params,none,
30443045
"cannot use named 'wrt' parameters in '@transpose(of:)' attribute, found "
30453046
"%0", (Identifier))
3046-
ERROR(transpose_func_wrt_self_must_be_static,none,
3047+
ERROR(transpose_attr_wrt_self_must_be_static,none,
30473048
"the transpose of an instance method must be a 'static' method in the "
30483049
"same type when 'self' is a linearity parameter", ())
3049-
NOTE(transpose_func_wrt_self_self_type_mismatch_note,none,
3050+
NOTE(transpose_attr_wrt_self_self_type_mismatch_note,none,
30503051
"the transpose is declared in %0 but the original function is declared in "
30513052
"%1", (Type, Type))
30523053

include/swift/AST/Types.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3235,12 +3235,12 @@ class AnyFunctionType : public TypeBase {
32353235
GenericSignature whereClauseGenericSignature = GenericSignature(),
32363236
bool makeSelfParamFirst = false);
32373237

3238-
/// Given the type of an autodiff derivative function, returns the
3238+
/// Given that `this` is an autodiff derivative function type, returns the
32393239
/// corresponding original function type.
3240-
AnyFunctionType *getAutoDiffOriginalFunctionType();
3240+
AnyFunctionType *getDerivativeOriginalFunctionType();
32413241

3242-
/// Given the type of a transpose function, returns the corresponding original
3243-
/// function type.
3242+
/// Given that `this` is an autodiff transpose function type, returns the
3243+
/// corresponding original function type.
32443244
AnyFunctionType *
32453245
getTransposeOriginalFunctionType(IndexSubset *wrtParamIndices, bool wrtSelf);
32463246

include/swift/Parse/Parser.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,12 +1002,13 @@ class Parser {
10021002
Optional<DeclNameRefWithLoc> &vjpSpec,
10031003
TrailingWhereClause *&whereClause);
10041004

1005-
/// Parse a differentiation parameters clause, i.e. the 'wrt:' clause in
1006-
/// `@differentiable` and `@derivative` attributes.
1005+
/// Parse a differentiability parameters clause, i.e. the 'wrt:' clause in
1006+
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
1007+
///
10071008
/// If `allowNamedParameters` is false, allow only index parameters and
1008-
/// 'self'.
1009-
bool parseDifferentiationParametersClause(
1010-
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName,
1009+
/// 'self'. Used for `@transpose` attributes.
1010+
bool parseDifferentiabilityParametersClause(
1011+
SmallVectorImpl<ParsedAutoDiffParameter> &parameters, StringRef attrName,
10111012
bool allowNamedParameters = true);
10121013

10131014
/// Parse the @derivative attribute.

lib/AST/Attr.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,9 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
372372
Printer.printNewline();
373373
}
374374

375-
/// The kind of a differentiation parameter in a `wrt:` differentiation
376-
/// parameters clause: differentiability or linearity. Used for printing
377-
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
375+
/// The kind of a parameter in a `wrt:` differentiation parameters clause:
376+
/// either a differentiability parameter or a linearity parameter. Used for
377+
/// printing `@differentiable`, `@derivative`, and `@transpose` attributes.
378378
enum class DifferentiationParameterKind {
379379
/// A differentiability parameter, printed by name.
380380
/// Used for `@differentiable` and `@derivative` attribute.

lib/AST/Type.cpp

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4914,8 +4914,9 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(
49144914

49154915
auto &ctx = getASTContext();
49164916

4917-
SmallVector<Type, 8> wrtParamTypes;
4918-
autodiff::getSubsetParameterTypes(indices, this, wrtParamTypes,
4917+
// Get differentiability parameter types.
4918+
SmallVector<Type, 8> diffParamTypes;
4919+
autodiff::getSubsetParameterTypes(indices, this, diffParamTypes,
49194920
/*reverseCurryLevels*/ !makeSelfParamFirst);
49204921

49214922
// Unwrap curry levels. At most, two parameter lists are necessary, for
@@ -4940,11 +4941,10 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(
49404941
// closure is the JVP "differential":
49414942
// (T.TangentVector...) -> (R.TangentVector...)
49424943
SmallVector<AnyFunctionType::Param, 8> differentialParams;
4943-
for (auto wrtParamType : wrtParamTypes)
4944-
differentialParams.push_back(
4945-
AnyFunctionType::Param(
4946-
wrtParamType->getAutoDiffAssociatedTangentSpace(lookupConformance)
4947-
->getType()));
4944+
for (auto diffParamType : diffParamTypes)
4945+
differentialParams.push_back(AnyFunctionType::Param(
4946+
diffParamType->getAutoDiffAssociatedTangentSpace(lookupConformance)
4947+
->getType()));
49484948

49494949
SmallVector<TupleTypeElt, 8> differentialResults;
49504950
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
@@ -4984,9 +4984,9 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(
49844984
}
49854985

49864986
SmallVector<TupleTypeElt, 8> pullbackResults;
4987-
for (auto wrtParamType : wrtParamTypes)
4988-
pullbackResults.push_back(wrtParamType
4989-
->getAutoDiffAssociatedTangentSpace(lookupConformance)
4987+
for (auto diffParamType : diffParamTypes)
4988+
pullbackResults.push_back(
4989+
diffParamType->getAutoDiffAssociatedTangentSpace(lookupConformance)
49904990
->getType());
49914991
Type pullbackResult = pullbackResults.size() > 1
49924992
? TupleType::get(pullbackResults, ctx)
@@ -5068,10 +5068,47 @@ makeFunctionType(ArrayRef<AnyFunctionType::Param> params, Type retTy,
50685068
return FunctionType::get(params, retTy);
50695069
}
50705070

5071-
// Compute the original function type corresponding to the given transpose
5072-
// function type.
5071+
/// Given that `this` is an autodiff derivative function type, returns the
5072+
/// corresponding original function type.
5073+
AnyFunctionType *AnyFunctionType::getDerivativeOriginalFunctionType() {
5074+
// Unwrap curry levels. At most, two parameter lists are necessary, for
5075+
// curried method types with a `(Self)` parameter list.
5076+
SmallVector<AnyFunctionType *, 2> curryLevels;
5077+
auto *currentLevel = this;
5078+
for (unsigned i : range(2)) {
5079+
(void)i;
5080+
if (currentLevel == nullptr)
5081+
break;
5082+
curryLevels.push_back(currentLevel);
5083+
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
5084+
}
5085+
5086+
auto derivativeResult = curryLevels.back()->getResult()->getAs<TupleType>();
5087+
assert(derivativeResult && derivativeResult->getNumElements() == 2 &&
5088+
"Expected derivative result to be a two-element tuple");
5089+
auto originalResult = derivativeResult->getElement(0).getType();
5090+
auto *originalType = makeFunctionType(
5091+
curryLevels.back(), curryLevels.back()->getParams(), originalResult,
5092+
curryLevels.size() == 1 ? getOptGenericSignature() : nullptr);
5093+
5094+
// Wrap the derivative function type in additional curry levels.
5095+
auto curryLevelsWithoutLast =
5096+
ArrayRef<AnyFunctionType *>(curryLevels).drop_back(1);
5097+
for (auto pair : enumerate(llvm::reverse(curryLevelsWithoutLast))) {
5098+
unsigned i = pair.index();
5099+
AnyFunctionType *curryLevel = pair.value();
5100+
originalType = makeFunctionType(
5101+
curryLevel, curryLevel->getParams(), originalType,
5102+
i == curryLevelsWithoutLast.size() - 1 ? getOptGenericSignature()
5103+
: nullptr);
5104+
}
5105+
return originalType;
5106+
}
5107+
5108+
/// Given that `this` is an autodiff transpose function type, returns the
5109+
/// corresponding original function type.
50735110
AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
5074-
IndexSubset *wrtParamIndices, bool wrtSelf) {
5111+
IndexSubset *linearParamIndices, bool wrtSelf) {
50755112
unsigned transposeParamsIndex = 0;
50765113
bool isCurried = getResult()->is<AnyFunctionType>();
50775114

@@ -5121,18 +5158,18 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
51215158
// - The number of original transposed parameters.
51225159
// - This is the number of linearity parameters.
51235160
unsigned originalParameterCount =
5124-
transposeParams.size() - 1 + wrtParamIndices->getNumIndices();
5161+
transposeParams.size() - 1 + linearParamIndices->getNumIndices();
51255162
// Iterate over all original parameter indices.
51265163
for (auto i : range(originalParameterCount)) {
51275164
// Skip `self` parameter if `self` is a linearity parameter.
51285165
// The `self` is handled specially later to form a curried function type.
51295166
bool isSelfParameterAndWrtSelf =
5130-
wrtSelf && i == wrtParamIndices->getCapacity() - 1;
5167+
wrtSelf && i == linearParamIndices->getCapacity() - 1;
51315168
if (isSelfParameterAndWrtSelf)
51325169
continue;
51335170
// If `i` is a linearity parameter index, the next original parameter is
51345171
// the next transpose result.
5135-
if (wrtParamIndices->contains(i)) {
5172+
if (linearParamIndices->contains(i)) {
51365173
auto resultType =
51375174
transposeResultTypes[transposeResultTypesIndex++].getType();
51385175
originalParams.push_back(AnyFunctionType::Param(resultType));

0 commit comments

Comments
 (0)