Skip to content

[AutoDiff] Array literal differentiation fixes. #28889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,20 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer,
SILValue rhsBufferAccess, SILLocation loc);

/// Given the adjoint value of an array initialized from an
/// `array.uninitialized_intrinsic` application and an array element index,
/// returns an `alloc_stack` containing the adjoint value of the array element
/// at the given index by applying `Array.TangentVector.subscript`.
AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint,
int eltIndex, SILLocation loc);

/// Given the adjoint value of an array initialized from an
/// `array.uninitialized_intrinsic` application, accumulate the adjoint
/// value's elements into the adjoint buffers of its element addresses.
void accumulateArrayLiteralElementAddressAdjoints(
SILBasicBlock *origBB, SILValue originalValue,
AdjointValue arrayAdjointValue, SILLocation loc);

//--------------------------------------------------------------------------//
// CFG mapping
//--------------------------------------------------------------------------//
Expand Down Expand Up @@ -322,25 +336,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {

void visitSILInstruction(SILInstruction *inst);

/// Given an array adjoint value, array element index and element tangent
/// type, returns an `alloc_stack` containing the array element adjoint value.
AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint,
int eltIndex, SILType eltTanType,
SILLocation loc);

/// Accumulate array element adjoint buffer into `store` source.
void accumulateArrayElementAdjointDirect(StoreInst *si,
AllocStackInst *eltAdjBuffer);

/// Accumulate array element adjoint buffer into `copy_addr` source.
void accumulateArrayElementAdjointIndirect(CopyAddrInst *cai,
AllocStackInst *eltAdjBuffer);

/// Given a `store` or `copy_addr` instruction whose destination is an element
/// address from an `array.uninitialized_intrinsic` application, accumulate
/// array element adjoint into the source's adjoint.
void accumulateArrayElementAdjoint(SILInstruction *inst);

void visitApplyInst(ApplyInst *ai);

/// Handle `struct` instruction.
Expand Down
33 changes: 25 additions & 8 deletions lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "differentiation"

#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
Expand Down Expand Up @@ -343,18 +344,25 @@ void DifferentiableActivityInfo::setUsefulThroughArrayInitialization(
auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser());
assert(ptai && "Expected `pointer_to_address` user for uninitialized "
"array intrinsic");
// Propagate usefulness through array element addresses.
// - Find `store` and `copy_addr` instructions with array element
// address destinations.
// - For each instruction, set destination (array element address) as
// useful and propagate usefulness through source.
setUseful(ptai, dependentVariableIndex);
// Propagate usefulness through array element addresses:
// `pointer_to_address` and `index_addr` instructions.
//
// - Set all array element addresses as useful.
// - Find instructions with array element addresses as "result":
// - `store` and `copy_addr` with array element address as destination.
// - `apply` with array element address as an indirect result.
// - For each instruction, propagate usefulness through "arguments":
// - `store` and `copy_addr`: propagate to source.
// - `apply`: propagate to arguments.
//
// Note: `propagateUseful(use->getUser(), ...)` is intentionally not used
// NOTE: `propagateUseful(use->getUser(), ...)` is intentionally not used
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: rather than pattern-matching yet another special case of "writing to an array literal element address" (store, copy_addr, and now apply indirect result) here in DifferentiableActivityInfo::setUsefulThroughArrayInitialization, it may be nice to call a blanket propagateUseful(use->getUser(), ...), at risk of marking some "junk" values as useful:

  • The array.uninitialized_intrinsic RawPointer result
  • integer_literal operands to index_addr element addresses.

In practice, calling propagateUseful(use->getUser(), ...) is blocked by TF-1032, which I plan to investigate soon.

// because it marks more values than necessary as useful, including:
// - The `RawPointer` result of the intrinsic.
// - The `pointer_to_address` user of the `RawPointer`.
// - `index_addr` and `integer_literal` instructions for indexing the
// - `integer_literal` operands to `index_addr` for indexing the
// `RawPointer`.
// It is also blocked by TF-1032: control flow differentiation crash for
// active values with no tangent space.
for (auto use : ptai->getUses()) {
auto *user = use->getUser();
if (auto *si = dyn_cast<StoreInst>(user)) {
Expand All @@ -364,7 +372,12 @@ void DifferentiableActivityInfo::setUsefulThroughArrayInitialization(
setUseful(cai->getDest(), dependentVariableIndex);
setUsefulAndPropagateToOperands(cai->getSrc(),
dependentVariableIndex);
} else if (auto *ai = dyn_cast<ApplyInst>(user)) {
if (FullApplySite(ai).isIndirectResultOperand(*use))
for (auto arg : ai->getArgumentsWithoutIndirectResults())
setUsefulAndPropagateToOperands(arg, dependentVariableIndex);
} else if (auto *iai = dyn_cast<IndexAddrInst>(user)) {
setUseful(iai, dependentVariableIndex);
for (auto use : iai->getUses()) {
auto *user = use->getUser();
if (auto si = dyn_cast<StoreInst>(user)) {
Expand All @@ -375,6 +388,10 @@ void DifferentiableActivityInfo::setUsefulThroughArrayInitialization(
setUseful(cai->getDest(), dependentVariableIndex);
setUsefulAndPropagateToOperands(cai->getSrc(),
dependentVariableIndex);
} else if (auto *ai = dyn_cast<ApplyInst>(user)) {
if (FullApplySite(ai).isIndirectResultOperand(*use))
for (auto arg : ai->getArgumentsWithoutIndirectResults())
setUsefulAndPropagateToOperands(arg, dependentVariableIndex);
}
}
}
Expand Down
Loading