Skip to content

Commit fbad9b6

Browse files
authored
Merge pull request #35774 from gottesmm/pr-aed4f36ace8254f8ea9b635f05a8c903fd964553
[capture-promotion] Emit a warning when the pass fails to promote a capture of a concurrent function to a by value capture instead of by ref capture.
2 parents 06dc593 + fa19cc9 commit fbad9b6

9 files changed

+642
-16
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,5 +645,15 @@ NOTE(box_to_stack_cannot_promote_box_to_stack_due_to_escape_location, none,
645645

646646
WARNING(semantic_function_improper_nesting, none, "'@_semantics' function calls non-'@_semantics' function with nested '@_semantics' calls", ())
647647

648+
// Capture promotion diagnostics
649+
WARNING(capturepromotion_concurrentcapture_mutation, none,
650+
"'%0' mutated after capture by concurrent closure", (StringRef))
651+
NOTE(capturepromotion_concurrentcapture_closure_here, none,
652+
"variable captured by concurrent closure", ())
653+
NOTE(capturepromotion_concurrentcapture_capturinguse_here, none,
654+
"capturing use", ())
655+
NOTE(capturepromotion_variable_defined_here,none,
656+
"variable defined here", ())
657+
648658
#define UNDEFINE_DIAGNOSTIC_MACROS
649659
#include "DefineDiagnosticMacros.h"

include/swift/SIL/SILType.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,13 @@ class SILType {
565565
/// Returns true if this SILType is a differentiable type.
566566
bool isDifferentiable(SILModule &M) const;
567567

568+
/// If this is a SILBoxType, return getSILBoxFieldType(). Otherwise, return
569+
/// SILType().
570+
///
571+
/// \p field Return the type of the ith field of the box. Default set to 0
572+
/// since we only support one field today. This is just future proofing.
573+
SILType getSILBoxFieldType(const SILFunction *f, unsigned field = 0);
574+
568575
/// Returns the hash code for the SILType.
569576
llvm::hash_code getHashCode() const {
570577
return llvm::hash_combine(*this);

lib/SIL/IR/SILType.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,11 @@ bool SILType::isEffectivelyExhaustiveEnumType(SILFunction *f) {
699699
return decl->isEffectivelyExhaustive(f->getModule().getSwiftModule(),
700700
f->getResilienceExpansion());
701701
}
702+
703+
SILType SILType::getSILBoxFieldType(const SILFunction *f, unsigned field) {
704+
auto *boxTy = getASTType()->getAs<SILBoxType>();
705+
if (!boxTy)
706+
return SILType();
707+
return ::getSILBoxFieldType(f->getTypeExpansionContext(), boxTy,
708+
f->getModule().Types, field);
709+
}

lib/SIL/Verifier/SILVerifier.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,13 +1671,24 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
16711671
SILFunctionConventions substConv(substTy, F.getModule());
16721672
unsigned appliedArgStartIdx =
16731673
substConv.getNumSILArguments() - PAI->getNumArguments();
1674-
for (unsigned i = 0, size = PAI->getArguments().size(); i < size; ++i) {
1674+
bool isConcurrentAndStageIsCanonical =
1675+
PAI->getFunctionType()->isConcurrent() &&
1676+
F.getModule().getStage() >= SILStage::Canonical;
1677+
for (auto p : llvm::enumerate(PAI->getArguments())) {
16751678
requireSameType(
1676-
PAI->getArguments()[i]->getType(),
1677-
substConv.getSILArgumentType(appliedArgStartIdx + i,
1679+
p.value()->getType(),
1680+
substConv.getSILArgumentType(appliedArgStartIdx + p.index(),
16781681
F.getTypeExpansionContext()),
16791682
"applied argument types do not match suffix of function type's "
16801683
"inputs");
1684+
1685+
// TODO: Expand this to also be true for address only types.
1686+
if (isConcurrentAndStageIsCanonical)
1687+
require(
1688+
!p.value()->getType().getASTType()->is<SILBoxType>() ||
1689+
p.value()->getType().getSILBoxFieldType(&F).isAddressOnly(F),
1690+
"Concurrent partial apply in canonical SIL with a loadable box "
1691+
"type argument?!");
16811692
}
16821693

16831694
// The arguments to the result function type must match the prefix of the

lib/SILOptimizer/Mandatory/CapturePromotion.cpp

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2014 - 2021 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
@@ -44,7 +44,9 @@
4444

4545
#define DEBUG_TYPE "sil-capture-promotion"
4646

47+
#include "swift/AST/DiagnosticsSIL.h"
4748
#include "swift/AST/GenericEnvironment.h"
49+
#include "swift/Basic/FrozenMultiMap.h"
4850
#include "swift/SIL/SILCloner.h"
4951
#include "swift/SIL/SILInstruction.h"
5052
#include "swift/SIL/TypeSubstCloner.h"
@@ -56,6 +58,7 @@
5658
#include "llvm/ADT/SmallSet.h"
5759
#include "llvm/ADT/Statistic.h"
5860
#include "llvm/Support/Debug.h"
61+
#include "llvm/Support/ErrorHandling.h"
5962
#include <tuple>
6063

6164
using namespace swift;
@@ -765,6 +768,15 @@ struct EscapeMutationScanningState {
765768
/// found.
766769
SmallVector<Operand *, 8> accumulatedEscapes;
767770

771+
/// A multimap that maps partial applies to the set of operands in the partial
772+
/// applies referenced function that the pass has identified as being the use
773+
/// that caused the partial apply to capture our box.
774+
///
775+
/// We use a frozen multi-map since our algorithm first accumulates this info
776+
/// and then wants to use it, perfect for the 2-stage frozen multi map.
777+
SmallFrozenMultiMap<PartialApplyInst *, Operand *, 16>
778+
accumulatedCaptureCausingUses;
779+
768780
/// A flag that we use to ensure that we only ever see 1 project_box on an
769781
/// alloc_box.
770782
bool sawProjectBoxInst;
@@ -797,15 +809,17 @@ static bool isNonMutatingLoad(SILInstruction *inst) {
797809
/// address of the box's contents), return true if this box has mutating
798810
/// captures. Return false otherwise. All of the mutating captures that we find
799811
/// are placed into \p accumulatedMutatingUses.
800-
static bool getPartialApplyArgMutationsAndEscapes(
801-
SILArgument *boxArg, SmallVectorImpl<Operand *> &accumulatedMutatingUses,
802-
SmallVectorImpl<Operand *> &accumulatedEscapes) {
812+
static bool
813+
getPartialApplyArgMutationsAndEscapes(PartialApplyInst *pai,
814+
SILArgument *boxArg,
815+
EscapeMutationScanningState &state) {
803816
SmallVector<ProjectBoxInst *, 2> projectBoxInsts;
804817

805818
// Conservatively do not allow any use of the box argument other than a
806819
// strong_release or projection, since this is the pattern expected from
807820
// SILGen.
808821
SmallVector<Operand *, 32> incrementalEscapes;
822+
SmallVector<Operand *, 32> incrementalCaptureCausingUses;
809823
for (auto *use : boxArg->getUses()) {
810824
if (isa<StrongReleaseInst>(use->getUser()) ||
811825
isa<DestroyValueInst>(use->getUser()))
@@ -827,18 +841,25 @@ static bool getPartialApplyArgMutationsAndEscapes(
827841
// function that mirrors isNonEscapingUse.
828842
auto checkIfAddrUseMutating = [&](Operand *addrUse) -> bool {
829843
unsigned initSize = incrementalEscapes.size();
830-
auto *addrInst = addrUse->getUser();
831-
if (auto *seai = dyn_cast<StructElementAddrInst>(addrInst)) {
844+
auto *addrUser = addrUse->getUser();
845+
if (auto *seai = dyn_cast<StructElementAddrInst>(addrUser)) {
832846
for (auto *seaiUse : seai->getUses()) {
833-
if (!isNonMutatingLoad(seaiUse->getUser())) {
847+
if (isNonMutatingLoad(seaiUse->getUser())) {
848+
incrementalCaptureCausingUses.push_back(seaiUse);
849+
} else {
834850
incrementalEscapes.push_back(seaiUse);
835851
}
836852
}
837853
return incrementalEscapes.size() != initSize;
838854
}
839855

840-
if (isNonMutatingLoad(addrInst) || isa<DebugValueAddrInst>(addrInst) ||
841-
isa<MarkFunctionEscapeInst>(addrInst) || isa<EndAccessInst>(addrInst)) {
856+
if (isNonMutatingLoad(addrUser)) {
857+
incrementalCaptureCausingUses.push_back(addrUse);
858+
return false;
859+
}
860+
861+
if (isa<DebugValueAddrInst>(addrUser) ||
862+
isa<MarkFunctionEscapeInst>(addrUser) || isa<EndAccessInst>(addrUser)) {
842863
return false;
843864
}
844865

@@ -859,10 +880,15 @@ static bool getPartialApplyArgMutationsAndEscapes(
859880
}
860881
}
861882

883+
auto &accCaptureCausingUses = state.accumulatedCaptureCausingUses;
884+
while (!incrementalCaptureCausingUses.empty())
885+
accCaptureCausingUses.insert(pai,
886+
incrementalCaptureCausingUses.pop_back_val());
887+
862888
if (incrementalEscapes.empty())
863889
return false;
864890
while (!incrementalEscapes.empty())
865-
accumulatedEscapes.push_back(incrementalEscapes.pop_back_val());
891+
state.accumulatedEscapes.push_back(incrementalEscapes.pop_back_val());
866892
return true;
867893
}
868894

@@ -925,8 +951,7 @@ bool isPartialApplyNonEscapingUser(Operand *currentOp, PartialApplyInst *pai,
925951
// Verify that this closure is known not to mutate the captured value; if
926952
// it does, then conservatively refuse to promote any captures of this
927953
// value.
928-
if (getPartialApplyArgMutationsAndEscapes(boxArg, state.accumulatedMutations,
929-
state.accumulatedEscapes)) {
954+
if (getPartialApplyArgMutationsAndEscapes(pai, boxArg, state)) {
930955
LLVM_DEBUG(llvm::dbgs() << " FAIL: Have a mutation or escape of a "
931956
"partial apply arg?!\n");
932957
return false;
@@ -1195,6 +1220,57 @@ static bool findEscapeOrMutationUses(Operand *op,
11951220
return isNonEscapingUse(op, state);
11961221
}
11971222

1223+
/// We found a capture of \p abi in concurrent closure \p pai that we can not
1224+
/// promote to a by value capture. Emit a nice warning (FIXME: error) to warn
1225+
/// the user and provide the following information in the compiler feedback:
1226+
///
1227+
/// 1. The source loc where the variable's box is written to.
1228+
///
1229+
/// 2. The source loc of the captured variable's declaration.
1230+
///
1231+
/// 3. The source loc of the start of the concurrent closure that caused the
1232+
/// variable to be captured.
1233+
///
1234+
/// 4. All places in the concurrent closure that triggered the box's
1235+
/// capture. NOTE: For objects these are load points. For address only things
1236+
/// it is still open for debate at this point.
1237+
static void diagnoseInvalidCaptureByConcurrentClosure(
1238+
AllocBoxInst *abi, PartialApplyInst *pai,
1239+
const EscapeMutationScanningState &state, SILInstruction *mutatingUser) {
1240+
auto captureCausingUses = state.accumulatedCaptureCausingUses.find(pai);
1241+
if (!captureCausingUses) {
1242+
llvm::errs() << "Didn't find capture causing use of partial apply: "
1243+
<< *pai;
1244+
llvm::errs() << "Original Func: " << pai->getFunction()->getName() << '\n';
1245+
llvm::errs() << "Partial Applied Func: "
1246+
<< pai->getReferencedFunctionOrNull()->getName() << '\n';
1247+
llvm::report_fatal_error("standard compiler error");
1248+
}
1249+
1250+
auto &astCtx = pai->getFunction()->getASTContext();
1251+
auto &de = astCtx.Diags;
1252+
auto varInfo = abi->getVarInfo();
1253+
StringRef name = "<unknown>";
1254+
if (varInfo) {
1255+
name = varInfo->Name;
1256+
}
1257+
1258+
de.diagnoseWithNotes(
1259+
de.diagnose(mutatingUser->getLoc().getSourceLoc(),
1260+
diag::capturepromotion_concurrentcapture_mutation, name),
1261+
[&]() {
1262+
de.diagnose(abi->getLoc().getSourceLoc(),
1263+
diag::capturepromotion_variable_defined_here);
1264+
de.diagnose(pai->getLoc().getSourceLoc(),
1265+
diag::capturepromotion_concurrentcapture_closure_here);
1266+
for (auto *use : *captureCausingUses) {
1267+
de.diagnose(
1268+
use->getUser()->getLoc().getSourceLoc(),
1269+
diag::capturepromotion_concurrentcapture_capturinguse_here);
1270+
}
1271+
});
1272+
}
1273+
11981274
/// Examine an alloc_box instruction, returning true if at least one
11991275
/// capture of the boxed variable is promotable. If so, then the pair of the
12001276
/// partial_apply instruction and the index of the box argument in the closure's
@@ -1203,7 +1279,7 @@ static bool
12031279
examineAllocBoxInst(AllocBoxInst *abi, ReachabilityInfo &ri,
12041280
llvm::DenseMap<PartialApplyInst *, unsigned> &im) {
12051281
LLVM_DEBUG(llvm::dbgs() << "Visiting alloc box: " << *abi);
1206-
EscapeMutationScanningState state{{}, {}, false, im};
1282+
EscapeMutationScanningState state{{}, {}, {}, false, im};
12071283

12081284
// Scan the box for escaping or mutating uses.
12091285
for (auto *use : abi->getUses()) {
@@ -1220,6 +1296,7 @@ examineAllocBoxInst(AllocBoxInst *abi, ReachabilityInfo &ri,
12201296
return false;
12211297
}
12221298

1299+
state.accumulatedCaptureCausingUses.setFrozen();
12231300
LLVM_DEBUG(llvm::dbgs() << "We can optimize this alloc box!\n");
12241301

12251302
// Helper lambda function to determine if instruction b is strictly after
@@ -1249,6 +1326,13 @@ examineAllocBoxInst(AllocBoxInst *abi, ReachabilityInfo &ri,
12491326
// block is after the partial_apply.
12501327
if (ri.isReachable(pai->getParent(), user->getParent()) ||
12511328
(pai->getParent() == user->getParent() && isAfter(pai, user))) {
1329+
// If our partial apply is concurrent and we can not promote this, emit
1330+
// a warning that shows the variable, where the variable is captured,
1331+
// and the mutation that we found.
1332+
if (pai->getFunctionType()->isConcurrent()) {
1333+
diagnoseInvalidCaptureByConcurrentClosure(abi, pai, state, user);
1334+
}
1335+
12521336
LLVM_DEBUG(llvm::dbgs() << " Invalidating: " << *pai);
12531337
LLVM_DEBUG(llvm::dbgs() << " Because of user: " << *user);
12541338
auto prev = iter++;
@@ -1257,6 +1341,7 @@ examineAllocBoxInst(AllocBoxInst *abi, ReachabilityInfo &ri,
12571341
}
12581342
++iter;
12591343
}
1344+
12601345
// If there are no valid captures left, then stop.
12611346
if (im.empty()) {
12621347
LLVM_DEBUG(llvm::dbgs() << " Ran out of valid captures... bailing!\n");

0 commit comments

Comments
 (0)