Skip to content

Commit 88a7d61

Browse files
bartchr808dan-zheng
andcommitted
[AutoDiff] Fix transpose type checking. (#28950)
- TF-1060: fix `@transpose` typing rules for instance methods wrt `self`. - The transpose of an instance method wrt `self` is now a static method in the same type context. - Nice new invariant: transpose functions are always declared in the same type context as the original function. - TF-997: support `@transpose` for initializer original declarations. Add TF-1063 negative test: `@transpose` type-checking crash for static methods. Co-authored-by: Dan Zheng <[email protected]>
1 parent 155b5ed commit 88a7d61

File tree

8 files changed

+353
-224
lines changed

8 files changed

+353
-224
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3043,9 +3043,12 @@ ERROR(transpose_attr_overload_not_found,none,
30433043
ERROR(transpose_attr_cannot_use_named_wrt_params,none,
30443044
"cannot use named 'wrt' parameters in '@transpose(of:)' attribute, found "
30453045
"%0", (Identifier))
3046-
ERROR(transpose_attr_result_value_not_differentiable,none,
3047-
"'@transpose(of:)' attribute requires original function result %0 to "
3048-
"conform to 'Differentiable'", (Type))
3046+
ERROR(transpose_func_wrt_self_must_be_static,none,
3047+
"the transpose of an instance method must be a 'static' method in the "
3048+
"same type when 'self' is a linearity parameter", ())
3049+
NOTE(transpose_func_wrt_self_self_type_mismatch_note,none,
3050+
"the transpose is declared in %0 but the original function is declared in "
3051+
"%1", (Type, Type))
30493052

30503053
// Automatic differentiation attributes
30513054
ERROR(autodiff_attr_original_decl_invalid_kind,none,

lib/AST/Attr.cpp

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -372,51 +372,60 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
372372
Printer.printNewline();
373373
}
374374

375-
/// Printing style for a differentiation parameter in a `wrt:` differentiation
376-
/// parameters clause. Used for printing `@differentiable`, `@derivative`, and
377-
/// `@transpose` attributes.
378-
enum class DifferentiationParameterPrintingStyle {
379-
/// Print parameter by name.
375+
/// The kind of a differentiation parameter in a `wrt:` differentiation
376+
/// parameters clause: differentiability or linearity. Used for printing
377+
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
378+
enum class DifferentiationParameterKind {
379+
/// A differentiability parameter, printed by name.
380380
/// Used for `@differentiable` and `@derivative` attribute.
381-
Name,
382-
/// Print parameter by index.
381+
Differentiability,
382+
/// A linearity parameter, printed by index.
383383
/// Used for `@transpose` attribute.
384-
Index
384+
Linearity
385385
};
386386

387387
/// Returns the differentiation parameters clause string for the given function,
388-
/// parameter indices, parsed parameters, . Use the parameter indices if
389-
/// specified; otherwise, use the parsed parameters.
388+
/// parameter indices, parsed parameters, and differentiation parameter kind.
389+
/// Use the parameter indices if specified; otherwise, use the parsed
390+
/// parameters.
390391
static std::string getDifferentiationParametersClauseString(
391-
const AbstractFunctionDecl *function, IndexSubset *paramIndices,
392+
const AbstractFunctionDecl *function, IndexSubset *parameterIndices,
392393
ArrayRef<ParsedAutoDiffParameter> parsedParams,
393-
DifferentiationParameterPrintingStyle style) {
394+
DifferentiationParameterKind parameterKind) {
394395
assert(function);
395396
bool isInstanceMethod = function->isInstanceMember();
397+
bool isStaticMethod = function->isStatic();
396398
std::string result;
397399
llvm::raw_string_ostream printer(result);
398400

399401
// Use the parameter indices, if specified.
400-
if (paramIndices) {
401-
auto parameters = paramIndices->getBitVector();
402+
if (parameterIndices) {
403+
auto parameters = parameterIndices->getBitVector();
402404
auto parameterCount = parameters.count();
403405
printer << "wrt: ";
404406
if (parameterCount > 1)
405407
printer << '(';
406408
// Check if differentiating wrt `self`. If so, manually print it first.
407-
if (isInstanceMethod && parameters.test(parameters.size() - 1)) {
409+
bool isWrtSelf =
410+
(isInstanceMethod ||
411+
(isStaticMethod &&
412+
parameterKind == DifferentiationParameterKind::Linearity)) &&
413+
parameters.test(parameters.size() - 1);
414+
if (isWrtSelf) {
408415
parameters.reset(parameters.size() - 1);
409416
printer << "self";
410417
if (parameters.any())
411418
printer << ", ";
412419
}
413420
// Print remaining differentiation parameters.
414421
interleave(parameters.set_bits(), [&](unsigned index) {
415-
switch (style) {
416-
case DifferentiationParameterPrintingStyle::Name:
422+
switch (parameterKind) {
423+
// Print differentiability parameters by name.
424+
case DifferentiationParameterKind::Differentiability:
417425
printer << function->getParameters()->get(index)->getName().str();
418426
break;
419-
case DifferentiationParameterPrintingStyle::Index:
427+
// Print linearity parameters by index.
428+
case DifferentiationParameterKind::Linearity:
420429
printer << index;
421430
break;
422431
}
@@ -493,7 +502,7 @@ static void printDifferentiableAttrArguments(
493502
if (!omitWrtClause) {
494503
auto diffParamsString = getDifferentiationParametersClauseString(
495504
original, attr->getParameterIndices(), attr->getParsedParameters(),
496-
DifferentiationParameterPrintingStyle::Name);
505+
DifferentiationParameterKind::Differentiability);
497506
// Check whether differentiation parameter clause is empty.
498507
// Handles edge case where resolved parameter indices are unset and
499508
// parsed parameters are empty. This case should never trigger for
@@ -933,7 +942,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
933942
auto *derivative = cast<AbstractFunctionDecl>(D);
934943
auto diffParamsString = getDifferentiationParametersClauseString(
935944
derivative, attr->getParameterIndices(), attr->getParsedParameters(),
936-
DifferentiationParameterPrintingStyle::Name);
945+
DifferentiationParameterKind::Differentiability);
937946
if (!diffParamsString.empty())
938947
Printer << ", " << diffParamsString;
939948
Printer << ')';
@@ -948,7 +957,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
948957
auto *transpose = cast<AbstractFunctionDecl>(D);
949958
auto transParamsString = getDifferentiationParametersClauseString(
950959
transpose, attr->getParameterIndices(), attr->getParsedParameters(),
951-
DifferentiationParameterPrintingStyle::Index);
960+
DifferentiationParameterKind::Linearity);
952961
if (!transParamsString.empty())
953962
Printer << ", " << transParamsString;
954963
Printer << ')';

lib/AST/Type.cpp

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5075,7 +5075,7 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
50755075
unsigned transposeParamsIndex = 0;
50765076
bool isCurried = getResult()->is<AnyFunctionType>();
50775077

5078-
// Get the original function's result.
5078+
// Get the transpose function's parameters and result type.
50795079
auto transposeParams = getParams();
50805080
auto transposeResult = getResult();
50815081
if (isCurried) {
@@ -5084,79 +5084,81 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
50845084
transposeResult = methodType->getResult();
50855085
}
50865086

5087-
Type originalResult;
5088-
if (isCurried) {
5089-
// If it's curried, then the first parameter in the curried type, which is
5090-
// the 'Self' type, is the original result (no matter if we are
5091-
// transposing wrt self or not).
5092-
originalResult = getParams().front().getPlainType();
5093-
} else {
5094-
// If it's not curried, the last parameter, the tangent, is always the
5095-
// original result type as we require the last parameter of the transposing
5096-
// function to be the original result.
5097-
originalResult = transposeParams.back().getPlainType();
5098-
}
5099-
assert(originalResult);
5087+
// Get the original function's result type.
5088+
// The original result type is always equal to the type of the last
5089+
// parameter of the transpose function type.
5090+
auto originalResult = transposeParams.back().getPlainType();
51005091

5092+
// Get transposed result types.
5093+
// The transpose function result type may be a singular type or a tuple type.
51015094
SmallVector<TupleTypeElt, 4> transposeResultTypes;
5102-
// Return type of transpose function can be a singular type or a tuple type.
51035095
if (auto transposeResultTupleType = transposeResult->getAs<TupleType>()) {
51045096
transposeResultTypes.append(transposeResultTupleType->getElements().begin(),
51055097
transposeResultTupleType->getElements().end());
51065098
} else {
51075099
transposeResultTypes.push_back(transposeResult);
51085100
}
5109-
assert(!transposeResultTypes.empty());
51105101

5111-
// If the function is curried and is transposing wrt 'self', then grab
5112-
// the type from the result list (guaranteed to be the first since 'self'
5113-
// is first in wrt list) and remove it. If it is still curried but not
5114-
// transposing wrt 'self', then the 'Self' type is the first parameter
5115-
// in the method.
5102+
// Get the `Self` type, if the transpose function type is curried.
5103+
// - If `self` is a linearity parameter, use the first transpose result type.
5104+
// - Otherwise, use the first transpose parameter type.
51165105
unsigned transposeResultTypesIndex = 0;
51175106
Type selfType;
51185107
if (isCurried && wrtSelf) {
51195108
selfType = transposeResultTypes.front().getType();
51205109
transposeResultTypesIndex++;
51215110
} else if (isCurried) {
5122-
selfType = transposeParams.front().getPlainType();
5123-
transposeParamsIndex++;
5111+
selfType = getParams().front().getPlainType();
51245112
}
51255113

5114+
// Get the original function's parameters.
51265115
SmallVector<AnyFunctionType::Param, 8> originalParams;
5116+
// The number of original parameters is equal to the sum of:
5117+
// - The number of original non-transposed parameters.
5118+
// - This is the number of transpose parameters minus one. All transpose
5119+
// parameters come from the original function, except the last parameter
5120+
// (the transposed original result).
5121+
// - The number of original transposed parameters.
5122+
// - This is the number of linearity parameters.
51275123
unsigned originalParameterCount =
5128-
transposeParams.size() + wrtParamIndices->getNumIndices() - 1;
5124+
transposeParams.size() - 1 + wrtParamIndices->getNumIndices();
5125+
// Iterate over all original parameter indices.
51295126
for (auto i : range(originalParameterCount)) {
5130-
// Need to check if it is the 'self' param since we handle it differently
5131-
// above.
5132-
bool lookingAtSelf = (i == wrtParamIndices->getCapacity() - 1) && wrtSelf;
5133-
if (wrtParamIndices->contains(i) && !lookingAtSelf) {
5134-
// If in wrt list, the item in the result tuple must be a parameter in the
5135-
// original function.
5127+
// Skip `self` parameter if `self` is a linearity parameter.
5128+
// The `self` is handled specially later to form a curried function type.
5129+
bool isSelfParameterAndWrtSelf =
5130+
wrtSelf && i == wrtParamIndices->getCapacity() - 1;
5131+
if (isSelfParameterAndWrtSelf)
5132+
continue;
5133+
// If `i` is a linearity parameter index, the next original parameter is
5134+
// the next transpose result.
5135+
if (wrtParamIndices->contains(i)) {
51365136
auto resultType =
5137-
transposeResultTypes[transposeResultTypesIndex].getType();
5137+
transposeResultTypes[transposeResultTypesIndex++].getType();
51385138
originalParams.push_back(AnyFunctionType::Param(resultType));
5139-
transposeResultTypesIndex++;
5140-
} else {
5141-
// Else if not in the wrt list, the parameter in the transposing function
5142-
// is a parameter in the original function.
5143-
originalParams.push_back(transposeParams[transposeParamsIndex]);
5144-
transposeParamsIndex++;
5139+
}
5140+
// Otherwise, the next original parameter is the next transpose parameter.
5141+
else {
5142+
originalParams.push_back(transposeParams[transposeParamsIndex++]);
51455143
}
51465144
}
51475145

5146+
// Compute the original function type.
51485147
AnyFunctionType *originalType;
5148+
// If the transpose type is curried, the original function type is:
5149+
// `(Self) -> (<original parameters>) -> <original result>`.
51495150
if (isCurried) {
5150-
assert(selfType);
5151-
// If curried, wrap the function into the 'Self' type to get a method.
5151+
assert(selfType && "`Self` type should be resolved");
51525152
originalType = makeFunctionType(originalParams, originalResult, nullptr);
51535153
originalType = makeFunctionType(AnyFunctionType::Param(selfType),
51545154
originalType, getOptGenericSignature());
5155-
} else {
5155+
}
5156+
// Otherwise, the original function type is simply:
5157+
// `(<original parameters>) -> <original result>`.
5158+
else {
51565159
originalType = makeFunctionType(originalParams, originalResult,
51575160
getOptGenericSignature());
51585161
}
5159-
assert(originalType);
51605162
return originalType;
51615163
}
51625164

0 commit comments

Comments
 (0)