Skip to content

Commit 0592979

Browse files
authored
Merge pull request #61956 from atrick/fix-closurespecialize-hang
Fix infinite recursion in ClosureSpecialize
2 parents c69184e + c665311 commit 0592979

File tree

2 files changed

+114
-19
lines changed

2 files changed

+114
-19
lines changed

lib/SILOptimizer/IPO/ClosureSpecializer.cpp

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,8 @@ static bool canSpecializeFullApplySite(FullApplySiteKind kind) {
10731073
llvm_unreachable("covered switch");
10741074
}
10751075

1076+
const int SpecializationLevelLimit = 2;
1077+
10761078
static int getSpecializationLevelRecursive(StringRef funcName, Demangler &parent) {
10771079
using namespace Demangle;
10781080

@@ -1098,23 +1100,44 @@ static int getSpecializationLevelRecursive(StringRef funcName, Demangler &parent
10981100
return 0;
10991101
if (funcSpec->getKind() != Node::Kind::FunctionSignatureSpecialization)
11001102
return 0;
1101-
Node *param = funcSpec->getChild(1);
1102-
if (param->getKind() != Node::Kind::FunctionSignatureSpecializationParam)
1103-
return 0;
1104-
if (param->getNumChildren() < 2)
1105-
return 0;
1106-
Node *kindNd = param->getChild(0);
1107-
if (kindNd->getKind() != Node::Kind::FunctionSignatureSpecializationParamKind)
1108-
return 0;
1109-
auto kind = FunctionSigSpecializationParamKind(kindNd->getIndex());
1110-
if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction)
1111-
return 0;
1112-
1113-
Node *payload = param->getChild(1);
1114-
if (payload->getKind() != Node::Kind::FunctionSignatureSpecializationParamPayload)
1115-
return 1;
1116-
// Check if the specialized function is a specialization itself.
1117-
return 1 + getSpecializationLevelRecursive(payload->getText(), demangler);
1103+
1104+
// Match any function specialization. We check for constant propagation at the
1105+
// parameter level.
1106+
Node *param = funcSpec->getChild(0);
1107+
if (param->getKind() != Node::Kind::SpecializationPassID)
1108+
return SpecializationLevelLimit + 1; // unrecognized format
1109+
1110+
unsigned maxParamLevel = 0;
1111+
for (unsigned paramIdx = 1; paramIdx < funcSpec->getNumChildren();
1112+
++paramIdx) {
1113+
Node *param = funcSpec->getChild(paramIdx);
1114+
if (param->getKind() != Node::Kind::FunctionSignatureSpecializationParam)
1115+
return SpecializationLevelLimit + 1; // unrecognized format
1116+
1117+
// A parameter is recursive if it has a kind with index and type payload
1118+
if (param->getNumChildren() < 2)
1119+
continue;
1120+
1121+
Node *kindNd = param->getChild(0);
1122+
if (kindNd->getKind()
1123+
!= Node::Kind::FunctionSignatureSpecializationParamKind) {
1124+
return SpecializationLevelLimit + 1; // unrecognized format
1125+
}
1126+
auto kind = FunctionSigSpecializationParamKind(kindNd->getIndex());
1127+
if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction)
1128+
continue;
1129+
Node *payload = param->getChild(1);
1130+
if (payload->getKind()
1131+
!= Node::Kind::FunctionSignatureSpecializationParamPayload) {
1132+
return SpecializationLevelLimit + 1; // unrecognized format
1133+
}
1134+
// Check if the specialized function is a specialization itself.
1135+
unsigned paramLevel =
1136+
1 + getSpecializationLevelRecursive(payload->getText(), demangler);
1137+
if (paramLevel > maxParamLevel)
1138+
maxParamLevel = paramLevel;
1139+
}
1140+
return maxParamLevel;
11181141
}
11191142

11201143
/// If \p function is a function-signature specialization for a constant-
@@ -1328,9 +1351,10 @@ bool SILClosureSpecializerTransform::gatherCallSites(
13281351
//
13291352
// A limit of 2 is good enough and will not be exceed in "regular"
13301353
// optimization scenarios.
1331-
if (getSpecializationLevel(getClosureCallee(ClosureInst)) > 2)
1354+
if (getSpecializationLevel(getClosureCallee(ClosureInst))
1355+
> SpecializationLevelLimit) {
13321356
continue;
1333-
1357+
}
13341358
// Compute the final release points of the closure. We will insert
13351359
// release of the captured arguments here.
13361360
if (!CInfo)
@@ -1395,6 +1419,8 @@ bool SILClosureSpecializerTransform::specialize(SILFunction *Caller,
13951419
if (!NewF) {
13961420
NewF = ClosureSpecCloner::cloneFunction(FuncBuilder, CSDesc, NewFName);
13971421
addFunctionToPassManagerWorklist(NewF, CSDesc.getApplyCallee());
1422+
LLVM_DEBUG(llvm::dbgs() << "\nThe rewritten callee is:\n";
1423+
NewF->dump());
13981424
}
13991425

14001426
// Rewrite the call
@@ -1404,6 +1430,10 @@ bool SILClosureSpecializerTransform::specialize(SILFunction *Caller,
14041430
Changed = true;
14051431
}
14061432
}
1433+
LLVM_DEBUG(if (Changed) {
1434+
llvm::dbgs() << "\nThe rewritten caller is:\n";
1435+
Caller->dump();
1436+
});
14071437
return Changed;
14081438
}
14091439

test/SILOptimizer/closure_specialize_loop.swift

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,68 @@ public func testit(c: @escaping () -> Bool) {
1414
}
1515
}
1616

17+
// PR: https://github.com/apple/swift/pull/61956
18+
// Optimizing Expression.contains(where:) should not timeout.
19+
//
20+
// Repeated capture propagation leads to:
21+
// func contains$termPred@arg0$[termPred$falsePred@arg1]@arg1(expr) {
22+
// closure = termPred$[termPred$falsePred@arg1]@arg1
23+
// falsePred(expr)
24+
// contains$termPred@arg0$termPred$[termPred$falsePred@arg1]@arg1(expr)
25+
// }
26+
//
27+
// func contains$termPred@arg0$termPred$[termPred$falsePred@arg1]@arg1(expr) {
28+
// closure = [termPred(termPred$[termPred$falsePred@arg1]@arg1)]
29+
// closure(expr)
30+
// contains$termPred@arg0(expr, closure)
31+
// }
32+
// The Demangled type tree looks like:
33+
// kind=FunctionSignatureSpecialization
34+
// kind=SpecializationPassID, index=3
35+
// kind=FunctionSignatureSpecializationParam
36+
// kind=FunctionSignatureSpecializationParam
37+
// kind=FunctionSignatureSpecializationParamKind, index=0
38+
// kind=FunctionSignatureSpecializationParamPayload, text="$s4test10ExpressionO8contains5whereS3bXE_tFSbACXEfU_S2bXEfU_36$s4test12IndirectEnumVACycfcS2bXEfU_Tf3npf_n"
39+
//
40+
// CHECK-LABEL: $s23closure_specialize_loop10ExpressionO8contains5whereS3bXE_tFSbACXEfU_S2bXEfU_012$s23closure_b7_loop10d44O8contains5whereS3bXE_tFSbACXEfU_S2bXEfU_012g13_b7_loop10d44ijk2_tlm2U_no52U_012g30_B34_loop12IndirectEnumVACycfcnO10U_Tf3npf_nY2_nTf3npf_n
41+
// ---> function signature specialization
42+
// <Arg[1] = [Constant Propagated Function : function signature specialization
43+
// <Arg[1] = [Constant Propagated Function : function signature specialization
44+
// <Arg[1] = [Constant Propagated Function : closure #1 (Swift.Bool) -> Swift.Bool
45+
// in closure_specialize_loop.IndirectEnum.init() -> closure_specialize_loop.IndirectEnum]>
46+
// of closure #1 (Swift.Bool) -> Swift.Bool
47+
// in closure #1 (closure_specialize_loop.Expression) -> Swift.Bool
48+
// in closure_specialize_loop.Expression.contains(where: (Swift.Bool) -> Swift.Bool) -> Swift.Bool]>
49+
// of closure #1 (Swift.Bool) -> Swift.Bool
50+
// in closure #1 (closure_specialize_loop.Expression) -> Swift.Bool
51+
// in closure_specialize_loop.Expression.contains(where: (Swift.Bool) -> Swift.Bool) -> Swift.Bool]>
52+
// of closure #1 (Swift.Bool) -> Swift.Bool
53+
// in closure #1 (closure_specialize_loop.Expression) -> Swift.Bool
54+
// in closure_specialize_loop.Expression.contains(where: (Swift.Bool) -> Swift.Bool) -> Swift.Bool
55+
//
56+
public indirect enum Expression {
57+
case term(Bool)
58+
case list(_ expressions: [Expression])
59+
60+
public func contains(where predicate: (Bool) -> Bool) -> Bool {
61+
switch self {
62+
case let .term(term):
63+
return predicate(term)
64+
case let .list(expressions):
65+
return expressions.contains { expression in
66+
expression.contains { term in
67+
predicate(term)
68+
}
69+
}
70+
}
71+
}
72+
}
73+
74+
public struct IndirectEnum {
75+
public init() {
76+
let containsFalse = Expression.list([.list([.term(true), .term(false)]), .term(true)]).contains { term in
77+
term == false
78+
}
79+
print(containsFalse)
80+
}
81+
}

0 commit comments

Comments
 (0)