@@ -5075,7 +5075,7 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
5075
5075
unsigned transposeParamsIndex = 0 ;
5076
5076
bool isCurried = getResult ()->is <AnyFunctionType>();
5077
5077
5078
- // Get the original function's result.
5078
+ // Get the transpose function's parameters and result type .
5079
5079
auto transposeParams = getParams ();
5080
5080
auto transposeResult = getResult ();
5081
5081
if (isCurried) {
@@ -5084,79 +5084,81 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
5084
5084
transposeResult = methodType->getResult ();
5085
5085
}
5086
5086
5087
- Type originalResult;
5088
- if (isCurried) {
5089
- // If it's curried, then the first parameter in the curried type, which is
5090
- // the 'Self' type, is the original result (no matter if we are
5091
- // transposing wrt self or not).
5092
- originalResult = getParams ().front ().getPlainType ();
5093
- } else {
5094
- // If it's not curried, the last parameter, the tangent, is always the
5095
- // original result type as we require the last parameter of the transposing
5096
- // function to be the original result.
5097
- originalResult = transposeParams.back ().getPlainType ();
5098
- }
5099
- assert (originalResult);
5087
+ // Get the original function's result type.
5088
+ // The original result type is always equal to the type of the last
5089
+ // parameter of the transpose function type.
5090
+ auto originalResult = transposeParams.back ().getPlainType ();
5100
5091
5092
+ // Get transposed result types.
5093
+ // The transpose function result type may be a singular type or a tuple type.
5101
5094
SmallVector<TupleTypeElt, 4 > transposeResultTypes;
5102
- // Return type of transpose function can be a singular type or a tuple type.
5103
5095
if (auto transposeResultTupleType = transposeResult->getAs <TupleType>()) {
5104
5096
transposeResultTypes.append (transposeResultTupleType->getElements ().begin (),
5105
5097
transposeResultTupleType->getElements ().end ());
5106
5098
} else {
5107
5099
transposeResultTypes.push_back (transposeResult);
5108
5100
}
5109
- assert (!transposeResultTypes.empty ());
5110
5101
5111
- // If the function is curried and is transposing wrt 'self', then grab
5112
- // the type from the result list (guaranteed to be the first since 'self'
5113
- // is first in wrt list) and remove it. If it is still curried but not
5114
- // transposing wrt 'self', then the 'Self' type is the first parameter
5115
- // in the method.
5102
+ // Get the `Self` type, if the transpose function type is curried.
5103
+ // - If `self` is a linearity parameter, use the first transpose result type.
5104
+ // - Otherwise, use the first transpose parameter type.
5116
5105
unsigned transposeResultTypesIndex = 0 ;
5117
5106
Type selfType;
5118
5107
if (isCurried && wrtSelf) {
5119
5108
selfType = transposeResultTypes.front ().getType ();
5120
5109
transposeResultTypesIndex++;
5121
5110
} else if (isCurried) {
5122
- selfType = transposeParams.front ().getPlainType ();
5123
- transposeParamsIndex++;
5111
+ selfType = getParams ().front ().getPlainType ();
5124
5112
}
5125
5113
5114
+ // Get the original function's parameters.
5126
5115
SmallVector<AnyFunctionType::Param, 8 > originalParams;
5116
+ // The number of original parameters is equal to the sum of:
5117
+ // - The number of original non-transposed parameters.
5118
+ // - This is the number of transpose parameters minus one. All transpose
5119
+ // parameters come from the original function, except the last parameter
5120
+ // (the transposed original result).
5121
+ // - The number of original transposed parameters.
5122
+ // - This is the number of linearity parameters.
5127
5123
unsigned originalParameterCount =
5128
- transposeParams.size () + wrtParamIndices->getNumIndices () - 1 ;
5124
+ transposeParams.size () - 1 + wrtParamIndices->getNumIndices ();
5125
+ // Iterate over all original parameter indices.
5129
5126
for (auto i : range (originalParameterCount)) {
5130
- // Need to check if it is the 'self' param since we handle it differently
5131
- // above.
5132
- bool lookingAtSelf = (i == wrtParamIndices->getCapacity () - 1 ) && wrtSelf;
5133
- if (wrtParamIndices->contains (i) && !lookingAtSelf) {
5134
- // If in wrt list, the item in the result tuple must be a parameter in the
5135
- // original function.
5127
+ // Skip `self` parameter if `self` is a linearity parameter.
5128
+ // The `self` is handled specially later to form a curried function type.
5129
+ bool isSelfParameterAndWrtSelf =
5130
+ wrtSelf && i == wrtParamIndices->getCapacity () - 1 ;
5131
+ if (isSelfParameterAndWrtSelf)
5132
+ continue ;
5133
+ // If `i` is a linearity parameter index, the next original parameter is
5134
+ // the next transpose result.
5135
+ if (wrtParamIndices->contains (i)) {
5136
5136
auto resultType =
5137
- transposeResultTypes[transposeResultTypesIndex].getType ();
5137
+ transposeResultTypes[transposeResultTypesIndex++ ].getType ();
5138
5138
originalParams.push_back (AnyFunctionType::Param (resultType));
5139
- transposeResultTypesIndex++;
5140
- } else {
5141
- // Else if not in the wrt list, the parameter in the transposing function
5142
- // is a parameter in the original function.
5143
- originalParams.push_back (transposeParams[transposeParamsIndex]);
5144
- transposeParamsIndex++;
5139
+ }
5140
+ // Otherwise, the next original parameter is the next transpose parameter.
5141
+ else {
5142
+ originalParams.push_back (transposeParams[transposeParamsIndex++]);
5145
5143
}
5146
5144
}
5147
5145
5146
+ // Compute the original function type.
5148
5147
AnyFunctionType *originalType;
5148
+ // If the transpose type is curried, the original function type is:
5149
+ // `(Self) -> (<original parameters>) -> <original result>`.
5149
5150
if (isCurried) {
5150
- assert (selfType);
5151
- // If curried, wrap the function into the 'Self' type to get a method.
5151
+ assert (selfType && " `Self` type should be resolved" );
5152
5152
originalType = makeFunctionType (originalParams, originalResult, nullptr );
5153
5153
originalType = makeFunctionType (AnyFunctionType::Param (selfType),
5154
5154
originalType, getOptGenericSignature ());
5155
- } else {
5155
+ }
5156
+ // Otherwise, the original function type is simply:
5157
+ // `(<original parameters>) -> <original result>`.
5158
+ else {
5156
5159
originalType = makeFunctionType (originalParams, originalResult,
5157
5160
getOptGenericSignature ());
5158
5161
}
5159
- assert (originalType);
5160
5162
return originalType;
5161
5163
}
5162
5164
0 commit comments