-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Array literal differentiation fixes. #28889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I tried various ways to simplify the pullback generation changes in this PR but wasn't yet successful. I'll leave some more comments describing the simplifications I tried and why they didn't work.
Activity analysis: - Mark array literal element addresses (`pointer_to_address` and `index_addr` instructions) as useful. These addresses legitimately need a derivative. - Propagate usefulness through array literal element addresses that are `apply` indirect results to the `apply` arguments. Pullback generation: - Add special case for array literal element addresses to `PullbackEmitter::getAdjointProjection`. The adjoint projection is a local allocation initialized from the array literal's adjoint value by applying `Array.TangentVector.subscript`. - This generalizes the old logic in `PullbackEmitter::visit{Store,CopyAddr}Inst` for handling array literal element addresses, which is now removed. - When an array literal's adjoint value is updated via `PullbackEmitter::addAdjointValue`, accumulate the array literal's adjoint value into the adjoint buffers of its element addresses. Resolves cleanup task: TF-976. Fixes correctness issues: - TF-975: nested array literals. - TF-978: array literal element address initialized as `apply` indirect result. Update activity analysis and derivative correctness tests.
if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) { | ||
auto adjBase = getAdjointBuffer(origBB, bai->getOperand()); | ||
if (errorOccurred) | ||
return (bufferMap[{origBB, originalProjection}] = SILValue()); | ||
// Return the base buffer's adjoint buffer. | ||
return adjBase; | ||
} | ||
// Handle `array.uninitialized_intrinsic` application element addresses. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR introduces two main changes to pullback generation:
- Adds a case to
PullbackEmitter::getAdjointProjection
for array literal element addresses. - Adds special logic to
PullbackEmitter::addAdjointValue
: when adding adjoint value for array literal, accumulate also for its element addresses.
With (1) and (2), TF-975 and TF-978 tests are fixed. I wondered if (1) can be dropped: perhaps array literal element addresses don't need special adjoint buffer initialization, and special logic for accumulating element address adjoint buffers in PullbackEmitter::addAdjointValue
or PullbackEmitter::setAdjointValue
is sufficient.
Here are the results:
Dropping (1) and keeping (2) for PullbackEmitter::addAdjointValue
leads to correctness issues:
[ RUN ] ArrayAutoDiff.ArrayIdentity
[ OK ] ArrayAutoDiff.ArrayIdentity
[ RUN ] ArrayAutoDiff.ArraySubscript
[ OK ] ArrayAutoDiff.ArraySubscript
[ RUN ] ArrayAutoDiff.ArrayLiteral
stdout>>> check failed at /Users/danielzheng/swift-bart/swift/test/AutoDiff/array.swift, line 38
stdout>>> expected: Tracked(1.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> actual: Tracked(0.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> check failed at /Users/danielzheng/swift-bart/swift/test/AutoDiff/array.swift, line 38
stdout>>> expected: Tracked(2.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> actual: Tracked(0.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> check failed at /Users/danielzheng/swift-bart/swift/test/AutoDiff/array.swift, line 49
stdout>>> expected: Tracked(8.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> actual: Tracked(0.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> check failed at /Users/danielzheng/swift-bart/swift/test/AutoDiff/array.swift, line 49
stdout>>> expected: Tracked(6.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> actual: Tracked(0.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> check failed at /Users/danielzheng/swift-bart/swift/test/AutoDiff/array.swift, line 58
stdout>>> expected: Tracked(8.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> actual: Tracked(0.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> check failed at /Users/danielzheng/swift-bart/swift/test/AutoDiff/array.swift, line 58
stdout>>> expected: Tracked(6.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> actual: Tracked(0.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> check failed at /Users/danielzheng/swift-bart/swift/test/AutoDiff/array.swift, line 68
stdout>>> expected: Tracked(8.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> actual: Tracked(0.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> check failed at /Users/danielzheng/swift-bart/swift/test/AutoDiff/array.swift, line 68
stdout>>> expected: Tracked(6.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
stdout>>> actual: Tracked(0.0) (of type DifferentiationUnittest.Tracked<Swift.Float>)
[ FAIL ] ArrayAutoDiff.ArrayLiteral
...
This may be because PullbackEmitter::addAdjointValue
isn't always called for array literal values. PullbackEmitter::setAdjointValue
should always be called, though.
Dropping (1) and keeping (2) for PullbackEmitter::setAdjointValue
leads to SIL verification failed: Basic blocks must end with a terminator instruction: isa<TermInst>(BB.back())
issues: https://gist.github.com/dan-zheng/504a2ab65f8b999b8ce9478854d5f46b
The insertion point of builder
in PullbackEmitter::accumulateArrayLiteralElementAddressAdjoints
seems sensitive. Perhaps it can be massaged in a way that fixes the SIL verification errors while enabling (1) to be dropped.
getNextFunctionLocalAllocationInsertionPoint()); | ||
auto *eltAdjBuffer = localAllocBuilder.createAllocStack(loc, eltTanSILType); | ||
functionLocalAllocations.push_back(eltAdjBuffer); | ||
// Temporarily change global builder insertion point and emit zero into the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are the errors from commenting the highlighted code (zero initialization + destroy): https://gist.github.com/dan-zheng/56e1233db1646cd4aceaeef23d02555b
SIL memory lifetime failure in @AD__$s4mainyycfU6_24controlFlowNestedLiteralL_ySay23DifferentiationUnittest7TrackedVySfGGAF_AFSbtF__pullback_src_0_wrt_0_1: memory is not initialized, but should
memory location: %27 = alloc_stack $Array<Tracked<Float>>.DifferentiableView // users: %1616, %1615, %536
at instruction: %536 = apply %534<Array<Tracked<Float>>.DifferentiableView>(%152, %27, %535) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
...
This error occurs for functions with control flow, since the local allocation isn't guaranteed to be initialized along all control flow paths.
Verifying that tests pass. |
do { | ||
// Test nested array literal and control flow. | ||
func controlFlowNestedLiteral( | ||
_ x: Tracked<Float>, _ y: Tracked<Float>, _ bool: Bool = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 extra spaces, blargh - I'll fix this later
_ x: Tracked<Float>, _ y: Tracked<Float>, _ bool: Bool = true | |
_ x: Tracked<Float>, _ y: Tracked<Float>, _ bool: Bool = true |
// | ||
// Note: `propagateUseful(use->getUser(), ...)` is intentionally not used | ||
// NOTE: `propagateUseful(use->getUser(), ...)` is intentionally not used |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: rather than pattern-matching yet another special case of "writing to an array literal element address" (store
, copy_addr
, and now apply
indirect result) here in DifferentiableActivityInfo::setUsefulThroughArrayInitialization
, it may be nice to call a blanket propagateUseful(use->getUser(), ...)
, at risk of marking some "junk" values as useful:
- The
array.uninitialized_intrinsic
RawPointer
result integer_literal
operands toindex_addr
element addresses.
In practice, calling propagateUseful(use->getUser(), ...)
is blocked by TF-1032, which I plan to investigate soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm merging this PR now since it's blocking a user (see comments on TF-975).
I'm not particularly happy with the complexity introduced by these changes - other PR comments (here and here) describe opportunities for simplification. I plan to reinvestigate these simplifications after resolving TF-977: the last known array literal differentiation bug FWICT.
Happy to address any review feedback!
Activity analysis:
pointer_to_address
andindex_addr
instructions) as useful. These addresses legitimately need a derivative.
apply
indirect results to the
apply
arguments.Pullback generation:
PullbackEmitter::getAdjointProjection
. The adjoint projection is a localallocation initialized from the array literal's adjoint value by applying
Array.TangentVector.subscript
.PullbackEmitter::visit{Store,CopyAddr}Inst
for handling array literal element addresses, which is now removed.
PullbackEmitter::addAdjointValue
, accumulate the array literal's adjointvalue into the adjoint buffers of its element addresses.
Resolves cleanup task: TF-976.
Fixes correctness issues:
apply
indirect result.Update activity analysis and derivative correctness tests.
Examples: