@@ -97,6 +97,15 @@ struct TermArgSources {
97
97
98
98
} // namespace
99
99
100
+ // / Used for generating informative diagnostics.
101
+ static Expr *getExprForPartitionOp (const PartitionOp &op) {
102
+ SILInstruction *sourceInstr = op.getSourceInst (/* assertNonNull=*/ true );
103
+ Expr *expr = sourceInstr->getLoc ().getAsASTNode <Expr>();
104
+ assert (expr && " PartitionOp's source location should correspond to"
105
+ " an AST node" );
106
+ return expr;
107
+ }
108
+
100
109
// ===----------------------------------------------------------------------===//
101
110
// MARK: Main Computation
102
111
// ===----------------------------------------------------------------------===//
@@ -1300,52 +1309,99 @@ class ConsumeRequireAccumulator {
1300
1309
std::map<PartitionOp, std::set<PartitionOpAtDistance>>
1301
1310
requirementsForConsumptions;
1302
1311
1312
+ SILFunction *fn;
1313
+
1303
1314
public:
1304
- ConsumeRequireAccumulator () {}
1315
+ ConsumeRequireAccumulator (SILFunction *fn) : fn(fn ) {}
1305
1316
1306
1317
void accumulateConsumedReason (PartitionOp requireOp, const ConsumedReason &consumedReason) {
1307
1318
for (auto [distance, consumeOps] : consumedReason.consumeOps )
1308
1319
for (auto consumeOp : consumeOps)
1309
1320
requirementsForConsumptions[consumeOp].insert ({requireOp, distance});
1310
1321
}
1311
1322
1312
- // for each consumption in this ConsumeRequireAccumulator, call the passed
1313
- // processConsumeOp closure on it, followed immediately by calling the passed
1314
- // processRequireOp closure on the top `numRequiresPerConsume` operations
1315
- // that access ("require") the region consumed. Sorting is by lowest distance
1316
- // first, then arbitrarily. This is used for final diagnostic output.
1317
- void forEachConsumeRequire (
1318
- llvm::function_ref<void (const PartitionOp& consumeOp, unsigned numProcessed, unsigned numSkipped)>
1319
- processConsumeOp,
1320
- llvm::function_ref<void(const PartitionOp& requireOp)>
1321
- processRequireOp,
1322
- unsigned numRequiresPerConsume = UINT_MAX) const {
1323
+ void
1324
+ emitErrorsForConsumeRequire (unsigned numRequiresPerConsume = UINT_MAX) const {
1323
1325
for (auto [consumeOp, requireOps] : requirementsForConsumptions) {
1324
1326
unsigned numProcessed = std::min ({(unsigned ) requireOps.size (),
1325
1327
(unsigned ) numRequiresPerConsume});
1326
- processConsumeOp (consumeOp, numProcessed, requireOps.size () - numProcessed);
1328
+
1329
+ // First process our consume ops.
1330
+ unsigned numDisplayed = numProcessed;
1331
+ unsigned numHidden = requireOps.size () - numProcessed;
1332
+ if (!tryDiagnoseAsCallSite (consumeOp, numDisplayed, numHidden)) {
1333
+ assert (false && " no consumptions besides callsites implemented yet" );
1334
+
1335
+ // default to more generic diagnostic
1336
+ auto expr = getExprForPartitionOp (consumeOp);
1337
+ auto diag = fn->getASTContext ().Diags .diagnose (
1338
+ expr->getLoc (), diag::consumption_yields_race, numDisplayed,
1339
+ numDisplayed != 1 , numHidden > 0 , numHidden);
1340
+ if (auto sourceExpr = consumeOp.getSourceExpr ())
1341
+ diag.highlight (sourceExpr->getSourceRange ());
1342
+ return ;
1343
+ }
1344
+
1327
1345
unsigned numRequiresToProcess = numRequiresPerConsume;
1328
1346
for (auto [requireOp, _] : requireOps) {
1329
1347
// ensures at most numRequiresPerConsume requires are processed per consume
1330
- if (numRequiresToProcess-- == 0 ) break ;
1331
- processRequireOp (requireOp);
1348
+ if (numRequiresToProcess-- == 0 )
1349
+ break ;
1350
+ auto expr = getExprForPartitionOp (requireOp);
1351
+ fn->getASTContext ()
1352
+ .Diags .diagnose (expr->getLoc (), diag::possible_racy_access_site)
1353
+ .highlight (expr->getSourceRange ());
1332
1354
}
1333
1355
}
1334
1356
}
1335
1357
1336
1358
SWIFT_DEBUG_DUMP { print (llvm::dbgs ()); }
1337
1359
1338
1360
void print (llvm::raw_ostream &os) const {
1339
- forEachConsumeRequire (
1340
- [&](const PartitionOp &consumeOp, unsigned numProcessed,
1341
- unsigned numSkipped) {
1342
- os << " ┌──╼ CONSUME: " ;
1343
- consumeOp.print (os);
1344
- },
1345
- [&](const PartitionOp &requireOp) {
1346
- os << " ├╼ REQUIRE: " ;
1347
- requireOp.print (os);
1348
- });
1361
+ for (auto [consumeOp, requireOps] : requirementsForConsumptions) {
1362
+ os << " ┌──╼ CONSUME: " ;
1363
+ consumeOp.print (os);
1364
+
1365
+ for (auto &[requireOp, _] : requireOps) {
1366
+ os << " ├╼ REQUIRE: " ;
1367
+ requireOp.print (os);
1368
+ }
1369
+ }
1370
+ }
1371
+
1372
+ private:
1373
+ // / Try to interpret this consumeOp as a source-level callsite (ApplyExpr),
1374
+ // / and report a diagnostic including actor isolation crossing information
1375
+ // / returns true iff one was succesfully formed and emitted.
1376
+ bool tryDiagnoseAsCallSite (const PartitionOp &consumeOp,
1377
+ unsigned numDisplayed, unsigned numHidden) const {
1378
+ SILInstruction *sourceInst =
1379
+ consumeOp.getSourceInst (/* assertNonNull=*/ true );
1380
+ ApplyExpr *apply = sourceInst->getLoc ().getAsASTNode <ApplyExpr>();
1381
+ if (!apply)
1382
+ // consumption does not correspond to an apply expression
1383
+ return false ;
1384
+ auto isolationCrossing = apply->getIsolationCrossing ();
1385
+ if (!isolationCrossing) {
1386
+ assert (false && " ApplyExprs should be consuming only if"
1387
+ " they are isolation crossing" );
1388
+ return false ;
1389
+ }
1390
+ auto argExpr = consumeOp.getSourceExpr ();
1391
+ if (!argExpr)
1392
+ assert (false &&
1393
+ " sourceExpr should be populated for ApplyExpr consumptions" );
1394
+
1395
+ sourceInst->getFunction ()
1396
+ ->getASTContext ()
1397
+ .Diags
1398
+ .diagnose (argExpr->getLoc (), diag::call_site_consumption_yields_race,
1399
+ argExpr->findOriginalType (),
1400
+ isolationCrossing.value ().getCallerIsolation (),
1401
+ isolationCrossing.value ().getCalleeIsolation (), numDisplayed,
1402
+ numDisplayed != 1 , numHidden > 0 , numHidden)
1403
+ .highlight (argExpr->getSourceRange ());
1404
+ return true ;
1349
1405
}
1350
1406
};
1351
1407
@@ -1560,8 +1616,9 @@ class RaceTracer {
1560
1616
}
1561
1617
1562
1618
public:
1563
- RaceTracer (const BasicBlockData<BlockPartitionState>& blockStates)
1564
- : blockStates(blockStates) {}
1619
+ RaceTracer (SILFunction *fn,
1620
+ const BasicBlockData<BlockPartitionState> &blockStates)
1621
+ : blockStates(blockStates), accumulator(fn) {}
1565
1622
1566
1623
void traceUseOfConsumedValue (PartitionOp use, TrackableValueID consumedVal) {
1567
1624
accumulator.accumulateConsumedReason (
@@ -1601,9 +1658,7 @@ class PartitionAnalysis {
1601
1658
[this ](SILBasicBlock *block) {
1602
1659
return BlockPartitionState (block, translator);
1603
1660
}),
1604
- raceTracer (blockStates),
1605
- function(fn),
1606
- solved(false ) {
1661
+ raceTracer (fn, blockStates), function(fn), solved(false ) {
1607
1662
// initialize the entry block as needing an update, and having a partition
1608
1663
// that places all its non-sendable args in a single region
1609
1664
blockStates[fn->getEntryBlock ()].needsUpdate = true ;
@@ -1692,15 +1747,6 @@ class PartitionAnalysis {
1692
1747
return false ;
1693
1748
}
1694
1749
1695
- // used for generating informative diagnostics
1696
- Expr *getExprForPartitionOp (const PartitionOp& op) {
1697
- SILInstruction *sourceInstr = op.getSourceInst (/* assertNonNull=*/ true );
1698
- Expr *expr = sourceInstr->getLoc ().getAsASTNode <Expr>();
1699
- assert (expr && " PartitionOp's source location should correspond to"
1700
- " an AST node" );
1701
- return expr;
1702
- }
1703
-
1704
1750
// once the fixpoint has been solved for, run one more pass over each basic
1705
1751
// block, reporting any failures due to requiring consumed regions in the
1706
1752
// fixpoint state
@@ -1710,7 +1756,7 @@ class PartitionAnalysis {
1710
1756
LLVM_DEBUG (
1711
1757
llvm::dbgs () << " Emitting diagnostics for function "
1712
1758
<< function->getName () << " \n " );
1713
- RaceTracer tracer = blockStates;
1759
+ RaceTracer tracer (function, blockStates) ;
1714
1760
1715
1761
for (auto [_, blockState] : blockStates) {
1716
1762
// populate the raceTracer with all requires of consumed valued found
@@ -1738,40 +1784,12 @@ class PartitionAnalysis {
1738
1784
LLVM_DEBUG (llvm::dbgs () << " Accumulator Complete:\n " ;
1739
1785
raceTracer.getAccumulator ().print (llvm::dbgs ()););
1740
1786
1741
- // ask the raceTracer to report diagnostics at the consumption sites
1742
- // for all the racy requirement sites entered into it above
1743
- raceTracer.getAccumulator ().forEachConsumeRequire (
1744
- /* diagnoseConsume=*/
1745
- [&](const PartitionOp& consumeOp,
1746
- unsigned numDisplayed, unsigned numHidden) {
1747
-
1748
- if (tryDiagnoseAsCallSite (consumeOp, numDisplayed, numHidden))
1749
- return ;
1750
-
1751
- assert (false && " no consumptions besides callsites implemented yet" );
1752
-
1753
- // default to more generic diagnostic
1754
- auto expr = getExprForPartitionOp (consumeOp);
1755
- auto diag = function->getASTContext ().Diags .diagnose (
1756
- expr->getLoc (), diag::consumption_yields_race,
1757
- numDisplayed, numDisplayed != 1 , numHidden > 0 , numHidden);
1758
- if (auto sourceExpr = consumeOp.getSourceExpr ())
1759
- diag.highlight (sourceExpr->getSourceRange ());
1760
- },
1761
-
1762
- /* diagnoseRequire=*/
1763
- [&](const PartitionOp& requireOp) {
1764
- auto expr = getExprForPartitionOp (requireOp);
1765
- function->getASTContext ().Diags .diagnose (
1766
- expr->getLoc (), diag::possible_racy_access_site)
1767
- .highlight (expr->getSourceRange ());
1768
- },
1787
+ // Ask the raceTracer to report diagnostics at the consumption sites for all
1788
+ // the racy requirement sites entered into it above.
1789
+ raceTracer.getAccumulator ().emitErrorsForConsumeRequire (
1769
1790
NUM_REQUIREMENTS_TO_DIAGNOSE);
1770
1791
}
1771
1792
1772
- // try to interpret this consumeOp as a source-level callsite (ApplyExpr),
1773
- // and report a diagnostic including actor isolation crossing information
1774
- // returns true iff one was succesfully formed and emitted
1775
1793
bool tryDiagnoseAsCallSite (
1776
1794
const PartitionOp& consumeOp, unsigned numDisplayed, unsigned numHidden) {
1777
1795
SILInstruction *sourceInst = consumeOp.getSourceInst (/* assertNonNull=*/ true );
0 commit comments