@@ -1073,6 +1073,8 @@ static bool canSpecializeFullApplySite(FullApplySiteKind kind) {
1073
1073
llvm_unreachable (" covered switch" );
1074
1074
}
1075
1075
1076
+ const int SpecializationLevelLimit = 2 ;
1077
+
1076
1078
static int getSpecializationLevelRecursive (StringRef funcName, Demangler &parent) {
1077
1079
using namespace Demangle ;
1078
1080
@@ -1098,23 +1100,44 @@ static int getSpecializationLevelRecursive(StringRef funcName, Demangler &parent
1098
1100
return 0 ;
1099
1101
if (funcSpec->getKind () != Node::Kind::FunctionSignatureSpecialization)
1100
1102
return 0 ;
1101
- Node *param = funcSpec->getChild (1 );
1102
- if (param->getKind () != Node::Kind::FunctionSignatureSpecializationParam)
1103
- return 0 ;
1104
- if (param->getNumChildren () < 2 )
1105
- return 0 ;
1106
- Node *kindNd = param->getChild (0 );
1107
- if (kindNd->getKind () != Node::Kind::FunctionSignatureSpecializationParamKind)
1108
- return 0 ;
1109
- auto kind = FunctionSigSpecializationParamKind (kindNd->getIndex ());
1110
- if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction)
1111
- return 0 ;
1112
-
1113
- Node *payload = param->getChild (1 );
1114
- if (payload->getKind () != Node::Kind::FunctionSignatureSpecializationParamPayload)
1115
- return 1 ;
1116
- // Check if the specialized function is a specialization itself.
1117
- return 1 + getSpecializationLevelRecursive (payload->getText (), demangler);
1103
+
1104
+ // Match any function specialization. We check for constant propagation at the
1105
+ // parameter level.
1106
+ Node *param = funcSpec->getChild (0 );
1107
+ if (param->getKind () != Node::Kind::SpecializationPassID)
1108
+ return SpecializationLevelLimit + 1 ; // unrecognized format
1109
+
1110
+ unsigned maxParamLevel = 0 ;
1111
+ for (unsigned paramIdx = 1 ; paramIdx < funcSpec->getNumChildren ();
1112
+ ++paramIdx) {
1113
+ Node *param = funcSpec->getChild (paramIdx);
1114
+ if (param->getKind () != Node::Kind::FunctionSignatureSpecializationParam)
1115
+ return SpecializationLevelLimit + 1 ; // unrecognized format
1116
+
1117
+ // A parameter is recursive if it has a kind with index and type payload
1118
+ if (param->getNumChildren () < 2 )
1119
+ continue ;
1120
+
1121
+ Node *kindNd = param->getChild (0 );
1122
+ if (kindNd->getKind ()
1123
+ != Node::Kind::FunctionSignatureSpecializationParamKind) {
1124
+ return SpecializationLevelLimit + 1 ; // unrecognized format
1125
+ }
1126
+ auto kind = FunctionSigSpecializationParamKind (kindNd->getIndex ());
1127
+ if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction)
1128
+ continue ;
1129
+ Node *payload = param->getChild (1 );
1130
+ if (payload->getKind ()
1131
+ != Node::Kind::FunctionSignatureSpecializationParamPayload) {
1132
+ return SpecializationLevelLimit + 1 ; // unrecognized format
1133
+ }
1134
+ // Check if the specialized function is a specialization itself.
1135
+ unsigned paramLevel =
1136
+ 1 + getSpecializationLevelRecursive (payload->getText (), demangler);
1137
+ if (paramLevel > maxParamLevel)
1138
+ maxParamLevel = paramLevel;
1139
+ }
1140
+ return maxParamLevel;
1118
1141
}
1119
1142
1120
1143
// / If \p function is a function-signature specialization for a constant-
@@ -1328,9 +1351,10 @@ bool SILClosureSpecializerTransform::gatherCallSites(
1328
1351
//
1329
1352
// A limit of 2 is good enough and will not be exceed in "regular"
1330
1353
// optimization scenarios.
1331
- if (getSpecializationLevel (getClosureCallee (ClosureInst)) > 2 )
1354
+ if (getSpecializationLevel (getClosureCallee (ClosureInst))
1355
+ > SpecializationLevelLimit) {
1332
1356
continue ;
1333
-
1357
+ }
1334
1358
// Compute the final release points of the closure. We will insert
1335
1359
// release of the captured arguments here.
1336
1360
if (!CInfo)
@@ -1395,6 +1419,8 @@ bool SILClosureSpecializerTransform::specialize(SILFunction *Caller,
1395
1419
if (!NewF) {
1396
1420
NewF = ClosureSpecCloner::cloneFunction (FuncBuilder, CSDesc, NewFName);
1397
1421
addFunctionToPassManagerWorklist (NewF, CSDesc.getApplyCallee ());
1422
+ LLVM_DEBUG (llvm::dbgs () << " \n The rewritten callee is:\n " ;
1423
+ NewF->dump ());
1398
1424
}
1399
1425
1400
1426
// Rewrite the call
@@ -1404,6 +1430,10 @@ bool SILClosureSpecializerTransform::specialize(SILFunction *Caller,
1404
1430
Changed = true ;
1405
1431
}
1406
1432
}
1433
+ LLVM_DEBUG (if (Changed) {
1434
+ llvm::dbgs () << " \n The rewritten caller is:\n " ;
1435
+ Caller->dump ();
1436
+ });
1407
1437
return Changed;
1408
1438
}
1409
1439
0 commit comments