Skip to content

Commit 2905ee1

Browse files
committed
[capture-propagation] Support generic partial_apply instructions
1 parent 65091d6 commit 2905ee1

File tree

2 files changed

+162
-30
lines changed

2 files changed

+162
-30
lines changed

lib/SILOptimizer/IPO/CapturePropagation.cpp

Lines changed: 90 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#define DEBUG_TYPE "capture-prop"
14+
#include "swift/AST/GenericEnvironment.h"
1415
#include "swift/SILOptimizer/PassManager/Passes.h"
16+
#include "swift/SILOptimizer/Utils/Generics.h"
1517
#include "swift/SILOptimizer/Utils/SpecializationMangler.h"
1618
#include "swift/Demangling/Demangle.h"
1719
#include "swift/SIL/Mangle.h"
1820
#include "swift/SIL/SILCloner.h"
1921
#include "swift/SIL/SILInstruction.h"
22+
#include "swift/SIL/TypeSubstCloner.h"
2023
#include "swift/SILOptimizer/Analysis/ColdBlockInfo.h"
2124
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
2225
#include "swift/SILOptimizer/PassManager/Transforms.h"
@@ -99,16 +102,17 @@ namespace {
99102
/// caller, so the cloned function will have a mix of locations from different
100103
/// functions.
101104
class CapturePropagationCloner
102-
: public SILClonerWithScopes<CapturePropagationCloner> {
103-
using SuperTy = SILClonerWithScopes<CapturePropagationCloner>;
105+
: public TypeSubstCloner<CapturePropagationCloner> {
106+
using SuperTy = TypeSubstCloner<CapturePropagationCloner>;
104107
friend class SILVisitor<CapturePropagationCloner>;
105108
friend class SILCloner<CapturePropagationCloner>;
106109

107110
SILFunction *OrigF;
108111
bool IsCloningConstant;
109112
public:
110-
CapturePropagationCloner(SILFunction *OrigF, SILFunction *NewF)
111-
: SuperTy(*NewF), OrigF(OrigF), IsCloningConstant(false) {}
113+
CapturePropagationCloner(SILFunction *OrigF, SILFunction *NewF,
114+
SubstitutionList Subs)
115+
: SuperTy(*NewF, *OrigF, Subs), OrigF(OrigF), IsCloningConstant(false) {}
112116

113117
void cloneBlocks(OperandValueArrayRef Args);
114118

@@ -219,6 +223,20 @@ void CapturePropagationCloner::cloneBlocks(
219223
}
220224
}
221225

226+
CanSILFunctionType getPartialApplyInterfaceResultType(PartialApplyInst *PAI) {
227+
SILFunction *OrigF = PAI->getReferencedFunction();
228+
// The new partial_apply will no longer take any arguments--they are all
229+
// expressed as literals. So its callee signature will be the same as its
230+
// return signature.
231+
auto FTy = PAI->getType().castTo<SILFunctionType>();
232+
CanGenericSignature CanGenericSig;
233+
assert(!PAI->hasSubstitutions() || !hasArchetypes(PAI->getSubstitutions()));
234+
FTy = cast<SILFunctionType>(
235+
OrigF->mapTypeOutOfContext(FTy)->getCanonicalType());
236+
auto NewFTy = FTy;
237+
return NewFTy;
238+
}
239+
222240
/// Given a partial_apply instruction, create a specialized callee by removing
223241
/// all constant arguments and adding constant literals to the specialized
224242
/// function body.
@@ -243,12 +261,16 @@ SILFunction *CapturePropagation::specializeConstClosure(PartialApplyInst *PAI,
243261
// The new partial_apply will no longer take any arguments--they are all
244262
// expressed as literals. So its callee signature will be the same as its
245263
// return signature.
246-
CanSILFunctionType NewFTy =
247-
Lowering::adjustFunctionType(PAI->getType().castTo<SILFunctionType>(),
248-
SILFunctionType::Representation::Thin);
264+
auto NewFTy = getPartialApplyInterfaceResultType(PAI);
265+
NewFTy = Lowering::adjustFunctionType(NewFTy,
266+
SILFunctionType::Representation::Thin);
267+
268+
GenericEnvironment *GenericEnv = nullptr;
269+
if (NewFTy->getGenericSignature())
270+
GenericEnv = OrigF->getGenericEnvironment();
249271
SILFunction *NewF = OrigF->getModule().createFunction(
250272
SILLinkage::Shared, Name, NewFTy,
251-
OrigF->getGenericEnvironment(), OrigF->getLocation(), OrigF->isBare(),
273+
GenericEnv, OrigF->getLocation(), OrigF->isBare(),
252274
OrigF->isTransparent(), Fragile, OrigF->isThunk(),
253275
OrigF->getClassVisibility(), OrigF->getInlineStrategy(),
254276
OrigF->getEffectsKind(),
@@ -259,18 +281,28 @@ SILFunction *CapturePropagation::specializeConstClosure(PartialApplyInst *PAI,
259281
DEBUG(llvm::dbgs() << " Specialize callee as ";
260282
NewF->printName(llvm::dbgs()); llvm::dbgs() << " " << NewFTy << "\n");
261283

262-
CapturePropagationCloner cloner(OrigF, NewF);
284+
DEBUG(if (PAI->hasSubstitutions()) {
285+
llvm::dbgs() << "CapturePropagation of generic partial_apply:\n";
286+
PAI->dumpInContext();
287+
});
288+
CapturePropagationCloner cloner(OrigF, NewF, PAI->getSubstitutions());
263289
cloner.cloneBlocks(PAI->getArguments());
264290
assert(OrigF->getDebugScope()->Parent != NewF->getDebugScope()->Parent);
265291
return NewF;
266292
}
267293

268294
void CapturePropagation::rewritePartialApply(PartialApplyInst *OrigPAI,
269295
SILFunction *SpecialF) {
296+
DEBUG(llvm::dbgs() << "\n Rewriting a partial apply:\n";
297+
OrigPAI->dumpInContext(); llvm::dbgs() << " with special function: "
298+
<< SpecialF->getName() << "\n";
299+
llvm::dbgs() << "\nThe function being rewritten is:\n";
300+
OrigPAI->getFunction()->dump());
301+
270302
SILBuilderWithScope Builder(OrigPAI);
271303
auto FuncRef = Builder.createFunctionRef(OrigPAI->getLoc(), SpecialF);
272-
auto *T2TF = Builder.createThinToThickFunction(OrigPAI->getLoc(),
273-
FuncRef, OrigPAI->getType());
304+
auto *T2TF = Builder.createThinToThickFunction(OrigPAI->getLoc(), FuncRef,
305+
OrigPAI->getType());
274306
OrigPAI->replaceAllUsesWith(T2TF);
275307
recursivelyDeleteTriviallyDeadInstructions(OrigPAI, true);
276308
DEBUG(llvm::dbgs() << " Rewrote caller:\n" << *T2TF);
@@ -311,12 +343,13 @@ static bool onlyContainsReturnOrThrowOfArg(SILBasicBlock *BB) {
311343

312344
/// Checks if \p Orig is a thunk which calls another function but without
313345
/// passing the trailing \p numDeadParams dead parameters.
314-
static SILFunction *getSpecializedWithDeadParams(SILFunction *Orig,
315-
int numDeadParams) {
346+
static SILFunction *getSpecializedWithDeadParams(
347+
PartialApplyInst *PAI, SILFunction *Orig, int numDeadParams,
348+
std::pair<SILFunction *, SILFunction *> &GenericSpecialized) {
316349
SILBasicBlock &EntryBB = *Orig->begin();
317350
unsigned NumArgs = EntryBB.getNumArguments();
318351
SILModule &M = Orig->getModule();
319-
352+
320353
// Check if all dead parameters have trivial types. We don't support non-
321354
// trivial types because it's very hard to find places where we can release
322355
// those parameters (as a replacement for the removed partial_apply).
@@ -328,20 +361,20 @@ static SILFunction *getSpecializedWithDeadParams(SILFunction *Orig,
328361
}
329362
SILFunction *Specialized = nullptr;
330363
SILValue RetValue;
331-
364+
332365
// Check all instruction of the entry block.
333366
for (SILInstruction &I : EntryBB) {
334367
if (auto FAS = FullApplySite::isa(&I)) {
335-
336368
// Check if this is the call of the specialized function.
337-
// As the original function is not generic, also the specialized function
338-
// must be not generic.
339-
if (FAS.hasSubstitutions())
369+
// If the original partial_apply didn't have substitutions,
370+
// also the specialized function must be not generic.
371+
if (!PAI->hasSubstitutions() && FAS.hasSubstitutions())
340372
return nullptr;
373+
341374
// Is it the only call?
342375
if (Specialized)
343376
return nullptr;
344-
377+
345378
Specialized = FAS.getReferencedFunction();
346379
if (!Specialized)
347380
return nullptr;
@@ -376,29 +409,55 @@ static SILFunction *getSpecializedWithDeadParams(SILFunction *Orig,
376409
if (I.mayHaveSideEffects() || isa<TermInst>(&I))
377410
return nullptr;
378411
}
412+
413+
GenericSpecialized = std::make_pair(nullptr, nullptr);
414+
415+
if (PAI->hasSubstitutions()) {
416+
if (Specialized->isExternalDeclaration())
417+
return nullptr;
418+
// Perform a generic specialization of the Specialized function.
419+
ReabstractionInfo ReInfo(ApplySite(), Specialized, PAI->getSubstitutions(),
420+
/* ConvertIndirectToDirect */ false);
421+
GenericFuncSpecializer FuncSpecializer(Specialized,
422+
ReInfo.getClonerParamSubstitutions(),
423+
Specialized->isFragile(), ReInfo);
424+
425+
SILFunction *GenericSpecializedFunc = FuncSpecializer.trySpecialization();
426+
if (!GenericSpecializedFunc)
427+
return nullptr;
428+
GenericSpecialized = std::make_pair(GenericSpecializedFunc, Specialized);
429+
return GenericSpecializedFunc;
430+
}
379431
return Specialized;
380432
}
381433

382434
bool CapturePropagation::optimizePartialApply(PartialApplyInst *PAI) {
383-
// Check if the partial_apply has generic substitutions.
384-
// FIXME: We could handle generic thunks if it's worthwhile.
385-
if (PAI->hasSubstitutions())
386-
return false;
387-
388435
SILFunction *SubstF = PAI->getReferencedFunction();
389436
if (!SubstF)
390437
return false;
391438
if (SubstF->isExternalDeclaration())
392439
return false;
393440

394-
assert(!SubstF->getLoweredFunctionType()->isPolymorphic() &&
395-
"cannot specialize generic partial apply");
441+
if (PAI->hasSubstitutions() && hasArchetypes(PAI->getSubstitutions())) {
442+
DEBUG(llvm::dbgs()
443+
<< "CapturePropagation: cannot handle partial specialization "
444+
"of partial_apply:\n";
445+
PAI->dumpInContext());
446+
return false;
447+
}
448+
396449

397450
// First possibility: Is it a partial_apply where all partially applied
398451
// arguments are dead?
399-
if (SILFunction *NewFunc = getSpecializedWithDeadParams(SubstF,
400-
PAI->getNumArguments())) {
452+
std::pair<SILFunction *, SILFunction *> GenericSpecialized;
453+
if (auto *NewFunc = getSpecializedWithDeadParams(
454+
PAI, SubstF, PAI->getNumArguments(), GenericSpecialized)) {
401455
rewritePartialApply(PAI, NewFunc);
456+
if (GenericSpecialized.first) {
457+
// Notify the pass manager about the new function.
458+
notifyPassManagerOfFunction(GenericSpecialized.first,
459+
GenericSpecialized.second);
460+
}
402461
return true;
403462
}
404463

@@ -411,7 +470,8 @@ bool CapturePropagation::optimizePartialApply(PartialApplyInst *PAI) {
411470
return false;
412471

413472
DEBUG(llvm::dbgs() << "Specializing closure for constant arguments:\n"
414-
<< " " << SubstF->getName() << "\n" << *PAI);
473+
<< " " << SubstF->getName() << "\n"
474+
<< *PAI);
415475
++NumCapturesPropagated;
416476
SILFunction *NewF = specializeConstClosure(PAI, SubstF);
417477
rewritePartialApply(PAI, NewF);

test/SILOptimizer/capture_propagation.sil

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,75 @@ bb0:
378378
return %2 : $@callee_owned (Int32, Int32) -> (Bool, @error Error)
379379
}
380380

381+
// Test generic capture propagation
382+
383+
sil @_TFtest_generic_capture_propagation2_closure : $@convention(thin) <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T) -> () {
384+
bb0(%0 : $Builtin.Int32, %1 : $Builtin.FPIEEE32, %2 : $Builtin.RawPointer, %3 : $*T):
385+
%9999 = tuple()
386+
return %9999 : $()
387+
}
388+
389+
// CHECK-LABEL: sil @test_generic_capture_propagation2_caller
390+
// CHECK: %[[CALLEE:[0-9]+]] = function_ref @test_generic_capture_propagation2_callee
391+
// CHECK: %[[FR:[0-9]+]] = function_ref @{{.*}}test_generic_capture_propagation2_thunk : $@convention(thin) () -> ()
392+
// CHECK: %[[CONVERTED:[0-9]+]] = thin_to_thick_function %[[FR]] : $@convention(thin) () -> () to $@callee_owned () -> ()
393+
// CHECK-NOT: partial_apply
394+
// CHECK: apply %[[CALLEE]](%[[CONVERTED]]) : $@convention(thin) (@owned @callee_owned () -> ()) -> ()
395+
// CHECL-NOT: partial_apply
396+
// CHECK: return
397+
sil @test_generic_capture_propagation2_caller : $@convention(thin) () -> () {
398+
%0 = integer_literal $Builtin.Int32, 0
399+
%1 = float_literal $Builtin.FPIEEE32, 0
400+
%2 = string_literal utf8 "123"
401+
%3 = global_addr @globalinit_33_06E7F1D906492AE070936A9B58CBAE1C_token8 : $*Builtin.Word
402+
%4 = function_ref @_TFtest_generic_capture_propagation2_closure : $@convention(thin) <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T) -> ()
403+
%5 = thin_to_thick_function %4 : $@convention(thin) <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T) -> () to $@callee_owned <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T) -> ()
404+
%6 = function_ref @test_generic_capture_propagation2_callee : $@convention(thin) (@owned @callee_owned () -> ()) -> ()
405+
%7 = function_ref @test_generic_capture_propagation2_thunk : $@convention(thin) <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T, @owned @callee_owned <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T) -> ()) -> ()
406+
%8 = partial_apply %7<Builtin.Word>(%0, %1, %2, %3, %5) : $@convention(thin) <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T, @owned @callee_owned <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T) -> ()) -> ()
407+
apply %6(%8) : $@convention(thin) (@owned @callee_owned () -> ()) -> ()
408+
%9999 = tuple()
409+
return %9999 : $()
410+
}
411+
412+
sil shared @test_generic_capture_propagation2_thunk : $@convention(thin) <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T, @owned @callee_owned <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T) -> ()) -> () {
413+
bb0(%0 : $Builtin.Int32, %1 : $Builtin.FPIEEE32, %2 : $Builtin.RawPointer, %3 : $*T, %4 : $@callee_owned <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T) -> ()):
414+
apply %4<T>(%0, %1, %2, %3) : $@callee_owned <T> (Builtin.Int32, Builtin.FPIEEE32, Builtin.RawPointer, @in T) -> ()
415+
%9999 = tuple()
416+
return %9999 : $()
417+
}
418+
419+
sil shared @test_generic_capture_propagation2_callee : $@convention(thin) (@owned @callee_owned () -> ()) -> () {
420+
bb0(%0 : $@callee_owned () -> ()):
421+
apply %0() : $@callee_owned () -> ()
422+
%9999 = tuple()
423+
return %9999 : $()
424+
}
425+
426+
// Test dead partial applied arguments when using generics
427+
428+
sil @specialized_generic_nonthrowing_closure : $@convention(thin) <T> (@in T, @in T) -> Bool {
429+
bb0(%0 : $*T, %1 : $*T):
430+
%10 = integer_literal $Builtin.Int1, -1
431+
%9999 = struct $Bool (%10 : $Builtin.Int1)
432+
return %9999 : $Bool
433+
}
434+
435+
sil @nonthrowing_generic_closure : $@convention(method) <T> (@in T, @in T, @thin T.Type) -> Bool {
436+
bb0(%0 : $*T, %1 : $*T, %2 : $@thin T.Type):
437+
%3 = function_ref @specialized_generic_nonthrowing_closure : $@convention(thin) <T> (@in T, @in T) -> Bool
438+
%4 = apply %3<T>(%0, %1) : $@convention(thin) <T> (@in T, @in T) -> Bool
439+
return %4 : $Bool
440+
}
441+
442+
// CHECK-LABEL: sil @return_generic_nonthrowing_closure
443+
// CHECK: [[F:%[0-9]+]] = function_ref @_TTSg5Vs5Int32__specialized_generic_nonthrowing_closure
444+
// CHECK: [[R:%[0-9]+]] = thin_to_thick_function [[F]]
445+
// CHECK: return [[R]]
446+
sil @return_generic_nonthrowing_closure : $@convention(thin) () -> @owned @callee_owned (@in Int32, @in Int32) -> Bool {
447+
bb0:
448+
%0 = metatype $@thin Int32.Type
449+
%1 = function_ref @nonthrowing_generic_closure : $@convention(method) <T> (@in T, @in T, @thin T.Type) -> Bool
450+
%2 = partial_apply %1<Int32>(%0) : $@convention(method) <T>(@in T, @in T, @thin T.Type) -> Bool
451+
return %2 : $@callee_owned (@in Int32, @in Int32) -> Bool
452+
}

0 commit comments

Comments
 (0)