Skip to content

Commit bba00c6

Browse files
authored
[AutoDiff] Diagnose unsupported coroutine differentiation. (#28921)
Coroutines are functions that can yield values and suspend/resume execution. SIL has dedicated coroutine types. Coroutines are applied via `begin_apply`. `read` and `modify` accessors are coroutines. Coroutine differentiation requires extra support, so it should be diagnosed for now. Differentiation transform: - Generalize differentiation utilities from `ApplyInst` to `FullApplySite`. - `bool isArrayLiteralIntrinsic(FullApplySite)` - `void forEachApplyDirectResult(FullApplySite, llvm::function_ref<...>)` - `apply`: direct result is the `apply` result, unless the `apply` result has tuple-type and a single `destructure_tuple` user. Then, direct results are the `destructure_tuple` results. - `begin_apply`: direct results are the `begin_apply` results. - `try_apply`: direct results are successor blocks' arguments. - `bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite)` - Generalize activity analysis from `ApplyInst` to `FullApplySite`. - Propagate variedness for `FullApplySite` through direct/indirect results and `inout` arguments. - Propagate usefulness for `FullApplySite` through arguments. - Diagnose `begin_apply` instructions with active arguments and results in `PullbackEmitter::visitBeginApplyInst`. Sema: - Diagnose `@differentiable` attribute on `read` and `modify` coroutines. Resolves TF-1081. TF-1080 tracks coroutine differentiation support. TF-1083 tracks throwing function differentiation support.
1 parent 3b3a2bb commit bba00c6

17 files changed

+320
-81
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,9 @@ NOTE(autodiff_control_flow_not_supported,none,
541541
// TODO(TF-645): Remove when differentiation supports `ref_element_addr`.
542542
NOTE(autodiff_class_property_not_supported,none,
543543
"differentiating class properties is not yet supported", ())
544+
// TODO(TF-1080): Remove when differentiation supports `begin_apply`.
545+
NOTE(autodiff_coroutines_not_supported,none,
546+
"differentiation of coroutine calls is not yet supported", ())
544547
NOTE(autodiff_missing_return,none,
545548
"missing return for differentiation", ())
546549

include/swift/SIL/ApplySite.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,19 @@ class FullApplySite : public ApplySite {
503503
return getArguments().slice(getNumIndirectSILResults());
504504
}
505505

506+
// SWIFT_ENABLE_TENSORFLOW
507+
InoutArgumentRange getInoutArguments() const {
508+
switch (getKind()) {
509+
case FullApplySiteKind::ApplyInst:
510+
return cast<ApplyInst>(getInstruction())->getInoutArguments();
511+
case FullApplySiteKind::TryApplyInst:
512+
return cast<TryApplyInst>(getInstruction())->getInoutArguments();
513+
case FullApplySiteKind::BeginApplyInst:
514+
return cast<BeginApplyInst>(getInstruction())->getInoutArguments();
515+
}
516+
}
517+
// SWIFT_ENABLE_TENSORFLOW END
518+
506519
/// Returns true if \p op is the callee operand of this apply site
507520
/// and not an argument operand.
508521
bool isCalleeOperand(const Operand &op) const {

include/swift/SIL/SILInstruction.h

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,6 +2041,27 @@ class ApplyInstBase<Impl, Base, false> : public Base {
20412041
/// does it have the given semantics?
20422042
bool doesApplyCalleeHaveSemantics(SILValue callee, StringRef semantics);
20432043

2044+
// SWIFT_ENABLE_TENSORFLOW
2045+
/// Predicate used to filter InoutArgumentRange.
2046+
struct OperandToInoutArgument {
2047+
ArrayRef<SILParameterInfo> paramInfos;
2048+
OperandValueArrayRef arguments;
2049+
OperandToInoutArgument(ArrayRef<SILParameterInfo> paramInfos,
2050+
OperandValueArrayRef arguments)
2051+
: paramInfos(paramInfos), arguments(arguments) {
2052+
assert(paramInfos.size() == arguments.size());
2053+
}
2054+
Optional<SILValue> operator()(unsigned long i) const {
2055+
if (paramInfos[i].isIndirectMutating())
2056+
return arguments[i];
2057+
return None;
2058+
}
2059+
};
2060+
2061+
using InoutArgumentRange =
2062+
OptionalTransformRange<IntRange<unsigned long>, OperandToInoutArgument>;
2063+
// SWIFT_ENABLE_TENSORFLOW END
2064+
20442065
/// The partial specialization of ApplyInstBase for full applications.
20452066
/// Adds some methods relating to 'self' and to result types that don't
20462067
/// make sense for partial applications.
@@ -2147,31 +2168,14 @@ class ApplyInstBase<Impl, Base, true>
21472168
}
21482169

21492170
// SWIFT_ENABLE_TENSORFLOW
2150-
private:
2151-
/// Predicate used to filter InoutArgumentRange.
2152-
struct OperandToInoutArgument {
2153-
ArrayRef<SILParameterInfo> paramInfos;
2154-
OperandValueArrayRef arguments;
2155-
OperandToInoutArgument(const Impl &inst)
2156-
: paramInfos(inst.getSubstCalleeConv().getParameters()),
2157-
arguments(inst.getArgumentsWithoutIndirectResults()) {
2158-
assert(paramInfos.size() == arguments.size());
2159-
}
2160-
Optional<SILValue> operator()(unsigned long i) const {
2161-
if (paramInfos[i].isIndirectMutating())
2162-
return arguments[i];
2163-
return None;
2164-
}
2165-
};
2166-
2167-
public:
2168-
using InoutArgumentRange =
2169-
OptionalTransformRange<IntRange<unsigned long>, OperandToInoutArgument>;
21702171
/// Returns all `@inout` and `@inout_aliasable` arguments passed to the
21712172
/// instruction.
21722173
InoutArgumentRange getInoutArguments() const {
2173-
return InoutArgumentRange(indices(getArgumentsWithoutIndirectResults()),
2174-
OperandToInoutArgument(asImpl()));
2174+
auto &impl = asImpl();
2175+
return InoutArgumentRange(
2176+
indices(getArgumentsWithoutIndirectResults()),
2177+
OperandToInoutArgument(impl.getSubstCalleeConv().getParameters(),
2178+
impl.getArgumentsWithoutIndirectResults()));
21752179
}
21762180
// SWIFT_ENABLE_TENSORFLOW END
21772181
};

include/swift/SILOptimizer/Utils/Differentiation/Common.h

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ namespace autodiff {
5050
/// This is being used to print short debug messages within the AD pass.
5151
raw_ostream &getADDebugStream();
5252

53-
/// Returns true if this is an `ApplyInst` with `array.uninitialized_intrinsic`
54-
/// semantics.
55-
bool isArrayLiteralIntrinsic(ApplyInst *ai);
53+
/// Returns true if this is an full apply site whose callee has
54+
/// `array.uninitialized_intrinsic` semantics.
55+
bool isArrayLiteralIntrinsic(FullApplySite applySite);
5656

5757
/// If the given value `v` corresponds to an `ApplyInst` with
5858
/// `array.uninitialized_intrinsic` semantics, returns the corresponding
@@ -76,11 +76,22 @@ ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v);
7676
/// tuple-typed and such a user exists.
7777
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value);
7878

79-
/// Given an `apply` instruction, apply the given callback to each of its
80-
/// direct results. If the `apply` instruction has a single `destructure_tuple`
81-
/// user, apply the callback to the results of the `destructure_tuple` user.
79+
/// Given a full apply site, apply the given callback to each of its
80+
/// "direct results".
81+
///
82+
/// - `apply`
83+
/// Special case because `apply` returns a single (possibly tuple-typed) result
84+
/// instead of multiple results. If the `apply` has a single
85+
/// `destructure_tuple` user, treat the `destructure_tuple` results as the
86+
/// `apply` direct results.
87+
///
88+
/// - `begin_apply`
89+
/// Apply callback to each `begin_apply` direct result.
90+
///
91+
/// - `try_apply`
92+
/// Apply callback to each `try_apply` successor basic block argument.
8293
void forEachApplyDirectResult(
83-
ApplyInst *ai, llvm::function_ref<void(SILValue)> resultCallback);
94+
FullApplySite applySite, llvm::function_ref<void(SILValue)> resultCallback);
8495

8596
/// Given a function, gathers all of its formal results (both direct and
8697
/// indirect) in an order defined by its result type. Note that "formal results"

include/swift/SILOptimizer/Utils/Differentiation/LinearMapInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class LinearMapInfo {
141141
SILFunction *derivative);
142142

143143
public:
144-
bool shouldDifferentiateApplyInst(ApplyInst *ai);
144+
bool shouldDifferentiateApplySite(FullApplySite applySite);
145145
bool shouldDifferentiateInstruction(SILInstruction *inst);
146146

147147
LinearMapInfo(const LinearMapInfo &) = delete;

include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
338338

339339
void visitApplyInst(ApplyInst *ai);
340340

341+
void visitBeginApplyInst(BeginApplyInst *bai);
342+
341343
/// Handle `struct` instruction.
342344
/// Original: y = struct (x0, x1, x2, ...)
343345
/// Adjoint: adj[x0] += struct_extract adj[y], #x0

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,19 +126,21 @@ void DifferentiableActivityInfo::propagateVaried(
126126
// General rule: mark results as varied and recursively propagate variedness
127127
// to users of results.
128128
auto i = independentVariableIndex;
129-
// Handle `apply`.
130-
if (auto *ai = dyn_cast<ApplyInst>(inst)) {
129+
// Handle full apply sites: `apply`, `try_apply`, and `begin_apply`.
130+
if (FullApplySite::isa(inst)) {
131+
FullApplySite applySite(inst);
131132
// If callee is non-varying, skip.
132-
if (isWithoutDerivative(ai->getCallee()))
133+
if (isWithoutDerivative(applySite.getCallee()))
133134
return;
134135
// If operand is varied, set all direct/indirect results and inout arguments
135136
// as varied.
136137
if (isVaried(operand->get(), i)) {
137-
for (auto indRes : ai->getIndirectSILResults())
138+
for (auto indRes : applySite.getIndirectSILResults())
138139
propagateVariedInwardsThroughProjections(indRes, i);
139-
for (auto inoutArg : ai->getInoutArguments())
140+
for (auto inoutArg : applySite.getInoutArguments())
140141
propagateVariedInwardsThroughProjections(inoutArg, i);
141-
forEachApplyDirectResult(ai, [&](SILValue directResult) {
142+
// Propagate variedness to apply site direct results.
143+
forEachApplyDirectResult(applySite, [&](SILValue directResult) {
142144
setVariedAndPropagateToUsers(directResult, i);
143145
});
144146
}
@@ -218,7 +220,7 @@ void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections(
218220
// Set value as varied and propagate to users.
219221
setVariedAndPropagateToUsers(value, independentVariableIndex);
220222
auto *inst = value->getDefiningInstruction();
221-
if (!inst || isa<ApplyInst>(inst))
223+
if (!inst || ApplySite::isa(inst))
222224
return;
223225
// Standard propagation.
224226
for (auto &op : inst->getAllOperands())
@@ -262,11 +264,13 @@ void DifferentiableActivityInfo::propagateUseful(
262264
// Propagate usefulness for the given instruction: mark operands as useful and
263265
// recursively propagate usefulness to defining instructions of operands.
264266
auto i = dependentVariableIndex;
265-
// Handle indirect results in `apply`.
266-
if (auto *ai = dyn_cast<ApplyInst>(inst)) {
267-
if (isWithoutDerivative(ai->getCallee()))
267+
// Handle full apply sites: `apply`, `try_apply`, and `begin_apply`.
268+
if (FullApplySite::isa(inst)) {
269+
FullApplySite applySite(inst);
270+
// If callee is non-varying, skip.
271+
if (isWithoutDerivative(applySite.getCallee()))
268272
return;
269-
for (auto arg : ai->getArgumentsWithoutIndirectResults())
273+
for (auto arg : applySite.getArgumentsWithoutIndirectResults())
270274
setUsefulAndPropagateToOperands(arg, i);
271275
}
272276
// Handle store-like instructions:

lib/SILOptimizer/Utils/Differentiation/Common.cpp

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ namespace autodiff {
2626

2727
raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; }
2828

29-
bool isArrayLiteralIntrinsic(ApplyInst *ai) {
30-
return ai->hasSemantics("array.uninitialized_intrinsic");
29+
bool isArrayLiteralIntrinsic(FullApplySite applySite) {
30+
return doesApplyCalleeHaveSemantics(applySite.getCalleeOrigin(),
31+
"array.uninitialized_intrinsic");
3132
}
3233

3334
ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) {
@@ -71,14 +72,34 @@ DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
7172
}
7273

7374
void forEachApplyDirectResult(
74-
ApplyInst *ai, llvm::function_ref<void(SILValue)> resultCallback) {
75-
if (!ai->getType().is<TupleType>()) {
76-
resultCallback(ai);
77-
return;
75+
FullApplySite applySite,
76+
llvm::function_ref<void(SILValue)> resultCallback) {
77+
switch (applySite.getKind()) {
78+
case FullApplySiteKind::ApplyInst: {
79+
auto *ai = cast<ApplyInst>(applySite.getInstruction());
80+
if (!ai->getType().is<TupleType>()) {
81+
resultCallback(ai);
82+
return;
83+
}
84+
if (auto *dti = getSingleDestructureTupleUser(ai))
85+
for (auto directResult : dti->getResults())
86+
resultCallback(directResult);
87+
break;
88+
}
89+
case FullApplySiteKind::BeginApplyInst: {
90+
auto *bai = cast<BeginApplyInst>(applySite.getInstruction());
91+
for (auto directResult : bai->getResults())
92+
resultCallback(directResult);
93+
break;
94+
}
95+
case FullApplySiteKind::TryApplyInst: {
96+
auto *tai = cast<TryApplyInst>(applySite.getInstruction());
97+
for (auto *succBB : tai->getSuccessorBlocks())
98+
for (auto *arg : succBB->getArguments())
99+
resultCallback(arg);
100+
break;
101+
}
78102
}
79-
if (auto *dti = getSingleDestructureTupleUser(ai))
80-
for (auto result : dti->getResults())
81-
resultCallback(result);
82103
}
83104

84105
void collectAllFormalResultsInTypeOrder(SILFunction &function,

lib/SILOptimizer/Utils/Differentiation/JVPEmitter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) {
746746
void JVPEmitter::emitTangentForApplyInst(
747747
ApplyInst *ai, SILAutoDiffIndices actualIndices,
748748
CanSILFunctionType originalDifferentialType) {
749-
assert(differentialInfo.shouldDifferentiateApplyInst(ai));
749+
assert(differentialInfo.shouldDifferentiateApplySite(ai));
750750
auto *bb = ai->getParent();
751751
auto loc = ai->getLoc();
752752
auto &diffBuilder = getDifferentialBuilder();
@@ -1184,7 +1184,7 @@ void JVPEmitter::visitInstructionsInBlock(SILBasicBlock *bb) {
11841184
void JVPEmitter::visitApplyInst(ApplyInst *ai) {
11851185
// If the function should not be differentiated or its the array literal
11861186
// initialization intrinsic, just do standard cloning.
1187-
if (!differentialInfo.shouldDifferentiateApplyInst(ai) ||
1187+
if (!differentialInfo.shouldDifferentiateApplySite(ai) ||
11881188
isArrayLiteralIntrinsic(ai)) {
11891189
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
11901190
TypeSubstCloner::visitApplyInst(ai);

lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
411411
// Add linear map field to struct for active `apply` instructions.
412412
// Skip array literal intrinsic applications since array literal
413413
// initialization is linear and handled separately.
414-
if (!shouldDifferentiateApplyInst(ai) || isArrayLiteralIntrinsic(ai))
414+
if (!shouldDifferentiateApplySite(ai) || isArrayLiteralIntrinsic(ai))
415415
continue;
416416

417417
LLVM_DEBUG(getADDebugStream() << "Adding linear map struct field for "
@@ -454,26 +454,29 @@ void LinearMapInfo::generateDifferentiationDataStructures(
454454
/// there is a `store` of an active value into the array's buffer.
455455
/// 3. The instruction has both an active result (direct or indirect) and an
456456
/// active argument.
457-
bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) {
457+
bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) {
458458
// Function applications with an inout argument should be differentiated.
459-
for (auto inoutArg : ai->getInoutArguments())
459+
for (auto inoutArg : applySite.getInoutArguments())
460460
if (activityInfo.isActive(inoutArg, indices))
461461
return true;
462462

463463
bool hasActiveDirectResults = false;
464-
forEachApplyDirectResult(ai, [&](SILValue directResult) {
464+
forEachApplyDirectResult(applySite, [&](SILValue directResult) {
465465
hasActiveDirectResults |= activityInfo.isActive(directResult, indices);
466466
});
467-
bool hasActiveIndirectResults = llvm::any_of(ai->getIndirectSILResults(),
468-
[&](SILValue result) { return activityInfo.isActive(result, indices); });
467+
bool hasActiveIndirectResults =
468+
llvm::any_of(applySite.getIndirectSILResults(), [&](SILValue result) {
469+
return activityInfo.isActive(result, indices);
470+
});
469471
bool hasActiveResults = hasActiveDirectResults || hasActiveIndirectResults;
470472

471473
// TODO: Pattern match to make sure there is at least one `store` to the
472474
// array's active buffer.
473-
if (isArrayLiteralIntrinsic(ai) && hasActiveResults)
475+
// if (isArrayLiteralIntrinsic(ai) && hasActiveResults)
476+
if (isArrayLiteralIntrinsic(applySite) && hasActiveResults)
474477
return true;
475478

476-
auto arguments = ai->getArgumentsWithoutIndirectResults();
479+
auto arguments = applySite.getArgumentsWithoutIndirectResults();
477480
bool hasActiveArguments = llvm::any_of(arguments,
478481
[&](SILValue arg) { return activityInfo.isActive(arg, indices); });
479482
return hasActiveResults && hasActiveArguments;
@@ -483,19 +486,19 @@ bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) {
483486
/// given the differentiation indices of the instruction's parent function.
484487
/// Whether the instruction should be differentiated is determined sequentially
485488
/// from any of the following conditions:
486-
/// 1. The instruction is an `apply` and `shouldDifferentiateApplyInst` returns
487-
/// true.
489+
/// 1. The instruction is a full apply site and `shouldDifferentiateApplyInst`
490+
/// returns true.
488491
/// 2. The instruction has a source operand and a destination operand, both
489492
/// being active.
490493
/// 3. The instruction is an allocation instruction and has an active result.
491494
/// 4. The instruction performs reference counting, lifetime ending, access
492495
/// ending, or destroying on an active operand.
493496
/// 5. The instruction creates an SSA copy of an active operand.
494497
bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
495-
// An `apply` with an active argument and an active result (direct or
498+
// A full apply site with an active argument and an active result (direct or
496499
// indirect) should be differentiated.
497-
if (auto *ai = dyn_cast<ApplyInst>(inst))
498-
return shouldDifferentiateApplyInst(ai);
500+
if (FullApplySite::isa(inst))
501+
return shouldDifferentiateApplySite(FullApplySite(inst));
499502
// Anything with an active result and an active operand should be
500503
// differentiated.
501504
auto hasActiveOperands = llvm::any_of(inst->getAllOperands(),

lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,7 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
10781078
}
10791079

10801080
void PullbackEmitter::visitApplyInst(ApplyInst *ai) {
1081-
assert(getPullbackInfo().shouldDifferentiateApplyInst(ai));
1081+
assert(getPullbackInfo().shouldDifferentiateApplySite(ai));
10821082
// Skip `array.uninitialized_intrinsic` intrinsic applications, which have
10831083
// special `store` and `copy_addr` support.
10841084
if (isArrayLiteralIntrinsic(ai))
@@ -1273,6 +1273,15 @@ void PullbackEmitter::visitStructInst(StructInst *si) {
12731273
}
12741274
}
12751275

1276+
void PullbackEmitter::visitBeginApplyInst(BeginApplyInst *bai) {
1277+
// Diagnose `begin_apply` instructions.
1278+
// Coroutine differentiation is not yet supported.
1279+
getContext().emitNondifferentiabilityError(
1280+
bai, getInvoker(), diag::autodiff_coroutines_not_supported);
1281+
errorOccurred = true;
1282+
return;
1283+
}
1284+
12761285
void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
12771286
assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
12781287
"`struct_extract` with `@noDerivative` field should not be "

lib/SILOptimizer/Utils/Differentiation/VJPEmitter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ void VJPEmitter::visitSwitchEnumInst(SwitchEnumInst *sei) {
430430
void VJPEmitter::visitApplyInst(ApplyInst *ai) {
431431
// If the function should not be differentiated or its the array literal
432432
// initialization intrinsic, just do standard cloning.
433-
if (!pullbackInfo.shouldDifferentiateApplyInst(ai) ||
433+
if (!pullbackInfo.shouldDifferentiateApplySite(ai) ||
434434
isArrayLiteralIntrinsic(ai)) {
435435
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
436436
TypeSubstCloner::visitApplyInst(ai);

0 commit comments

Comments
 (0)