Skip to content

Commit 63b3d0b

Browse files
authored
[AutoDiff] Make activity analysis work on side effecting code. (#21677)
1 parent e15142c commit 63b3d0b

File tree

6 files changed

+173
-32
lines changed

6 files changed

+173
-32
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ ERROR(autodiff_function_not_differentiable,none,
376376
"function is not differentiable", ())
377377
ERROR(autodiff_property_not_differentiable,none,
378378
"property is not differentiable", ())
379+
ERROR(autodiff_expression_is_not_differentiable_error,none,
380+
"expression is not differentiable", ())
379381
NOTE(autodiff_function_generic_functions_unsupported,none,
380382
"differentiating generic functions is not supported yet", ())
381383
NOTE(autodiff_value_defined_here,none,

include/swift/SIL/SILFunction.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,10 @@ class SILFunction
697697

698698
void addDifferentiableAttr(SILDifferentiableAttr *attr);
699699

700+
void removeDifferentiableAttr(SILDifferentiableAttr *attr) {
701+
std::remove(DifferentiableAttrs.begin(), DifferentiableAttrs.end(), attr);
702+
}
703+
700704
/// Get this function's optimization mode or OptimizationMode::NotSet if it is
701705
/// not set for this specific function.
702706
OptimizationMode getOptimizationMode() const { return OptMode; }

include/swift/SIL/SILType.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,10 @@ class SILType {
497497
/// formal type. Meant for verification purposes/assertions.
498498
bool isLoweringOf(SILModule &M, CanType formalType);
499499

500+
// SWIFT_ENABLE_TENSORFLOW
501+
/// Returns true if this SILType is a differentiable type.
502+
bool isDifferentiable(SILModule &M) const;
503+
500504
/// Returns the hash code for the SILType.
501505
llvm::hash_code getHashCode() const {
502506
return llvm::hash_combine(*this);

lib/SIL/SILType.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,11 @@ bool SILType::isLoweringOf(SILModule &Mod, CanType formalType) {
598598
// Other types are preserved through lowering.
599599
return loweredType.getASTType() == formalType;
600600
}
601+
602+
// SWIFT_ENABLE_TENSORFLOW
603+
/// Returns true if this SILType is a differentiable type.
604+
bool SILType::isDifferentiable(SILModule &M) const {
605+
return getASTType()->getAutoDiffAssociatedVectorSpace(
606+
AutoDiffAssociatedVectorSpaceKind::Tangent,
607+
LookUpConformanceInModule(M.getSwiftModule())).hasValue();
608+
}

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 144 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,20 @@ class ADContext {
10321032
return cachedPlusFn;
10331033
}
10341034

1035+
void clearTask(DifferentiationTask *task) {
1036+
LLVM_DEBUG(getADDebugStream() << "Clearing differentiation task for "
1037+
<< task->original->getName() << '\n');
1038+
transform.notifyWillDeleteFunction(task->primal);
1039+
module.eraseFunction(task->primal);
1040+
transform.notifyWillDeleteFunction(task->adjoint);
1041+
module.eraseFunction(task->adjoint);
1042+
transform.notifyWillDeleteFunction(task->jvp);
1043+
module.eraseFunction(task->jvp);
1044+
transform.notifyWillDeleteFunction(task->vjp);
1045+
module.eraseFunction(task->vjp);
1046+
task->original->removeDifferentiableAttr(task->attr);
1047+
}
1048+
10351049
/// Retrieves the file unit that contains implicit declarations in the
10361050
/// current Swift module. If it does not exist, create one.
10371051
///
@@ -1220,10 +1234,13 @@ void ADContext::emitNondifferentiabilityError(SILInstruction *inst,
12201234
// Location of the instruction.
12211235
auto opLoc = inst->getLoc().getSourceLoc();
12221236
auto invoker = task->getInvoker();
1223-
LLVM_DEBUG(getADDebugStream()
1224-
<< "Diagnosing non-differentiability for value \n\t" << *inst
1225-
<< "\n"
1226-
<< "while performing differentiation task\n\t" << task << '\n');
1237+
LLVM_DEBUG({
1238+
auto &s = getADDebugStream()
1239+
<< "Diagnosing non-differentiability for value \n\t" << *inst
1240+
<< "\nwhile performing differentiation task\n\t";
1241+
task->print(s);
1242+
s << '\n';
1243+
});
12271244
switch (invoker.getKind()) {
12281245
// For a `autodiff_function` instruction or a `[differentiable]` attribute
12291246
// that is not associated with any source location, we emit a diagnostic at
@@ -1407,6 +1424,11 @@ class DifferentiableActivityInfo {
14071424
/// Perform analysis and populate sets.
14081425
void analyze(DominanceInfo *di, PostDominanceInfo *pdi);
14091426

1427+
void setVariedIfDifferentiable(SILValue value,
1428+
unsigned independentVariableIndex);
1429+
void setUsefulIfDifferentiable(SILValue value,
1430+
unsigned dependentVariableIndex);
1431+
14101432
public:
14111433
explicit DifferentiableActivityInfo(SILFunction &f,
14121434
DominanceInfo *di,
@@ -1459,9 +1481,8 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
14591481
<< "Running activity analysis on @" << function.getName() << '\n');
14601482
// Inputs are just function's arguments, count `n`.
14611483
auto paramArgs = function.getArgumentsWithoutIndirectResults();
1462-
for (auto valueAndIndex : enumerate(paramArgs)) {
1463-
inputValues.push_back(valueAndIndex.first);
1464-
}
1484+
for (auto value : paramArgs)
1485+
inputValues.push_back(value);
14651486
LLVM_DEBUG({
14661487
auto &s = getADDebugStream();
14671488
s << "Inputs in @" << function.getName() << ":\n";
@@ -1477,39 +1498,111 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
14771498
s << val << '\n';
14781499
});
14791500

1501+
auto &module = function.getModule();
14801502
// Mark inputs as varied.
14811503
assert(variedValueSets.empty());
1482-
for (auto input : inputValues)
1483-
variedValueSets.push_back({input});
1504+
for (auto input : inputValues) {
1505+
variedValueSets.push_back({});
1506+
if (input->getType().isDifferentiable(module))
1507+
variedValueSets.back().insert(input);
1508+
}
14841509
// Propagate varied-ness through the function in dominance order.
14851510
DominanceOrder domOrder(function.getEntryBlock(), di);
14861511
while (auto *block = domOrder.getNext()) {
1487-
for (auto &inst : *block)
1488-
for (auto &op : inst.getAllOperands())
1489-
for (auto i : indices(inputValues))
1490-
if (isVaried(op.get(), i))
1491-
for (auto result : inst.getResults())
1492-
variedValueSets[i].insert(result);
1512+
for (auto &inst : *block) {
1513+
for (auto i : indices(inputValues)) {
1514+
// Handle `apply`.
1515+
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
1516+
for (auto arg : ai->getArgumentsWithoutIndirectResults()) {
1517+
if (isVaried(arg, i)) {
1518+
for (auto indRes : ai->getIndirectSILResults())
1519+
setVariedIfDifferentiable(indRes, i);
1520+
for (auto dirRes : ai->getResults())
1521+
setVariedIfDifferentiable(dirRes, i);
1522+
}
1523+
}
1524+
}
1525+
// Handle `store`.
1526+
else if (auto *si = dyn_cast<StoreInst>(&inst)) {
1527+
if (isVaried(si->getSrc(), i))
1528+
setVariedIfDifferentiable(si->getDest(), i);
1529+
}
1530+
// Handle everything else.
1531+
else {
1532+
for (auto &op : inst.getAllOperands())
1533+
if (isVaried(op.get(), i))
1534+
for (auto result : inst.getResults())
1535+
setVariedIfDifferentiable(result, i);
1536+
}
1537+
}
1538+
}
14931539
domOrder.pushChildren(block);
14941540
}
14951541

1496-
// Mark outputs as useful.
1542+
// Mark differentiable outputs as useful.
14971543
assert(usefulValueSets.empty());
1498-
for (auto output : outputValues)
1499-
usefulValueSets.push_back({output});
1544+
for (auto output : outputValues) {
1545+
usefulValueSets.push_back({});
1546+
if (output->getType().isDifferentiable(module))
1547+
usefulValueSets.back().insert(output);
1548+
}
15001549
// Propagate usefulness through the function in post-dominance order.
15011550
PostDominanceOrder postDomOrder(&*function.findReturnBB(), pdi);
15021551
while (auto *block = postDomOrder.getNext()) {
1503-
for (auto &inst : reversed(*block))
1504-
for (auto result : inst.getResults())
1505-
for (auto i : indices(outputValues))
1506-
if (isUseful(result, i))
1507-
for (auto &op : inst.getAllOperands())
1508-
usefulValueSets[i].insert(op.get());
1552+
for (auto &inst : reversed(*block)) {
1553+
for (auto i : indices(outputValues)) {
1554+
// Handle indirect results in `apply`.
1555+
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
1556+
auto checkAndSetUseful = [&](SILValue res) {
1557+
if (isUseful(res, i))
1558+
for (auto arg : ai->getArgumentsWithoutIndirectResults())
1559+
setUsefulIfDifferentiable(arg, i);
1560+
};
1561+
for (auto dirRes : ai->getResults())
1562+
checkAndSetUseful(dirRes);
1563+
for (auto indRes : ai->getIndirectSILResults())
1564+
checkAndSetUseful(indRes);
1565+
}
1566+
// Handle `store`.
1567+
else if (auto *si = dyn_cast<StoreInst>(&inst)) {
1568+
if (isUseful(si->getDest(), i))
1569+
setUsefulIfDifferentiable(si->getSrc(), i);
1570+
}
1571+
// Handle side-effecting operations.
1572+
else if (inst.mayHaveSideEffects()) {
1573+
for (auto &op : inst.getAllOperands())
1574+
if (op.get()->getType().isAddress())
1575+
setUsefulIfDifferentiable(op.get(), i);
1576+
for (auto result : inst.getResults())
1577+
setUsefulIfDifferentiable(result, i);
1578+
}
1579+
// Handle everything else.
1580+
else {
1581+
for (auto result : inst.getResults())
1582+
if (isUseful(result, i))
1583+
for (auto &op : inst.getAllOperands())
1584+
setUsefulIfDifferentiable(op.get(), i);
1585+
}
1586+
}
1587+
}
15091588
postDomOrder.pushChildren(block);
15101589
}
15111590
}
15121591

1592+
void DifferentiableActivityInfo::setVariedIfDifferentiable(
1593+
SILValue value, unsigned independentVariableIndex) {
1594+
if (!value->getType().isDifferentiable(function.getModule()))
1595+
return;
1596+
variedValueSets[independentVariableIndex].insert(value);
1597+
}
1598+
1599+
void DifferentiableActivityInfo::setUsefulIfDifferentiable(
1600+
SILValue value, unsigned dependentVariableIndex) {
1601+
if (!value->getType().isDifferentiable(function.getModule()))
1602+
return;
1603+
usefulValueSets[dependentVariableIndex].insert(value);
1604+
}
1605+
15131606
bool DifferentiableActivityInfo::isIndependent(
15141607
SILValue value, const SILAutoDiffIndices &indices) const {
15151608
for (auto paramIdx : indices.parameters.set_bits())
@@ -2181,11 +2274,8 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
21812274
// Clone.
21822275
cloneFunctionBody(original, entry, entryArgs);
21832276
// If errors occurred, back out.
2184-
if (errorOccurred) {
2185-
// Delete the body so that later passes don't get confused by invalid SIL.
2186-
getPrimal()->getBlocks().clear();
2277+
if (errorOccurred)
21872278
return true;
2188-
}
21892279
auto *origExit = &*original->findReturnBB();
21902280
auto *exit = BBMap.lookup(origExit);
21912281
assert(exit->getParent() == getPrimal());
@@ -2249,7 +2339,15 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22492339
return;
22502340
SILClonerWithScopes::visit(inst);
22512341
}
2252-
2342+
2343+
void visitSILInstruction(SILInstruction *inst) {
2344+
// TODO: Change this to a note when we emit an error at the @autodiff
2345+
// function conversion location.
2346+
getContext().emitNondifferentiabilityError(inst, getDifferentiationTask(),
2347+
diag::autodiff_expression_is_not_differentiable_error);
2348+
errorOccurred = true;
2349+
}
2350+
22532351
void visitReturnInst(ReturnInst *ri) {
22542352
// The original return is not to be cloned.
22552353
return;
@@ -2702,6 +2800,7 @@ bool PrimalGen::performSynthesis(FunctionSynthesisItem item) {
27022800
// Synthesize primal.
27032801
PrimalGenCloner cloner(item, activityInfo, domInfo, pdomInfo, loopInfo, *this,
27042802
context);
2803+
// Run the cloner.
27052804
return cloner.run();
27062805
}
27072806

@@ -2724,7 +2823,10 @@ bool PrimalGen::run() {
27242823
while (!worklist.empty()) {
27252824
auto synthesis = worklist.back();
27262825
worklist.pop_back();
2727-
errorOccurred |= performSynthesis(synthesis);
2826+
if (performSynthesis(synthesis)) {
2827+
context.clearTask(synthesis.task);
2828+
errorOccurred = true;
2829+
}
27282830
synthesis.task->getPrimalInfo()->computePrimalValueStructType();
27292831
synthesis.task->setPrimalSynthesisState(FunctionSynthesisState::Done);
27302832
}
@@ -2780,7 +2882,10 @@ bool AdjointGen::run() {
27802882
while (!worklist.empty()) {
27812883
auto synthesis = worklist.back();
27822884
worklist.pop_back();
2783-
errorOccurred |= performSynthesis(synthesis);
2885+
if (performSynthesis(synthesis)) {
2886+
context.clearTask(synthesis.task);
2887+
errorOccurred = true;
2888+
}
27842889
synthesis.task->setAdjointSynthesisState(FunctionSynthesisState::Done);
27852890
}
27862891
return errorOccurred;
@@ -3243,6 +3348,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
32433348
continue;
32443349
// Differentiate instruction.
32453350
visit(&inst);
3351+
if (errorOccurred)
3352+
return true;
32463353
}
32473354
}
32483355

@@ -3322,7 +3429,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
33223429
}
33233430

33243431
void visitSILInstruction(SILInstruction *inst) {
3325-
llvm_unreachable("Unsupport instruction visited");
3432+
// TODO: Change this to a note when we emit an error at the @autodiff
3433+
// function conversion location.
3434+
getContext().emitNondifferentiabilityError(inst, getDifferentiationTask(),
3435+
diag::autodiff_expression_is_not_differentiable_error);
3436+
errorOccurred = true;
33263437
}
33273438

33283439
SILLocation remapLocation(SILLocation loc) { return loc; }
@@ -4326,6 +4437,7 @@ bool AdjointGen::performSynthesis(FunctionSynthesisItem item) {
43264437
*domAnalysis->get(item.original),
43274438
*pdomAnalysis->get(item.original),
43284439
*loopAnalysis->get(item.original), *this);
4440+
// Run the adjoint emitter.
43294441
return emitter.run();
43304442
}
43314443

test/AutoDiff/side_effects.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
func simpleStoreLoad(x: Float) -> Float {
4+
var y = x
5+
y = x + 1
6+
// expected-error @+1 {{expression is not differentiable}}
7+
return y
8+
}
9+
let _: @autodiff (Float) -> Float = simpleStoreLoad(x:)
10+
11+
// TODO: Add file checks.

0 commit comments

Comments
 (0)