@@ -1029,49 +1029,56 @@ class DestructureInputs {
1029
1029
1030
1030
} // end anonymous namespace
1031
1031
1032
- static SmallVector<SILParameterInfo, 4 >
1032
+ // / Collects the differentiability parameters of the given original function
1033
+ // / type in `diffParams`.
1034
+ static void
1033
1035
getDifferentiabilityParameters (SILFunctionType *originalFnTy,
1034
- IndexSubset *parameterIndices) {
1036
+ IndexSubset *parameterIndices,
1037
+ SmallVectorImpl<SILParameterInfo> &diffParams) {
1035
1038
// Returns true if `index` is a differentiability parameter index.
1036
1039
auto isDiffParamIndex = [&](unsigned index) -> bool {
1037
1040
return index < parameterIndices->getCapacity () &&
1038
1041
parameterIndices->contains (index);
1039
1042
};
1040
1043
// Calculate differentiability parameter infos.
1041
- SmallVector<SILParameterInfo, 4 > diffParams;
1042
1044
for (auto valueAndIndex : enumerate(originalFnTy->getParameters ()))
1043
1045
if (isDiffParamIndex (valueAndIndex.index ()))
1044
1046
diffParams.push_back (valueAndIndex.value ());
1045
- return diffParams;
1046
1047
}
1047
1048
1048
- static SmallVector<SILResultInfo, 2 >
1049
- getSemanticResults (SILFunctionType *originalFnTy, IndexSubset *parameterIndices,
1049
+ // / Collects the semantic results of the given function type in
1050
+ // / `originalResults`. The semantic results are formal results followed by
1051
+ // / `inout` parameters, in type order.
1052
+ // TODO(TF-983): Generalize to support multiple `inout` parameters. The current
1053
+ // singular `inoutParam` and `isWrtInoutParameter` are hacky.
1054
+ static void
1055
+ getSemanticResults (SILFunctionType *functionType, IndexSubset *parameterIndices,
1050
1056
Optional<SILParameterInfo> &inoutParam,
1051
- bool &isWrtInoutParameter) {
1052
- // Compute the original semantic results: original formal results, followed by
1053
- // `inout` parameters in type order.
1054
- SmallVector<SILResultInfo, 2 > originalResults;
1057
+ bool &isWrtInoutParameter,
1058
+ SmallVectorImpl<SILResultInfo> &originalResults) {
1055
1059
inoutParam = None;
1056
1060
isWrtInoutParameter = false ;
1057
- originalResults.append (originalFnTy->getResults ().begin (),
1058
- originalFnTy->getResults ().end ());
1059
- for (auto i : range (originalFnTy->getNumParameters ())) {
1060
- auto param = originalFnTy->getParameters ()[i];
1061
+ // Collect original formal results.
1062
+ originalResults.append (functionType->getResults ().begin (),
1063
+ functionType->getResults ().end ());
1064
+ // Collect original `inout` parameters.
1065
+ for (auto i : range (functionType->getNumParameters ())) {
1066
+ auto param = functionType->getParameters ()[i];
1061
1067
if (!param.isIndirectInOut ())
1062
1068
continue ;
1063
1069
inoutParam = param;
1064
1070
isWrtInoutParameter = parameterIndices->contains (i);
1065
1071
originalResults.push_back (
1066
1072
SILResultInfo (param.getInterfaceType (), ResultConvention::Indirect));
1067
1073
}
1068
- return originalResults;
1069
1074
}
1070
1075
1076
+ // / Returns the differential type for the given original function type,
1077
+ // / parameter indices, and result index.
1071
1078
static CanSILFunctionType
1072
- getDifferentialType (SILFunctionType *originalFnTy,
1073
- IndexSubset *parameterIndices, unsigned resultIndex,
1074
- LookupConformanceFn lookupConformance) {
1079
+ getAutoDiffDifferentialType (SILFunctionType *originalFnTy,
1080
+ IndexSubset *parameterIndices, unsigned resultIndex,
1081
+ LookupConformanceFn lookupConformance) {
1075
1082
auto &ctx = originalFnTy->getASTContext ();
1076
1083
SmallVector<GenericTypeParamType *, 4 > substGenericParams;
1077
1084
SmallVector<Requirement, 4 > substRequirements;
@@ -1080,12 +1087,14 @@ getDifferentialType(SILFunctionType *originalFnTy,
1080
1087
1081
1088
Optional<SILParameterInfo> inoutParam = None;
1082
1089
bool isWrtInoutParameter = false ;
1083
- SmallVector<SILResultInfo, 2 > originalResults = getSemanticResults (
1084
- originalFnTy, parameterIndices, inoutParam, isWrtInoutParameter);
1090
+ SmallVector<SILResultInfo, 2 > originalResults;
1091
+ getSemanticResults (originalFnTy, parameterIndices, inoutParam,
1092
+ isWrtInoutParameter, originalResults);
1085
1093
1094
+ SmallVector<SILParameterInfo, 4 > diffParams;
1095
+ getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
1086
1096
SmallVector<SILParameterInfo, 8 > differentialParams;
1087
- for (auto ¶m :
1088
- getDifferentiabilityParameters (originalFnTy, parameterIndices)) {
1097
+ for (auto ¶m : diffParams) {
1089
1098
auto paramTan =
1090
1099
param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
1091
1100
assert (paramTan && " Parameter type does not have a tangent space?" );
@@ -1136,11 +1145,13 @@ getDifferentialType(SILFunctionType *originalFnTy,
1136
1145
differentialResults, None, substitutions, impliedSignature, ctx);
1137
1146
}
1138
1147
1139
- static CanSILFunctionType getPullbackType (SILFunctionType *originalFnTy,
1140
- IndexSubset *parameterIndices,
1141
- unsigned resultIndex,
1142
- LookupConformanceFn lookupConformance,
1143
- TypeConverter &TC) {
1148
+ // / Returns the pullback type for the given original function type, parameter
1149
+ // / indices, and result index.
1150
+ static CanSILFunctionType
1151
+ getAutoDiffPullbackType (SILFunctionType *originalFnTy,
1152
+ IndexSubset *parameterIndices, unsigned resultIndex,
1153
+ LookupConformanceFn lookupConformance,
1154
+ TypeConverter &TC) {
1144
1155
auto &ctx = originalFnTy->getASTContext ();
1145
1156
SmallVector<GenericTypeParamType *, 4 > substGenericParams;
1146
1157
SmallVector<Requirement, 4 > substRequirements;
@@ -1149,8 +1160,9 @@ static CanSILFunctionType getPullbackType(SILFunctionType *originalFnTy,
1149
1160
1150
1161
Optional<SILParameterInfo> inoutParam = None;
1151
1162
bool isWrtInoutParameter = false ;
1152
- SmallVector<SILResultInfo, 2 > originalResults = getSemanticResults (
1153
- originalFnTy, parameterIndices, inoutParam, isWrtInoutParameter);
1163
+ SmallVector<SILResultInfo, 2 > originalResults;
1164
+ getSemanticResults (originalFnTy, parameterIndices, inoutParam,
1165
+ isWrtInoutParameter, originalResults);
1154
1166
1155
1167
// Given a type, returns its formal SIL parameter info.
1156
1168
auto getTangentParameterConventionForOriginalResult =
@@ -1251,9 +1263,10 @@ static CanSILFunctionType getPullbackType(SILFunctionType *originalFnTy,
1251
1263
pullbackParams.push_back ({gpType, paramTanConvention});
1252
1264
}
1253
1265
}
1266
+ SmallVector<SILParameterInfo, 4 > diffParams;
1267
+ getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
1254
1268
SmallVector<SILResultInfo, 8 > pullbackResults;
1255
- for (auto ¶m :
1256
- getDifferentiabilityParameters (originalFnTy, parameterIndices)) {
1269
+ for (auto ¶m : diffParams) {
1257
1270
if (param.isIndirectInOut ())
1258
1271
continue ;
1259
1272
auto paramTan =
@@ -1378,12 +1391,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
1378
1391
CanSILFunctionType closureType;
1379
1392
switch (kind) {
1380
1393
case AutoDiffDerivativeFunctionKind::JVP:
1381
- closureType = getDifferentialType (constrainedOriginalFnTy, parameterIndices,
1382
- resultIndex, lookupConformance);
1394
+ closureType =
1395
+ getAutoDiffDifferentialType (constrainedOriginalFnTy, parameterIndices,
1396
+ resultIndex, lookupConformance);
1383
1397
break ;
1384
1398
case AutoDiffDerivativeFunctionKind::VJP:
1385
- closureType = getPullbackType (constrainedOriginalFnTy, parameterIndices,
1386
- resultIndex, lookupConformance, TC);
1399
+ closureType =
1400
+ getAutoDiffPullbackType (constrainedOriginalFnTy, parameterIndices,
1401
+ resultIndex, lookupConformance, TC);
1387
1402
break ;
1388
1403
}
1389
1404
0 commit comments