Skip to content

Fix infinite recursion in ClosureSpecialize #61956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 49 additions & 19 deletions lib/SILOptimizer/IPO/ClosureSpecializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,8 @@ static bool canSpecializeFullApplySite(FullApplySiteKind kind) {
llvm_unreachable("covered switch");
}

const int SpecializationLevelLimit = 2;

static int getSpecializationLevelRecursive(StringRef funcName, Demangler &parent) {
using namespace Demangle;

Expand All @@ -1098,23 +1100,44 @@ static int getSpecializationLevelRecursive(StringRef funcName, Demangler &parent
return 0;
if (funcSpec->getKind() != Node::Kind::FunctionSignatureSpecialization)
return 0;
Node *param = funcSpec->getChild(1);
if (param->getKind() != Node::Kind::FunctionSignatureSpecializationParam)
return 0;
if (param->getNumChildren() < 2)
return 0;
Node *kindNd = param->getChild(0);
if (kindNd->getKind() != Node::Kind::FunctionSignatureSpecializationParamKind)
return 0;
auto kind = FunctionSigSpecializationParamKind(kindNd->getIndex());
if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction)
return 0;

Node *payload = param->getChild(1);
if (payload->getKind() != Node::Kind::FunctionSignatureSpecializationParamPayload)
return 1;
// Check if the specialized function is a specialization itself.
return 1 + getSpecializationLevelRecursive(payload->getText(), demangler);

// Match any function specialization. We check for constant propagation at the
// parameter level.
Node *param = funcSpec->getChild(0);
if (param->getKind() != Node::Kind::SpecializationPassID)
return SpecializationLevelLimit + 1; // unrecognized format

unsigned maxParamLevel = 0;
for (unsigned paramIdx = 1; paramIdx < funcSpec->getNumChildren();
++paramIdx) {
Node *param = funcSpec->getChild(paramIdx);
if (param->getKind() != Node::Kind::FunctionSignatureSpecializationParam)
return SpecializationLevelLimit + 1; // unrecognized format

// A parameter is recursive if it has a kind with index and type payload
if (param->getNumChildren() < 2)
continue;

Node *kindNd = param->getChild(0);
if (kindNd->getKind()
!= Node::Kind::FunctionSignatureSpecializationParamKind) {
return SpecializationLevelLimit + 1; // unrecognized format
}
auto kind = FunctionSigSpecializationParamKind(kindNd->getIndex());
if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction)
continue;
Node *payload = param->getChild(1);
if (payload->getKind()
!= Node::Kind::FunctionSignatureSpecializationParamPayload) {
return SpecializationLevelLimit + 1; // unrecognized format
}
// Check if the specialized function is a specialization itself.
unsigned paramLevel =
1 + getSpecializationLevelRecursive(payload->getText(), demangler);
if (paramLevel > maxParamLevel)
maxParamLevel = paramLevel;
}
return maxParamLevel;
}

/// If \p function is a function-signature specialization for a constant-
Expand Down Expand Up @@ -1328,9 +1351,10 @@ bool SILClosureSpecializerTransform::gatherCallSites(
//
// A limit of 2 is good enough and will not be exceed in "regular"
// optimization scenarios.
if (getSpecializationLevel(getClosureCallee(ClosureInst)) > 2)
if (getSpecializationLevel(getClosureCallee(ClosureInst))
> SpecializationLevelLimit) {
continue;

}
// Compute the final release points of the closure. We will insert
// release of the captured arguments here.
if (!CInfo)
Expand Down Expand Up @@ -1395,6 +1419,8 @@ bool SILClosureSpecializerTransform::specialize(SILFunction *Caller,
if (!NewF) {
NewF = ClosureSpecCloner::cloneFunction(FuncBuilder, CSDesc, NewFName);
addFunctionToPassManagerWorklist(NewF, CSDesc.getApplyCallee());
LLVM_DEBUG(llvm::dbgs() << "\nThe rewritten callee is:\n";
NewF->dump());
}

// Rewrite the call
Expand All @@ -1404,6 +1430,10 @@ bool SILClosureSpecializerTransform::specialize(SILFunction *Caller,
Changed = true;
}
}
LLVM_DEBUG(if (Changed) {
llvm::dbgs() << "\nThe rewritten caller is:\n";
Caller->dump();
});
return Changed;
}

Expand Down
65 changes: 65 additions & 0 deletions test/SILOptimizer/closure_specialize_loop.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,68 @@ public func testit(c: @escaping () -> Bool) {
}
}

// PR: https://github.com/apple/swift/pull/61956
// Optimizing Expression.contains(where:) should not timeout.
//
// Repeated capture propagation leads to:
// func contains$termPred@arg0$[termPred$falsePred@arg1]@arg1(expr) {
// closure = termPred$[termPred$falsePred@arg1]@arg1
// falsePred(expr)
// contains$termPred@arg0$termPred$[termPred$falsePred@arg1]@arg1(expr)
// }
//
// func contains$termPred@arg0$termPred$[termPred$falsePred@arg1]@arg1(expr) {
// closure = [termPred(termPred$[termPred$falsePred@arg1]@arg1)]
// closure(expr)
// contains$termPred@arg0(expr, closure)
// }
// The Demangled type tree looks like:
// kind=FunctionSignatureSpecialization
// kind=SpecializationPassID, index=3
// kind=FunctionSignatureSpecializationParam
// kind=FunctionSignatureSpecializationParam
// kind=FunctionSignatureSpecializationParamKind, index=0
// kind=FunctionSignatureSpecializationParamPayload, text="$s4test10ExpressionO8contains5whereS3bXE_tFSbACXEfU_S2bXEfU_36$s4test12IndirectEnumVACycfcS2bXEfU_Tf3npf_n"
//
// CHECK-LABEL: $s23closure_specialize_loop10ExpressionO8contains5whereS3bXE_tFSbACXEfU_S2bXEfU_012$s23closure_b7_loop10d44O8contains5whereS3bXE_tFSbACXEfU_S2bXEfU_012g13_b7_loop10d44ijk2_tlm2U_no52U_012g30_B34_loop12IndirectEnumVACycfcnO10U_Tf3npf_nY2_nTf3npf_n
// ---> function signature specialization
// <Arg[1] = [Constant Propagated Function : function signature specialization
// <Arg[1] = [Constant Propagated Function : function signature specialization
// <Arg[1] = [Constant Propagated Function : closure #1 (Swift.Bool) -> Swift.Bool
// in closure_specialize_loop.IndirectEnum.init() -> closure_specialize_loop.IndirectEnum]>
// of closure #1 (Swift.Bool) -> Swift.Bool
// in closure #1 (closure_specialize_loop.Expression) -> Swift.Bool
// in closure_specialize_loop.Expression.contains(where: (Swift.Bool) -> Swift.Bool) -> Swift.Bool]>
// of closure #1 (Swift.Bool) -> Swift.Bool
// in closure #1 (closure_specialize_loop.Expression) -> Swift.Bool
// in closure_specialize_loop.Expression.contains(where: (Swift.Bool) -> Swift.Bool) -> Swift.Bool]>
// of closure #1 (Swift.Bool) -> Swift.Bool
// in closure #1 (closure_specialize_loop.Expression) -> Swift.Bool
// in closure_specialize_loop.Expression.contains(where: (Swift.Bool) -> Swift.Bool) -> Swift.Bool
//
public indirect enum Expression {
case term(Bool)
case list(_ expressions: [Expression])

public func contains(where predicate: (Bool) -> Bool) -> Bool {
switch self {
case let .term(term):
return predicate(term)
case let .list(expressions):
return expressions.contains { expression in
expression.contains { term in
predicate(term)
}
}
}
}
}

public struct IndirectEnum {
public init() {
let containsFalse = Expression.list([.list([.term(true), .term(false)]), .term(true)]).contains { term in
term == false
}
print(containsFalse)
}
}