Skip to content

[AutoDiff] Diagnose unsupported coroutine differentiation. #28921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,9 @@ NOTE(autodiff_control_flow_not_supported,none,
// TODO(TF-645): Remove when differentiation supports `ref_element_addr`.
NOTE(autodiff_class_property_not_supported,none,
"differentiating class properties is not yet supported", ())
// TODO(TF-1080): Remove when differentiation supports `begin_apply`.
NOTE(autodiff_coroutines_not_supported,none,
"differentiation of coroutine calls is not yet supported", ())
NOTE(autodiff_missing_return,none,
"missing return for differentiation", ())

Expand Down
13 changes: 13 additions & 0 deletions include/swift/SIL/ApplySite.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,19 @@ class FullApplySite : public ApplySite {
return getArguments().slice(getNumIndirectSILResults());
}

// SWIFT_ENABLE_TENSORFLOW
InoutArgumentRange getInoutArguments() const {
switch (getKind()) {
case FullApplySiteKind::ApplyInst:
return cast<ApplyInst>(getInstruction())->getInoutArguments();
case FullApplySiteKind::TryApplyInst:
return cast<TryApplyInst>(getInstruction())->getInoutArguments();
case FullApplySiteKind::BeginApplyInst:
return cast<BeginApplyInst>(getInstruction())->getInoutArguments();
}
}
// SWIFT_ENABLE_TENSORFLOW END

/// Returns true if \p op is the callee operand of this apply site
/// and not an argument operand.
bool isCalleeOperand(const Operand &op) const {
Expand Down
48 changes: 26 additions & 22 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2041,6 +2041,27 @@ class ApplyInstBase<Impl, Base, false> : public Base {
/// does it have the given semantics?
bool doesApplyCalleeHaveSemantics(SILValue callee, StringRef semantics);

// SWIFT_ENABLE_TENSORFLOW
/// Predicate used to filter InoutArgumentRange.
struct OperandToInoutArgument {
ArrayRef<SILParameterInfo> paramInfos;
OperandValueArrayRef arguments;
OperandToInoutArgument(ArrayRef<SILParameterInfo> paramInfos,
OperandValueArrayRef arguments)
: paramInfos(paramInfos), arguments(arguments) {
assert(paramInfos.size() == arguments.size());
}
Optional<SILValue> operator()(unsigned long i) const {
if (paramInfos[i].isIndirectMutating())
return arguments[i];
return None;
}
};

using InoutArgumentRange =
OptionalTransformRange<IntRange<unsigned long>, OperandToInoutArgument>;
// SWIFT_ENABLE_TENSORFLOW END

/// The partial specialization of ApplyInstBase for full applications.
/// Adds some methods relating to 'self' and to result types that don't
/// make sense for partial applications.
Expand Down Expand Up @@ -2147,31 +2168,14 @@ class ApplyInstBase<Impl, Base, true>
}

// SWIFT_ENABLE_TENSORFLOW
private:
/// Predicate used to filter InoutArgumentRange.
struct OperandToInoutArgument {
ArrayRef<SILParameterInfo> paramInfos;
OperandValueArrayRef arguments;
OperandToInoutArgument(const Impl &inst)
: paramInfos(inst.getSubstCalleeConv().getParameters()),
arguments(inst.getArgumentsWithoutIndirectResults()) {
assert(paramInfos.size() == arguments.size());
}
Optional<SILValue> operator()(unsigned long i) const {
if (paramInfos[i].isIndirectMutating())
return arguments[i];
return None;
}
};

public:
using InoutArgumentRange =
OptionalTransformRange<IntRange<unsigned long>, OperandToInoutArgument>;
/// Returns all `@inout` and `@inout_aliasable` arguments passed to the
/// instruction.
InoutArgumentRange getInoutArguments() const {
return InoutArgumentRange(indices(getArgumentsWithoutIndirectResults()),
OperandToInoutArgument(asImpl()));
auto &impl = asImpl();
return InoutArgumentRange(
indices(getArgumentsWithoutIndirectResults()),
OperandToInoutArgument(impl.getSubstCalleeConv().getParameters(),
impl.getArgumentsWithoutIndirectResults()));
}
// SWIFT_ENABLE_TENSORFLOW END
};
Expand Down
25 changes: 18 additions & 7 deletions include/swift/SILOptimizer/Utils/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ namespace autodiff {
/// This is being used to print short debug messages within the AD pass.
raw_ostream &getADDebugStream();

/// Returns true if this is an `ApplyInst` with `array.uninitialized_intrinsic`
/// semantics.
bool isArrayLiteralIntrinsic(ApplyInst *ai);
/// Returns true if this is an full apply site whose callee has
/// `array.uninitialized_intrinsic` semantics.
bool isArrayLiteralIntrinsic(FullApplySite applySite);

/// If the given value `v` corresponds to an `ApplyInst` with
/// `array.uninitialized_intrinsic` semantics, returns the corresponding
Expand All @@ -76,11 +76,22 @@ ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v);
/// tuple-typed and such a user exists.
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value);

/// Given an `apply` instruction, apply the given callback to each of its
/// direct results. If the `apply` instruction has a single `destructure_tuple`
/// user, apply the callback to the results of the `destructure_tuple` user.
/// Given a full apply site, apply the given callback to each of its
/// "direct results".
///
/// - `apply`
/// Special case because `apply` returns a single (possibly tuple-typed) result
/// instead of multiple results. If the `apply` has a single
/// `destructure_tuple` user, treat the `destructure_tuple` results as the
/// `apply` direct results.
///
/// - `begin_apply`
/// Apply callback to each `begin_apply` direct result.
///
/// - `try_apply`
/// Apply callback to each `try_apply` successor basic block argument.
void forEachApplyDirectResult(
ApplyInst *ai, llvm::function_ref<void(SILValue)> resultCallback);
FullApplySite applySite, llvm::function_ref<void(SILValue)> resultCallback);

/// Given a function, gathers all of its formal results (both direct and
/// indirect) in an order defined by its result type. Note that "formal results"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class LinearMapInfo {
SILFunction *derivative);

public:
bool shouldDifferentiateApplyInst(ApplyInst *ai);
bool shouldDifferentiateApplySite(FullApplySite applySite);
bool shouldDifferentiateInstruction(SILInstruction *inst);

LinearMapInfo(const LinearMapInfo &) = delete;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {

void visitApplyInst(ApplyInst *ai);

void visitBeginApplyInst(BeginApplyInst *bai);

/// Handle `struct` instruction.
/// Original: y = struct (x0, x1, x2, ...)
/// Adjoint: adj[x0] += struct_extract adj[y], #x0
Expand Down
26 changes: 15 additions & 11 deletions lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,21 @@ void DifferentiableActivityInfo::propagateVaried(
// General rule: mark results as varied and recursively propagate variedness
// to users of results.
auto i = independentVariableIndex;
// Handle `apply`.
if (auto *ai = dyn_cast<ApplyInst>(inst)) {
// Handle full apply sites: `apply`, `try_apply`, and `begin_apply`.
if (FullApplySite::isa(inst)) {
FullApplySite applySite(inst);
// If callee is non-varying, skip.
if (isWithoutDerivative(ai->getCallee()))
if (isWithoutDerivative(applySite.getCallee()))
return;
// If operand is varied, set all direct/indirect results and inout arguments
// as varied.
if (isVaried(operand->get(), i)) {
for (auto indRes : ai->getIndirectSILResults())
for (auto indRes : applySite.getIndirectSILResults())
propagateVariedInwardsThroughProjections(indRes, i);
for (auto inoutArg : ai->getInoutArguments())
for (auto inoutArg : applySite.getInoutArguments())
propagateVariedInwardsThroughProjections(inoutArg, i);
forEachApplyDirectResult(ai, [&](SILValue directResult) {
// Propagate variedness to apply site direct results.
forEachApplyDirectResult(applySite, [&](SILValue directResult) {
setVariedAndPropagateToUsers(directResult, i);
});
}
Expand Down Expand Up @@ -218,7 +220,7 @@ void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections(
// Set value as varied and propagate to users.
setVariedAndPropagateToUsers(value, independentVariableIndex);
auto *inst = value->getDefiningInstruction();
if (!inst || isa<ApplyInst>(inst))
if (!inst || ApplySite::isa(inst))
return;
// Standard propagation.
for (auto &op : inst->getAllOperands())
Expand Down Expand Up @@ -262,11 +264,13 @@ void DifferentiableActivityInfo::propagateUseful(
// Propagate usefulness for the given instruction: mark operands as useful and
// recursively propagate usefulness to defining instructions of operands.
auto i = dependentVariableIndex;
// Handle indirect results in `apply`.
if (auto *ai = dyn_cast<ApplyInst>(inst)) {
if (isWithoutDerivative(ai->getCallee()))
// Handle full apply sites: `apply`, `try_apply`, and `begin_apply`.
if (FullApplySite::isa(inst)) {
FullApplySite applySite(inst);
// If callee is non-varying, skip.
if (isWithoutDerivative(applySite.getCallee()))
return;
for (auto arg : ai->getArgumentsWithoutIndirectResults())
for (auto arg : applySite.getArgumentsWithoutIndirectResults())
setUsefulAndPropagateToOperands(arg, i);
}
// Handle store-like instructions:
Expand Down
39 changes: 30 additions & 9 deletions lib/SILOptimizer/Utils/Differentiation/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ namespace autodiff {

raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; }

bool isArrayLiteralIntrinsic(ApplyInst *ai) {
return ai->hasSemantics("array.uninitialized_intrinsic");
bool isArrayLiteralIntrinsic(FullApplySite applySite) {
return doesApplyCalleeHaveSemantics(applySite.getCalleeOrigin(),
"array.uninitialized_intrinsic");
}

ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) {
Expand Down Expand Up @@ -71,14 +72,34 @@ DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
}

void forEachApplyDirectResult(
ApplyInst *ai, llvm::function_ref<void(SILValue)> resultCallback) {
if (!ai->getType().is<TupleType>()) {
resultCallback(ai);
return;
FullApplySite applySite,
llvm::function_ref<void(SILValue)> resultCallback) {
switch (applySite.getKind()) {
case FullApplySiteKind::ApplyInst: {
auto *ai = cast<ApplyInst>(applySite.getInstruction());
if (!ai->getType().is<TupleType>()) {
resultCallback(ai);
return;
}
if (auto *dti = getSingleDestructureTupleUser(ai))
for (auto directResult : dti->getResults())
resultCallback(directResult);
break;
}
case FullApplySiteKind::BeginApplyInst: {
auto *bai = cast<BeginApplyInst>(applySite.getInstruction());
for (auto directResult : bai->getResults())
resultCallback(directResult);
break;
}
case FullApplySiteKind::TryApplyInst: {
auto *tai = cast<TryApplyInst>(applySite.getInstruction());
for (auto *succBB : tai->getSuccessorBlocks())
for (auto *arg : succBB->getArguments())
resultCallback(arg);
break;
}
}
if (auto *dti = getSingleDestructureTupleUser(ai))
for (auto result : dti->getResults())
resultCallback(result);
}

void collectAllFormalResultsInTypeOrder(SILFunction &function,
Expand Down
4 changes: 2 additions & 2 deletions lib/SILOptimizer/Utils/Differentiation/JVPEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) {
void JVPEmitter::emitTangentForApplyInst(
ApplyInst *ai, SILAutoDiffIndices actualIndices,
CanSILFunctionType originalDifferentialType) {
assert(differentialInfo.shouldDifferentiateApplyInst(ai));
assert(differentialInfo.shouldDifferentiateApplySite(ai));
auto *bb = ai->getParent();
auto loc = ai->getLoc();
auto &diffBuilder = getDifferentialBuilder();
Expand Down Expand Up @@ -1184,7 +1184,7 @@ void JVPEmitter::visitInstructionsInBlock(SILBasicBlock *bb) {
void JVPEmitter::visitApplyInst(ApplyInst *ai) {
// If the function should not be differentiated or its the array literal
// initialization intrinsic, just do standard cloning.
if (!differentialInfo.shouldDifferentiateApplyInst(ai) ||
if (!differentialInfo.shouldDifferentiateApplySite(ai) ||
isArrayLiteralIntrinsic(ai)) {
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
TypeSubstCloner::visitApplyInst(ai);
Expand Down
29 changes: 16 additions & 13 deletions lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
// Add linear map field to struct for active `apply` instructions.
// Skip array literal intrinsic applications since array literal
// initialization is linear and handled separately.
if (!shouldDifferentiateApplyInst(ai) || isArrayLiteralIntrinsic(ai))
if (!shouldDifferentiateApplySite(ai) || isArrayLiteralIntrinsic(ai))
continue;

LLVM_DEBUG(getADDebugStream() << "Adding linear map struct field for "
Expand Down Expand Up @@ -454,26 +454,29 @@ void LinearMapInfo::generateDifferentiationDataStructures(
/// there is a `store` of an active value into the array's buffer.
/// 3. The instruction has both an active result (direct or indirect) and an
/// active argument.
bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) {
bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) {
// Function applications with an inout argument should be differentiated.
for (auto inoutArg : ai->getInoutArguments())
for (auto inoutArg : applySite.getInoutArguments())
if (activityInfo.isActive(inoutArg, indices))
return true;

bool hasActiveDirectResults = false;
forEachApplyDirectResult(ai, [&](SILValue directResult) {
forEachApplyDirectResult(applySite, [&](SILValue directResult) {
hasActiveDirectResults |= activityInfo.isActive(directResult, indices);
});
bool hasActiveIndirectResults = llvm::any_of(ai->getIndirectSILResults(),
[&](SILValue result) { return activityInfo.isActive(result, indices); });
bool hasActiveIndirectResults =
llvm::any_of(applySite.getIndirectSILResults(), [&](SILValue result) {
return activityInfo.isActive(result, indices);
});
bool hasActiveResults = hasActiveDirectResults || hasActiveIndirectResults;

// TODO: Pattern match to make sure there is at least one `store` to the
// array's active buffer.
if (isArrayLiteralIntrinsic(ai) && hasActiveResults)
// if (isArrayLiteralIntrinsic(ai) && hasActiveResults)
if (isArrayLiteralIntrinsic(applySite) && hasActiveResults)
return true;

auto arguments = ai->getArgumentsWithoutIndirectResults();
auto arguments = applySite.getArgumentsWithoutIndirectResults();
bool hasActiveArguments = llvm::any_of(arguments,
[&](SILValue arg) { return activityInfo.isActive(arg, indices); });
return hasActiveResults && hasActiveArguments;
Expand All @@ -483,19 +486,19 @@ bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) {
/// given the differentiation indices of the instruction's parent function.
/// Whether the instruction should be differentiated is determined sequentially
/// from any of the following conditions:
/// 1. The instruction is an `apply` and `shouldDifferentiateApplyInst` returns
/// true.
/// 1. The instruction is a full apply site and `shouldDifferentiateApplyInst`
/// returns true.
/// 2. The instruction has a source operand and a destination operand, both
/// being active.
/// 3. The instruction is an allocation instruction and has an active result.
/// 4. The instruction performs reference counting, lifetime ending, access
/// ending, or destroying on an active operand.
/// 5. The instruction creates an SSA copy of an active operand.
bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
// An `apply` with an active argument and an active result (direct or
// A full apply site with an active argument and an active result (direct or
// indirect) should be differentiated.
if (auto *ai = dyn_cast<ApplyInst>(inst))
return shouldDifferentiateApplyInst(ai);
if (FullApplySite::isa(inst))
return shouldDifferentiateApplySite(FullApplySite(inst));
// Anything with an active result and an active operand should be
// differentiated.
auto hasActiveOperands = llvm::any_of(inst->getAllOperands(),
Expand Down
11 changes: 10 additions & 1 deletion lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
}

void PullbackEmitter::visitApplyInst(ApplyInst *ai) {
assert(getPullbackInfo().shouldDifferentiateApplyInst(ai));
assert(getPullbackInfo().shouldDifferentiateApplySite(ai));
// Skip `array.uninitialized_intrinsic` intrinsic applications, which have
// special `store` and `copy_addr` support.
if (isArrayLiteralIntrinsic(ai))
Expand Down Expand Up @@ -1273,6 +1273,15 @@ void PullbackEmitter::visitStructInst(StructInst *si) {
}
}

void PullbackEmitter::visitBeginApplyInst(BeginApplyInst *bai) {
// Diagnose `begin_apply` instructions.
// Coroutine differentiation is not yet supported.
getContext().emitNondifferentiabilityError(
bai, getInvoker(), diag::autodiff_coroutines_not_supported);
errorOccurred = true;
return;
}

void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
"`struct_extract` with `@noDerivative` field should not be "
Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Utils/Differentiation/VJPEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ void VJPEmitter::visitSwitchEnumInst(SwitchEnumInst *sei) {
void VJPEmitter::visitApplyInst(ApplyInst *ai) {
// If the function should not be differentiated or its the array literal
// initialization intrinsic, just do standard cloning.
if (!pullbackInfo.shouldDifferentiateApplyInst(ai) ||
if (!pullbackInfo.shouldDifferentiateApplySite(ai) ||
isArrayLiteralIntrinsic(ai)) {
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
TypeSubstCloner::visitApplyInst(ai);
Expand Down
Loading