Skip to content

Commit 05c4539

Browse files
authored
[AutoDiff] Array literal differentiation fixes. (#28889)
Activity analysis change: - Mark array literal element addresses (`pointer_to_address` and `index_addr` instructions) as useful. These addresses legitimately need a derivative. - Propagate usefulness through array literal element addresses that are `apply` indirect results to the `apply` arguments. Pullback generation changes: - Add special case for array literal element addresses to `PullbackEmitter::getAdjointProjection`. The adjoint projection is a local allocation initialized from the array literal's adjoint value by applying `Array.TangentVector.subscript`. - This generalizes the old logic in `PullbackEmitter::visit{Store,CopyAddr}Inst` for handling array literal element addresses. - When an array literal's adjoint value is updated via `PullbackEmitter::addAdjointValue`, accumulate the array literal's adjoint value into the adjoint buffers of its element addresses. Resolves cleanup task: TF-976. Fixes correctness issues: - TF-975: nested array literals. - TF-978: array literal element address initialized as `apply` indirect result. Update activity analysis and derivative correctness tests.
1 parent 7d3ae09 commit 05c4539

File tree

5 files changed

+202
-174
lines changed

5 files changed

+202
-174
lines changed

include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,20 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
274274
void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer,
275275
SILValue rhsBufferAccess, SILLocation loc);
276276

277+
/// Given the adjoint value of an array initialized from an
278+
/// `array.uninitialized_intrinsic` application and an array element index,
279+
/// returns an `alloc_stack` containing the adjoint value of the array element
280+
/// at the given index by applying `Array.TangentVector.subscript`.
281+
AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint,
282+
int eltIndex, SILLocation loc);
283+
284+
/// Given the adjoint value of an array initialized from an
285+
/// `array.uninitialized_intrinsic` application, accumulate the adjoint
286+
/// value's elements into the adjoint buffers of its element addresses.
287+
void accumulateArrayLiteralElementAddressAdjoints(
288+
SILBasicBlock *origBB, SILValue originalValue,
289+
AdjointValue arrayAdjointValue, SILLocation loc);
290+
277291
//--------------------------------------------------------------------------//
278292
// CFG mapping
279293
//--------------------------------------------------------------------------//
@@ -322,25 +336,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
322336

323337
void visitSILInstruction(SILInstruction *inst);
324338

325-
/// Given an array adjoint value, array element index and element tangent
326-
/// type, returns an `alloc_stack` containing the array element adjoint value.
327-
AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint,
328-
int eltIndex, SILType eltTanType,
329-
SILLocation loc);
330-
331-
/// Accumulate array element adjoint buffer into `store` source.
332-
void accumulateArrayElementAdjointDirect(StoreInst *si,
333-
AllocStackInst *eltAdjBuffer);
334-
335-
/// Accumulate array element adjoint buffer into `copy_addr` source.
336-
void accumulateArrayElementAdjointIndirect(CopyAddrInst *cai,
337-
AllocStackInst *eltAdjBuffer);
338-
339-
/// Given a `store` or `copy_addr` instruction whose destination is an element
340-
/// address from an `array.uninitialized_intrinsic` application, accumulate
341-
/// array element adjoint into the source's adjoint.
342-
void accumulateArrayElementAdjoint(SILInstruction *inst);
343-
344339
void visitApplyInst(ApplyInst *ai);
345340

346341
/// Handle `struct` instruction.

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
1010
//
1111
//===----------------------------------------------------------------------===//
12+
1213
#define DEBUG_TYPE "differentiation"
1314

1415
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
@@ -343,18 +344,25 @@ void DifferentiableActivityInfo::setUsefulThroughArrayInitialization(
343344
auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser());
344345
assert(ptai && "Expected `pointer_to_address` user for uninitialized "
345346
"array intrinsic");
346-
// Propagate usefulness through array element addresses.
347-
// - Find `store` and `copy_addr` instructions with array element
348-
// address destinations.
349-
// - For each instruction, set destination (array element address) as
350-
// useful and propagate usefulness through source.
347+
setUseful(ptai, dependentVariableIndex);
348+
// Propagate usefulness through array element addresses:
349+
// `pointer_to_address` and `index_addr` instructions.
350+
//
351+
// - Set all array element addresses as useful.
352+
// - Find instructions with array element addresses as "result":
353+
// - `store` and `copy_addr` with array element address as destination.
354+
// - `apply` with array element address as an indirect result.
355+
// - For each instruction, propagate usefulness through "arguments":
356+
// - `store` and `copy_addr`: propagate to source.
357+
// - `apply`: propagate to arguments.
351358
//
352-
// Note: `propagateUseful(use->getUser(), ...)` is intentionally not used
359+
// NOTE: `propagateUseful(use->getUser(), ...)` is intentionally not used
353360
// because it marks more values than necessary as useful, including:
354361
// - The `RawPointer` result of the intrinsic.
355-
// - The `pointer_to_address` user of the `RawPointer`.
356-
// - `index_addr` and `integer_literal` instructions for indexing the
362+
// - `integer_literal` operands to `index_addr` for indexing the
357363
// `RawPointer`.
364+
// It is also blocked by TF-1032: control flow differentiation crash for
365+
// active values with no tangent space.
358366
for (auto use : ptai->getUses()) {
359367
auto *user = use->getUser();
360368
if (auto *si = dyn_cast<StoreInst>(user)) {
@@ -364,7 +372,12 @@ void DifferentiableActivityInfo::setUsefulThroughArrayInitialization(
364372
setUseful(cai->getDest(), dependentVariableIndex);
365373
setUsefulAndPropagateToOperands(cai->getSrc(),
366374
dependentVariableIndex);
375+
} else if (auto *ai = dyn_cast<ApplyInst>(user)) {
376+
if (FullApplySite(ai).isIndirectResultOperand(*use))
377+
for (auto arg : ai->getArgumentsWithoutIndirectResults())
378+
setUsefulAndPropagateToOperands(arg, dependentVariableIndex);
367379
} else if (auto *iai = dyn_cast<IndexAddrInst>(user)) {
380+
setUseful(iai, dependentVariableIndex);
368381
for (auto use : iai->getUses()) {
369382
auto *user = use->getUser();
370383
if (auto si = dyn_cast<StoreInst>(user)) {
@@ -375,6 +388,10 @@ void DifferentiableActivityInfo::setUsefulThroughArrayInitialization(
375388
setUseful(cai->getDest(), dependentVariableIndex);
376389
setUsefulAndPropagateToOperands(cai->getSrc(),
377390
dependentVariableIndex);
391+
} else if (auto *ai = dyn_cast<ApplyInst>(user)) {
392+
if (FullApplySite(ai).isIndirectResultOperand(*use))
393+
for (auto arg : ai->getArgumentsWithoutIndirectResults())
394+
setUsefulAndPropagateToOperands(arg, dependentVariableIndex);
378395
}
379396
}
380397
}

0 commit comments

Comments
 (0)