@@ -4742,25 +4742,24 @@ makeFunctionType(ArrayRef<AnyFunctionType::Param> params, Type retTy,
4742
4742
// Compute the original function type corresponding to the given transpose
4743
4743
// function type.
4744
4744
AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType (
4745
- TransposingAttr *attr, IndexSubset *wrtParamIndices, bool wrtSelf) {
4745
+ IndexSubset *wrtParamIndices, bool wrtSelf) {
4746
4746
unsigned transposeParamsIndex = 0 ;
4747
4747
bool isCurried = getResult ()->is <AnyFunctionType>();
4748
-
4748
+
4749
4749
// Get the original function's result.
4750
4750
auto transposeParams = getParams ();
4751
4751
auto transposeResult = getResult ();
4752
4752
if (isCurried) {
4753
- auto method =
4754
- getAs<AnyFunctionType>()->getResult ()->getAs <AnyFunctionType>();
4755
- transposeParams = method->getParams ();
4756
- transposeResult = method->getResult ();
4753
+ auto methodType = getResult ()->castTo <AnyFunctionType>();
4754
+ transposeParams = methodType->getParams ();
4755
+ transposeResult = methodType->getResult ();
4757
4756
}
4758
-
4757
+
4759
4758
Type originalResult;
4760
4759
if (isCurried) {
4761
4760
// If it's curried, then the first parameter in the curried type, which is
4762
4761
// the 'Self' type, is the original result (no matter if we are
4763
- // differentiating WRT self or aren't ).
4762
+ // transposing wrt self or not ).
4764
4763
originalResult = getParams ().front ().getPlainType ();
4765
4764
} else {
4766
4765
// If it's not curried, the last parameter, the tangent, is always the
@@ -4770,22 +4769,21 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
4770
4769
}
4771
4770
assert (originalResult);
4772
4771
4773
- auto wrtParams = attr->getParsedParameters ();
4774
4772
SmallVector<TupleTypeElt, 4 > transposeResultTypes;
4775
4773
// Return type of '@transposing' function can have single type or tuples
4776
4774
// of types.
4777
- if (auto t = transposeResult->getAs <TupleType>()) {
4778
- transposeResultTypes.append (t ->getElements ().begin (),
4779
- t ->getElements ().end ());
4775
+ if (auto transposeResultTupleType = transposeResult->getAs <TupleType>()) {
4776
+ transposeResultTypes.append (transposeResultTupleType ->getElements ().begin (),
4777
+ transposeResultTupleType ->getElements ().end ());
4780
4778
} else {
4781
4779
transposeResultTypes.push_back (transposeResult);
4782
4780
}
4783
4781
assert (!transposeResultTypes.empty ());
4784
4782
4785
- // If the function is curried and is transposing WRT 'self', then grab
4783
+ // If the function is curried and is transposing wrt 'self', then grab
4786
4784
// the type from the result list (guaranteed to be the first since 'self'
4787
- // is first in WRT list) and remove it. If it's still curried but not
4788
- // transposing WRT 'self', then the 'Self' type is the first parameter
4785
+ // is first in wrt list) and remove it. If it is still curried but not
4786
+ // transposing wrt 'self', then the 'Self' type is the first parameter
4789
4787
// in the method.
4790
4788
unsigned transposeResultTypesIndex = 0 ;
4791
4789
Type selfType;
@@ -4798,21 +4796,21 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
4798
4796
}
4799
4797
4800
4798
SmallVector<AnyFunctionType::Param, 8 > originalParams;
4801
- unsigned numberOriginalParameters =
4802
- transposeParams.size () + wrtParams. size () - 1 ;
4803
- for (auto i : range (numberOriginalParameters )) {
4799
+ unsigned originalParameterCount =
4800
+ transposeParams.size () + wrtParamIndices-> getNumIndices () - 1 ;
4801
+ for (auto i : range (originalParameterCount )) {
4804
4802
// Need to check if it is the 'self' param since we handle it differently
4805
4803
// above.
4806
- bool lookingAtSelf = (i == (wrtParamIndices->getCapacity () - 1 )) && wrtSelf;
4807
- bool isWrt = wrtParamIndices->contains (i);
4808
- if (isWrt && !lookingAtSelf) {
4809
- // If in WRT list, the item in the result tuple must be a parameter in the
4804
+ bool lookingAtSelf = (i == wrtParamIndices->getCapacity () - 1 ) && wrtSelf;
4805
+ if (wrtParamIndices->contains (i) && !lookingAtSelf) {
4806
+ // If in wrt list, the item in the result tuple must be a parameter in the
4810
4807
// original function.
4811
- auto resultType = transposeResultTypes[transposeResultTypesIndex].getType ();
4808
+ auto resultType =
4809
+ transposeResultTypes[transposeResultTypesIndex].getType ();
4812
4810
originalParams.push_back (AnyFunctionType::Param (resultType));
4813
4811
transposeResultTypesIndex++;
4814
4812
} else {
4815
- // Else if not in the WRT list, the parameter in the transposing function
4813
+ // Else if not in the wrt list, the parameter in the transposing function
4816
4814
// is a parameter in the original function.
4817
4815
originalParams.push_back (transposeParams[transposeParamsIndex]);
4818
4816
transposeParamsIndex++;
0 commit comments