Skip to content

Commit e4c8c8e

Browse files
authored
[AutoDiff] Clean up @transposing attribute type-checking. (#27988)
Remove dead code and unused parameters, unify style. Fix capacity of `@transposing` attribute parameter indices. The correct capacity is: non-wrt parameter count + wrt parameter count - 1. Todos: - More clean up; this clean up was not comprehensive. - Consider revamping type-checking for instance methods so that `@transposing` functions can always be declared in same type context as their original declaration.
1 parent 07db198 commit e4c8c8e

File tree

5 files changed

+133
-201
lines changed

5 files changed

+133
-201
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2952,8 +2952,8 @@ ERROR(transpose_params_clause_param_not_differentiable,none,
29522952
"'Differentiable' and where '%0 == %0.TangentVector'", (StringRef))
29532953
ERROR(transposing_attr_overload_not_found,none,
29542954
"could not find function %0 with expected type %1", (DeclName, Type))
2955-
ERROR(transposing_attr_cant_use_named_wrt_params,none,
2956-
"cannot use named wrt parameters in '@transposing' attribute, found %0",
2955+
ERROR(transposing_attr_cannot_use_named_wrt_params,none,
2956+
"cannot use named 'wrt' parameters in '@transposing' attribute, found %0",
29572957
(Identifier))
29582958
ERROR(transposing_attr_result_value_not_differentiable,none,
29592959
"'@transposing' attribute requires original function result to "

include/swift/AST/Types.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3124,12 +3124,11 @@ class AnyFunctionType : public TypeBase {
31243124
/// Given the type of an autodiff derivative function, returns the
31253125
/// corresponding original function type.
31263126
AnyFunctionType *getAutoDiffOriginalFunctionType();
3127-
3127+
31283128
/// Given the type of a transposing derivative function, returns the
31293129
/// corresponding original function type.
31303130
AnyFunctionType *
3131-
getTransposeOriginalFunctionType(TransposingAttr *attr,
3132-
IndexSubset *wrtParamIndices,
3131+
getTransposeOriginalFunctionType(IndexSubset *wrtParamIndices,
31333132
bool wrtSelf);
31343133

31353134
AnyFunctionType *getWithoutDifferentiability() const;

lib/AST/Type.cpp

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4742,25 +4742,24 @@ makeFunctionType(ArrayRef<AnyFunctionType::Param> params, Type retTy,
47424742
// Compute the original function type corresponding to the given transpose
47434743
// function type.
47444744
AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
4745-
TransposingAttr *attr, IndexSubset *wrtParamIndices, bool wrtSelf) {
4745+
IndexSubset *wrtParamIndices, bool wrtSelf) {
47464746
unsigned transposeParamsIndex = 0;
47474747
bool isCurried = getResult()->is<AnyFunctionType>();
4748-
4748+
47494749
// Get the original function's result.
47504750
auto transposeParams = getParams();
47514751
auto transposeResult = getResult();
47524752
if (isCurried) {
4753-
auto method =
4754-
getAs<AnyFunctionType>()->getResult()->getAs<AnyFunctionType>();
4755-
transposeParams = method->getParams();
4756-
transposeResult = method->getResult();
4753+
auto methodType = getResult()->castTo<AnyFunctionType>();
4754+
transposeParams = methodType->getParams();
4755+
transposeResult = methodType->getResult();
47574756
}
4758-
4757+
47594758
Type originalResult;
47604759
if (isCurried) {
47614760
// If it's curried, then the first parameter in the curried type, which is
47624761
// the 'Self' type, is the original result (no matter if we are
4763-
// differentiating WRT self or aren't).
4762+
// transposing wrt self or not).
47644763
originalResult = getParams().front().getPlainType();
47654764
} else {
47664765
// If it's not curried, the last parameter, the tangent, is always the
@@ -4770,22 +4769,21 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
47704769
}
47714770
assert(originalResult);
47724771

4773-
auto wrtParams = attr->getParsedParameters();
47744772
SmallVector<TupleTypeElt, 4> transposeResultTypes;
47754773
// Return type of '@transposing' function can have single type or tuples
47764774
// of types.
4777-
if (auto t = transposeResult->getAs<TupleType>()) {
4778-
transposeResultTypes.append(t->getElements().begin(),
4779-
t->getElements().end());
4775+
if (auto transposeResultTupleType = transposeResult->getAs<TupleType>()) {
4776+
transposeResultTypes.append(transposeResultTupleType->getElements().begin(),
4777+
transposeResultTupleType->getElements().end());
47804778
} else {
47814779
transposeResultTypes.push_back(transposeResult);
47824780
}
47834781
assert(!transposeResultTypes.empty());
47844782

4785-
// If the function is curried and is transposing WRT 'self', then grab
4783+
// If the function is curried and is transposing wrt 'self', then grab
47864784
// the type from the result list (guaranteed to be the first since 'self'
4787-
// is first in WRT list) and remove it. If it's still curried but not
4788-
// transposing WRT 'self', then the 'Self' type is the first parameter
4785+
// is first in wrt list) and remove it. If it is still curried but not
4786+
// transposing wrt 'self', then the 'Self' type is the first parameter
47894787
// in the method.
47904788
unsigned transposeResultTypesIndex = 0;
47914789
Type selfType;
@@ -4798,21 +4796,21 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
47984796
}
47994797

48004798
SmallVector<AnyFunctionType::Param, 8> originalParams;
4801-
unsigned numberOriginalParameters =
4802-
transposeParams.size() + wrtParams.size() - 1;
4803-
for (auto i : range(numberOriginalParameters)) {
4799+
unsigned originalParameterCount =
4800+
transposeParams.size() + wrtParamIndices->getNumIndices() - 1;
4801+
for (auto i : range(originalParameterCount)) {
48044802
// Need to check if it is the 'self' param since we handle it differently
48054803
// above.
4806-
bool lookingAtSelf = (i == (wrtParamIndices->getCapacity() - 1)) && wrtSelf;
4807-
bool isWrt = wrtParamIndices->contains(i);
4808-
if (isWrt && !lookingAtSelf) {
4809-
// If in WRT list, the item in the result tuple must be a parameter in the
4804+
bool lookingAtSelf = (i == wrtParamIndices->getCapacity() - 1) && wrtSelf;
4805+
if (wrtParamIndices->contains(i) && !lookingAtSelf) {
4806+
// If in wrt list, the item in the result tuple must be a parameter in the
48104807
// original function.
4811-
auto resultType = transposeResultTypes[transposeResultTypesIndex].getType();
4808+
auto resultType =
4809+
transposeResultTypes[transposeResultTypesIndex].getType();
48124810
originalParams.push_back(AnyFunctionType::Param(resultType));
48134811
transposeResultTypesIndex++;
48144812
} else {
4815-
// Else if not in the WRT list, the parameter in the transposing function
4813+
// Else if not in the wrt list, the parameter in the transposing function
48164814
// is a parameter in the original function.
48174815
originalParams.push_back(transposeParams[transposeParamsIndex]);
48184816
transposeParamsIndex++;

0 commit comments

Comments
 (0)