Skip to content

Commit c665311

Browse files
committed
Fix infinite recursion in ClosureSpecialize
Fixes getSpecializationLevelRecursive to handle recursive manglings caused by interleaving CapturePropagation and ClosureSpecialize passes. For some reason, only the first closure parameter was checked for recursion. We need to handle patterns like this: kind=FunctionSignatureSpecialization kind=SpecializationPassID, index=3 kind=FunctionSignatureSpecializationParam kind=FunctionSignatureSpecializationParam kind=FunctionSignatureSpecializationParamKind, index=0 kind=FunctionSignatureSpecializationParamPayload, text="$s4test10ExpressionO8contains5whereS3bXE_tFSbACXEfU_S2bXEfU_36$s4test12IndirectEnumVACycfcS2bXEfU_Tf3npf_n" I fixed the logic so we now check for recursion on all closure parameters and bail out on unrecognized mangling formats. For reference, see summary.sil in Infinitely recursive closure specialization #61955 #61955 Fixes rdar://101589190 (Swift Compiler hangs when building this code for release)
1 parent 9c321c0 commit c665311

File tree

2 files changed

+114
-19
lines changed

2 files changed

+114
-19
lines changed

lib/SILOptimizer/IPO/ClosureSpecializer.cpp

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,8 @@ static bool canSpecializeFullApplySite(FullApplySiteKind kind) {
10731073
llvm_unreachable("covered switch");
10741074
}
10751075

1076+
const int SpecializationLevelLimit = 2;
1077+
10761078
static int getSpecializationLevelRecursive(StringRef funcName, Demangler &parent) {
10771079
using namespace Demangle;
10781080

@@ -1098,23 +1100,44 @@ static int getSpecializationLevelRecursive(StringRef funcName, Demangler &parent
10981100
return 0;
10991101
if (funcSpec->getKind() != Node::Kind::FunctionSignatureSpecialization)
11001102
return 0;
1101-
Node *param = funcSpec->getChild(1);
1102-
if (param->getKind() != Node::Kind::FunctionSignatureSpecializationParam)
1103-
return 0;
1104-
if (param->getNumChildren() < 2)
1105-
return 0;
1106-
Node *kindNd = param->getChild(0);
1107-
if (kindNd->getKind() != Node::Kind::FunctionSignatureSpecializationParamKind)
1108-
return 0;
1109-
auto kind = FunctionSigSpecializationParamKind(kindNd->getIndex());
1110-
if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction)
1111-
return 0;
1112-
1113-
Node *payload = param->getChild(1);
1114-
if (payload->getKind() != Node::Kind::FunctionSignatureSpecializationParamPayload)
1115-
return 1;
1116-
// Check if the specialized function is a specialization itself.
1117-
return 1 + getSpecializationLevelRecursive(payload->getText(), demangler);
1103+
1104+
// Match any function specialization. We check for constant propagation at the
1105+
// parameter level.
1106+
Node *param = funcSpec->getChild(0);
1107+
if (param->getKind() != Node::Kind::SpecializationPassID)
1108+
return SpecializationLevelLimit + 1; // unrecognized format
1109+
1110+
unsigned maxParamLevel = 0;
1111+
for (unsigned paramIdx = 1; paramIdx < funcSpec->getNumChildren();
1112+
++paramIdx) {
1113+
Node *param = funcSpec->getChild(paramIdx);
1114+
if (param->getKind() != Node::Kind::FunctionSignatureSpecializationParam)
1115+
return SpecializationLevelLimit + 1; // unrecognized format
1116+
1117+
// A parameter is recursive if it has a kind with index and type payload
1118+
if (param->getNumChildren() < 2)
1119+
continue;
1120+
1121+
Node *kindNd = param->getChild(0);
1122+
if (kindNd->getKind()
1123+
!= Node::Kind::FunctionSignatureSpecializationParamKind) {
1124+
return SpecializationLevelLimit + 1; // unrecognized format
1125+
}
1126+
auto kind = FunctionSigSpecializationParamKind(kindNd->getIndex());
1127+
if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction)
1128+
continue;
1129+
Node *payload = param->getChild(1);
1130+
if (payload->getKind()
1131+
!= Node::Kind::FunctionSignatureSpecializationParamPayload) {
1132+
return SpecializationLevelLimit + 1; // unrecognized format
1133+
}
1134+
// Check if the specialized function is a specialization itself.
1135+
unsigned paramLevel =
1136+
1 + getSpecializationLevelRecursive(payload->getText(), demangler);
1137+
if (paramLevel > maxParamLevel)
1138+
maxParamLevel = paramLevel;
1139+
}
1140+
return maxParamLevel;
11181141
}
11191142

11201143
/// If \p function is a function-signature specialization for a constant-
@@ -1328,9 +1351,10 @@ bool SILClosureSpecializerTransform::gatherCallSites(
13281351
//
13291352
// A limit of 2 is good enough and will not be exceed in "regular"
13301353
// optimization scenarios.
1331-
if (getSpecializationLevel(getClosureCallee(ClosureInst)) > 2)
1354+
if (getSpecializationLevel(getClosureCallee(ClosureInst))
1355+
> SpecializationLevelLimit) {
13321356
continue;
1333-
1357+
}
13341358
// Compute the final release points of the closure. We will insert
13351359
// release of the captured arguments here.
13361360
if (!CInfo)
@@ -1395,6 +1419,8 @@ bool SILClosureSpecializerTransform::specialize(SILFunction *Caller,
13951419
if (!NewF) {
13961420
NewF = ClosureSpecCloner::cloneFunction(FuncBuilder, CSDesc, NewFName);
13971421
addFunctionToPassManagerWorklist(NewF, CSDesc.getApplyCallee());
1422+
LLVM_DEBUG(llvm::dbgs() << "\nThe rewritten callee is:\n";
1423+
NewF->dump());
13981424
}
13991425

14001426
// Rewrite the call
@@ -1404,6 +1430,10 @@ bool SILClosureSpecializerTransform::specialize(SILFunction *Caller,
14041430
Changed = true;
14051431
}
14061432
}
1433+
LLVM_DEBUG(if (Changed) {
1434+
llvm::dbgs() << "\nThe rewritten caller is:\n";
1435+
Caller->dump();
1436+
});
14071437
return Changed;
14081438
}
14091439

test/SILOptimizer/closure_specialize_loop.swift

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,68 @@ public func testit(c: @escaping () -> Bool) {
1414
}
1515
}
1616

17+
// PR: https://github.com/apple/swift/pull/61956
18+
// Optimizing Expression.contains(where:) should not timeout.
19+
//
20+
// Repeated capture propagation leads to:
21+
// func contains$termPred@arg0$[termPred$falsePred@arg1]@arg1(expr) {
22+
// closure = termPred$[termPred$falsePred@arg1]@arg1
23+
// falsePred(expr)
24+
// contains$termPred@arg0$termPred$[termPred$falsePred@arg1]@arg1(expr)
25+
// }
26+
//
27+
// func contains$termPred@arg0$termPred$[termPred$falsePred@arg1]@arg1(expr) {
28+
// closure = [termPred(termPred$[termPred$falsePred@arg1]@arg1)]
29+
// closure(expr)
30+
// contains$termPred@arg0(expr, closure)
31+
// }
32+
// The Demangled type tree looks like:
33+
// kind=FunctionSignatureSpecialization
34+
// kind=SpecializationPassID, index=3
35+
// kind=FunctionSignatureSpecializationParam
36+
// kind=FunctionSignatureSpecializationParam
37+
// kind=FunctionSignatureSpecializationParamKind, index=0
38+
// kind=FunctionSignatureSpecializationParamPayload, text="$s4test10ExpressionO8contains5whereS3bXE_tFSbACXEfU_S2bXEfU_36$s4test12IndirectEnumVACycfcS2bXEfU_Tf3npf_n"
39+
//
40+
// CHECK-LABEL: $s23closure_specialize_loop10ExpressionO8contains5whereS3bXE_tFSbACXEfU_S2bXEfU_012$s23closure_b7_loop10d44O8contains5whereS3bXE_tFSbACXEfU_S2bXEfU_012g13_b7_loop10d44ijk2_tlm2U_no52U_012g30_B34_loop12IndirectEnumVACycfcnO10U_Tf3npf_nY2_nTf3npf_n
41+
// ---> function signature specialization
42+
// <Arg[1] = [Constant Propagated Function : function signature specialization
43+
// <Arg[1] = [Constant Propagated Function : function signature specialization
44+
// <Arg[1] = [Constant Propagated Function : closure #1 (Swift.Bool) -> Swift.Bool
45+
// in closure_specialize_loop.IndirectEnum.init() -> closure_specialize_loop.IndirectEnum]>
46+
// of closure #1 (Swift.Bool) -> Swift.Bool
47+
// in closure #1 (closure_specialize_loop.Expression) -> Swift.Bool
48+
// in closure_specialize_loop.Expression.contains(where: (Swift.Bool) -> Swift.Bool) -> Swift.Bool]>
49+
// of closure #1 (Swift.Bool) -> Swift.Bool
50+
// in closure #1 (closure_specialize_loop.Expression) -> Swift.Bool
51+
// in closure_specialize_loop.Expression.contains(where: (Swift.Bool) -> Swift.Bool) -> Swift.Bool]>
52+
// of closure #1 (Swift.Bool) -> Swift.Bool
53+
// in closure #1 (closure_specialize_loop.Expression) -> Swift.Bool
54+
// in closure_specialize_loop.Expression.contains(where: (Swift.Bool) -> Swift.Bool) -> Swift.Bool
55+
//
56+
public indirect enum Expression {
57+
case term(Bool)
58+
case list(_ expressions: [Expression])
59+
60+
public func contains(where predicate: (Bool) -> Bool) -> Bool {
61+
switch self {
62+
case let .term(term):
63+
return predicate(term)
64+
case let .list(expressions):
65+
return expressions.contains { expression in
66+
expression.contains { term in
67+
predicate(term)
68+
}
69+
}
70+
}
71+
}
72+
}
73+
74+
public struct IndirectEnum {
75+
public init() {
76+
let containsFalse = Expression.list([.list([.term(true), .term(false)]), .term(true)]).contains { term in
77+
term == false
78+
}
79+
print(containsFalse)
80+
}
81+
}

0 commit comments

Comments
 (0)