@@ -1032,6 +1032,20 @@ class ADContext {
1032
1032
return cachedPlusFn;
1033
1033
}
1034
1034
1035
+ void clearTask (DifferentiationTask *task) {
1036
+ LLVM_DEBUG (getADDebugStream () << " Clearing differentiation task for "
1037
+ << task->original ->getName () << ' \n ' );
1038
+ transform.notifyWillDeleteFunction (task->primal );
1039
+ module .eraseFunction (task->primal );
1040
+ transform.notifyWillDeleteFunction (task->adjoint );
1041
+ module .eraseFunction (task->adjoint );
1042
+ transform.notifyWillDeleteFunction (task->jvp );
1043
+ module .eraseFunction (task->jvp );
1044
+ transform.notifyWillDeleteFunction (task->vjp );
1045
+ module .eraseFunction (task->vjp );
1046
+ task->original ->removeDifferentiableAttr (task->attr );
1047
+ }
1048
+
1035
1049
// / Retrieves the file unit that contains implicit declarations in the
1036
1050
// / current Swift module. If it does not exist, create one.
1037
1051
// /
@@ -1220,10 +1234,13 @@ void ADContext::emitNondifferentiabilityError(SILInstruction *inst,
1220
1234
// Location of the instruction.
1221
1235
auto opLoc = inst->getLoc ().getSourceLoc ();
1222
1236
auto invoker = task->getInvoker ();
1223
- LLVM_DEBUG (getADDebugStream ()
1224
- << " Diagnosing non-differentiability for value \n\t " << *inst
1225
- << " \n "
1226
- << " while performing differentiation task\n\t " << task << ' \n ' );
1237
+ LLVM_DEBUG ({
1238
+ auto &s = getADDebugStream ()
1239
+ << " Diagnosing non-differentiability for value \n\t " << *inst
1240
+ << " \n while performing differentiation task\n\t " ;
1241
+ task->print (s);
1242
+ s << ' \n ' ;
1243
+ });
1227
1244
switch (invoker.getKind ()) {
1228
1245
// For a `autodiff_function` instruction or a `[differentiable]` attribute
1229
1246
// that is not associated with any source location, we emit a diagnostic at
@@ -1407,6 +1424,11 @@ class DifferentiableActivityInfo {
1407
1424
// / Perform analysis and populate sets.
1408
1425
void analyze (DominanceInfo *di, PostDominanceInfo *pdi);
1409
1426
1427
+ void setVariedIfDifferentiable (SILValue value,
1428
+ unsigned independentVariableIndex);
1429
+ void setUsefulIfDifferentiable (SILValue value,
1430
+ unsigned dependentVariableIndex);
1431
+
1410
1432
public:
1411
1433
explicit DifferentiableActivityInfo (SILFunction &f,
1412
1434
DominanceInfo *di,
@@ -1459,9 +1481,8 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
1459
1481
<< " Running activity analysis on @" << function.getName () << ' \n ' );
1460
1482
// Inputs are just function's arguments, count `n`.
1461
1483
auto paramArgs = function.getArgumentsWithoutIndirectResults ();
1462
- for (auto valueAndIndex : enumerate(paramArgs)) {
1463
- inputValues.push_back (valueAndIndex.first );
1464
- }
1484
+ for (auto value : paramArgs)
1485
+ inputValues.push_back (value);
1465
1486
LLVM_DEBUG ({
1466
1487
auto &s = getADDebugStream ();
1467
1488
s << " Inputs in @" << function.getName () << " :\n " ;
@@ -1477,39 +1498,111 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
1477
1498
s << val << ' \n ' ;
1478
1499
});
1479
1500
1501
+ auto &module = function.getModule ();
1480
1502
// Mark inputs as varied.
1481
1503
assert (variedValueSets.empty ());
1482
- for (auto input : inputValues)
1483
- variedValueSets.push_back ({input});
1504
+ for (auto input : inputValues) {
1505
+ variedValueSets.push_back ({});
1506
+ if (input->getType ().isDifferentiable (module ))
1507
+ variedValueSets.back ().insert (input);
1508
+ }
1484
1509
// Propagate varied-ness through the function in dominance order.
1485
1510
DominanceOrder domOrder (function.getEntryBlock (), di);
1486
1511
while (auto *block = domOrder.getNext ()) {
1487
- for (auto &inst : *block)
1488
- for (auto &op : inst.getAllOperands ())
1489
- for (auto i : indices (inputValues))
1490
- if (isVaried (op.get (), i))
1491
- for (auto result : inst.getResults ())
1492
- variedValueSets[i].insert (result);
1512
+ for (auto &inst : *block) {
1513
+ for (auto i : indices (inputValues)) {
1514
+ // Handle `apply`.
1515
+ if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
1516
+ for (auto arg : ai->getArgumentsWithoutIndirectResults ()) {
1517
+ if (isVaried (arg, i)) {
1518
+ for (auto indRes : ai->getIndirectSILResults ())
1519
+ setVariedIfDifferentiable (indRes, i);
1520
+ for (auto dirRes : ai->getResults ())
1521
+ setVariedIfDifferentiable (dirRes, i);
1522
+ }
1523
+ }
1524
+ }
1525
+ // Handle `store`.
1526
+ else if (auto *si = dyn_cast<StoreInst>(&inst)) {
1527
+ if (isVaried (si->getSrc (), i))
1528
+ setVariedIfDifferentiable (si->getDest (), i);
1529
+ }
1530
+ // Handle everything else.
1531
+ else {
1532
+ for (auto &op : inst.getAllOperands ())
1533
+ if (isVaried (op.get (), i))
1534
+ for (auto result : inst.getResults ())
1535
+ setVariedIfDifferentiable (result, i);
1536
+ }
1537
+ }
1538
+ }
1493
1539
domOrder.pushChildren (block);
1494
1540
}
1495
1541
1496
- // Mark outputs as useful.
1542
+ // Mark differentiable outputs as useful.
1497
1543
assert (usefulValueSets.empty ());
1498
- for (auto output : outputValues)
1499
- usefulValueSets.push_back ({output});
1544
+ for (auto output : outputValues) {
1545
+ usefulValueSets.push_back ({});
1546
+ if (output->getType ().isDifferentiable (module ))
1547
+ usefulValueSets.back ().insert (output);
1548
+ }
1500
1549
// Propagate usefulness through the function in post-dominance order.
1501
1550
PostDominanceOrder postDomOrder (&*function.findReturnBB (), pdi);
1502
1551
while (auto *block = postDomOrder.getNext ()) {
1503
- for (auto &inst : reversed (*block))
1504
- for (auto result : inst.getResults ())
1505
- for (auto i : indices (outputValues))
1506
- if (isUseful (result, i))
1507
- for (auto &op : inst.getAllOperands ())
1508
- usefulValueSets[i].insert (op.get ());
1552
+ for (auto &inst : reversed (*block)) {
1553
+ for (auto i : indices (outputValues)) {
1554
+ // Handle indirect results in `apply`.
1555
+ if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
1556
+ auto checkAndSetUseful = [&](SILValue res) {
1557
+ if (isUseful (res, i))
1558
+ for (auto arg : ai->getArgumentsWithoutIndirectResults ())
1559
+ setUsefulIfDifferentiable (arg, i);
1560
+ };
1561
+ for (auto dirRes : ai->getResults ())
1562
+ checkAndSetUseful (dirRes);
1563
+ for (auto indRes : ai->getIndirectSILResults ())
1564
+ checkAndSetUseful (indRes);
1565
+ }
1566
+ // Handle `store`.
1567
+ else if (auto *si = dyn_cast<StoreInst>(&inst)) {
1568
+ if (isUseful (si->getDest (), i))
1569
+ setUsefulIfDifferentiable (si->getSrc (), i);
1570
+ }
1571
+ // Handle side-effecting operations.
1572
+ else if (inst.mayHaveSideEffects ()) {
1573
+ for (auto &op : inst.getAllOperands ())
1574
+ if (op.get ()->getType ().isAddress ())
1575
+ setUsefulIfDifferentiable (op.get (), i);
1576
+ for (auto result : inst.getResults ())
1577
+ setUsefulIfDifferentiable (result, i);
1578
+ }
1579
+ // Handle everything else.
1580
+ else {
1581
+ for (auto result : inst.getResults ())
1582
+ if (isUseful (result, i))
1583
+ for (auto &op : inst.getAllOperands ())
1584
+ setUsefulIfDifferentiable (op.get (), i);
1585
+ }
1586
+ }
1587
+ }
1509
1588
postDomOrder.pushChildren (block);
1510
1589
}
1511
1590
}
1512
1591
1592
+ void DifferentiableActivityInfo::setVariedIfDifferentiable (
1593
+ SILValue value, unsigned independentVariableIndex) {
1594
+ if (!value->getType ().isDifferentiable (function.getModule ()))
1595
+ return ;
1596
+ variedValueSets[independentVariableIndex].insert (value);
1597
+ }
1598
+
1599
+ void DifferentiableActivityInfo::setUsefulIfDifferentiable (
1600
+ SILValue value, unsigned dependentVariableIndex) {
1601
+ if (!value->getType ().isDifferentiable (function.getModule ()))
1602
+ return ;
1603
+ usefulValueSets[dependentVariableIndex].insert (value);
1604
+ }
1605
+
1513
1606
bool DifferentiableActivityInfo::isIndependent (
1514
1607
SILValue value, const SILAutoDiffIndices &indices) const {
1515
1608
for (auto paramIdx : indices.parameters .set_bits ())
@@ -2181,11 +2274,8 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2181
2274
// Clone.
2182
2275
cloneFunctionBody (original, entry, entryArgs);
2183
2276
// If errors occurred, back out.
2184
- if (errorOccurred) {
2185
- // Delete the body so that later passes don't get confused by invalid SIL.
2186
- getPrimal ()->getBlocks ().clear ();
2277
+ if (errorOccurred)
2187
2278
return true ;
2188
- }
2189
2279
auto *origExit = &*original->findReturnBB ();
2190
2280
auto *exit = BBMap.lookup (origExit);
2191
2281
assert (exit->getParent () == getPrimal ());
@@ -2249,7 +2339,15 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2249
2339
return ;
2250
2340
SILClonerWithScopes::visit (inst);
2251
2341
}
2252
-
2342
+
2343
+ void visitSILInstruction (SILInstruction *inst) {
2344
+ // TODO: Change this to a note when we emit an error at the @autodiff
2345
+ // function conversion location.
2346
+ getContext ().emitNondifferentiabilityError (inst, getDifferentiationTask (),
2347
+ diag::autodiff_expression_is_not_differentiable_error);
2348
+ errorOccurred = true ;
2349
+ }
2350
+
2253
2351
void visitReturnInst (ReturnInst *ri) {
2254
2352
// The original return is not to be cloned.
2255
2353
return ;
@@ -2702,6 +2800,7 @@ bool PrimalGen::performSynthesis(FunctionSynthesisItem item) {
2702
2800
// Synthesize primal.
2703
2801
PrimalGenCloner cloner (item, activityInfo, domInfo, pdomInfo, loopInfo, *this ,
2704
2802
context);
2803
+ // Run the cloner.
2705
2804
return cloner.run ();
2706
2805
}
2707
2806
@@ -2724,7 +2823,10 @@ bool PrimalGen::run() {
2724
2823
while (!worklist.empty ()) {
2725
2824
auto synthesis = worklist.back ();
2726
2825
worklist.pop_back ();
2727
- errorOccurred |= performSynthesis (synthesis);
2826
+ if (performSynthesis (synthesis)) {
2827
+ context.clearTask (synthesis.task );
2828
+ errorOccurred = true ;
2829
+ }
2728
2830
synthesis.task ->getPrimalInfo ()->computePrimalValueStructType ();
2729
2831
synthesis.task ->setPrimalSynthesisState (FunctionSynthesisState::Done);
2730
2832
}
@@ -2780,7 +2882,10 @@ bool AdjointGen::run() {
2780
2882
while (!worklist.empty ()) {
2781
2883
auto synthesis = worklist.back ();
2782
2884
worklist.pop_back ();
2783
- errorOccurred |= performSynthesis (synthesis);
2885
+ if (performSynthesis (synthesis)) {
2886
+ context.clearTask (synthesis.task );
2887
+ errorOccurred = true ;
2888
+ }
2784
2889
synthesis.task ->setAdjointSynthesisState (FunctionSynthesisState::Done);
2785
2890
}
2786
2891
return errorOccurred;
@@ -3243,6 +3348,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3243
3348
continue ;
3244
3349
// Differentiate instruction.
3245
3350
visit (&inst);
3351
+ if (errorOccurred)
3352
+ return true ;
3246
3353
}
3247
3354
}
3248
3355
@@ -3322,7 +3429,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3322
3429
}
3323
3430
3324
3431
void visitSILInstruction (SILInstruction *inst) {
3325
- llvm_unreachable (" Unsupport instruction visited" );
3432
+ // TODO: Change this to a note when we emit an error at the @autodiff
3433
+ // function conversion location.
3434
+ getContext ().emitNondifferentiabilityError (inst, getDifferentiationTask (),
3435
+ diag::autodiff_expression_is_not_differentiable_error);
3436
+ errorOccurred = true ;
3326
3437
}
3327
3438
3328
3439
SILLocation remapLocation (SILLocation loc) { return loc; }
@@ -4326,6 +4437,7 @@ bool AdjointGen::performSynthesis(FunctionSynthesisItem item) {
4326
4437
*domAnalysis->get (item.original ),
4327
4438
*pdomAnalysis->get (item.original ),
4328
4439
*loopAnalysis->get (item.original ), *this );
4440
+ // Run the adjoint emitter.
4329
4441
return emitter.run ();
4330
4442
}
4331
4443
0 commit comments