Skip to content

Commit 358c1ce

Browse files
committed
Fix incorrect ref_element_addr derivative values.
Add `ref_element_addr` case to `PullbackEmitter::getAdjointProjection`. The adjoint projection of a `ref_element_addr` is a local allocation initialized with the corresponding field value from the class's base adjoint value. Fixes incorrect `ref_element_addr` zero derivatives. Add class initializer differentiation tests.
1 parent 610214a commit 358c1ce

File tree

3 files changed

+78
-19
lines changed

3 files changed

+78
-19
lines changed

include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,17 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
266266

267267
SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint();
268268

269+
/// Creates and returns a local allocation with the given type.
270+
///
271+
/// Local allocations are created uninitialized in the pullback entry and
272+
/// deallocated in the pullback exit. All local allocations not in
273+
/// `destroyedLocalAllocations` are also destroyed in the pullback exit.
274+
AllocStackInst *createFunctionLocalAllocation(SILType type, SILLocation loc);
275+
269276
SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer);
270277

271-
// Accumulates `rhsBufferAccess` into the adjoint buffer corresponding to
272-
// `originalBuffer`.
278+
/// Accumulates `rhsBufferAccess` into the adjoint buffer corresponding to
279+
/// `originalBuffer`.
273280
void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer,
274281
SILValue rhsBufferAccess, SILLocation loc);
275282

lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#define DEBUG_TYPE "differentiation"
2020

2121
#include "swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h"
22+
#include "swift/SIL/InstructionUtils.h"
2223
#include "swift/SIL/Projection.h"
2324
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
2425
#include "swift/SILOptimizer/Utils/Differentiation/ADContext.h"
@@ -351,6 +352,43 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
351352
}
352353
return builder.createTupleElementAddr(teai->getLoc(), adjSource, adjIndex);
353354
}
355+
// Handle `ref_element_addr`.
356+
// Adjoint projection: a local allocation initialized with the corresponding
357+
// field value from the class's base adjoint value.
358+
if (auto *reai = dyn_cast<RefElementAddrInst>(originalProjection)) {
359+
assert(!reai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
360+
"`@noDerivative` class projections should never be active");
361+
auto loc = reai->getLoc();
362+
// Get the class operand, stripping `begin_borrow`.
363+
auto classOperand = stripBorrow(reai->getOperand());
364+
// Get the class operand's adjoint value. Currently, it must be a
365+
// `TangentVector` struct.
366+
auto adjClass =
367+
materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc);
368+
auto *tangentVectorDecl =
369+
adjClass->getType().getStructOrBoundGenericStruct();
370+
// TODO(TF-970): Replace assertions below with diagnostics.
371+
assert(tangentVectorDecl && "`TangentVector` of a class must be a struct");
372+
auto tanFieldLookup =
373+
tangentVectorDecl->lookupDirect(reai->getField()->getName());
374+
assert(tanFieldLookup.size() == 1 &&
375+
"Class `TangentVector` must have field of the same name");
376+
auto *tanField = cast<VarDecl>(tanFieldLookup.front());
377+
// Create a local allocation for the element adjoint buffer.
378+
auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType();
379+
auto eltTanSILType = SILType::getPrimitiveAddressType(eltTanType);
380+
auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc);
381+
builder.emitScopedBorrowOperation(
382+
loc, adjClass, [&](SILValue borrowedAdjClass) {
383+
// Initialize the element adjoint buffer with the base adjoint value.
384+
auto *adjElt =
385+
builder.createStructExtract(loc, borrowedAdjClass, tanField);
386+
auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt);
387+
builder.emitStoreValueOperation(loc, adjEltCopy, eltAdjBuffer,
388+
StoreOwnershipQualifier::Init);
389+
});
390+
return eltAdjBuffer;
391+
}
354392
// Handle `begin_access`.
355393
// Adjoint projection: the base adjoint buffer itself.
356394
if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) {
@@ -412,6 +450,20 @@ PullbackEmitter::getNextFunctionLocalAllocationInsertionPoint() {
412450
return lastLocalAlloc->getDefiningInstruction()->getIterator();
413451
}
414452

453+
AllocStackInst *
454+
PullbackEmitter::createFunctionLocalAllocation(SILType type, SILLocation loc) {
455+
// Set insertion point for local allocation builder: before the last local
456+
// allocation, or at the start of the pullback function's entry if no local
457+
// allocations exist yet.
458+
localAllocBuilder.setInsertionPoint(
459+
getPullback().getEntryBlock(),
460+
getNextFunctionLocalAllocationInsertionPoint());
461+
// Create and return local allocation.
462+
auto *alloc = localAllocBuilder.createAllocStack(loc, type);
463+
functionLocalAllocations.push_back(alloc);
464+
return alloc;
465+
}
466+
415467
SILValue &PullbackEmitter::getAdjointBuffer(SILBasicBlock *origBB,
416468
SILValue originalBuffer) {
417469
assert(originalBuffer->getType().isAddress());
@@ -425,25 +477,19 @@ SILValue &PullbackEmitter::getAdjointBuffer(SILBasicBlock *origBB,
425477
if (auto adjProj = getAdjointProjection(origBB, originalBuffer))
426478
return (bufferMap[{origBB, originalBuffer}] = adjProj);
427479

480+
auto bufObjectType = getRemappedTangentType(originalBuffer->getType());
428481
// Set insertion point for local allocation builder: before the last local
429482
// allocation, or at the start of the pullback function's entry if no local
430483
// allocations exist yet.
431-
localAllocBuilder.setInsertionPoint(
432-
getPullback().getEntryBlock(),
433-
getNextFunctionLocalAllocationInsertionPoint());
434-
// Allocate local buffer and initialize to zero.
435-
auto bufObjectType = getRemappedTangentType(originalBuffer->getType());
436-
auto *newBuf = localAllocBuilder.createAllocStack(
437-
RegularLocation::getAutoGeneratedLocation(), bufObjectType);
484+
auto *newBuf = createFunctionLocalAllocation(
485+
bufObjectType, RegularLocation::getAutoGeneratedLocation());
438486
// Temporarily change global builder insertion point and emit zero into the
439-
// local buffer.
487+
// local allocation.
440488
auto insertionPoint = builder.getInsertionBB();
441489
builder.setInsertionPoint(localAllocBuilder.getInsertionBB(),
442490
localAllocBuilder.getInsertionPoint());
443491
emitZeroIndirect(bufObjectType.getASTType(), newBuf, newBuf->getLoc());
444492
builder.setInsertionPoint(insertionPoint);
445-
// Register the local buffer.
446-
functionLocalAllocations.push_back(newBuf);
447493
return (insertion.first->getSecond() = newBuf);
448494
}
449495

@@ -1082,13 +1128,9 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
10821128
{addArithConf, diffConf});
10831129
// %elt_adj = alloc_stack $T.TangentVector
10841130
// Create and register a local allocation.
1085-
localAllocBuilder.setInsertionPoint(
1086-
getPullback().getEntryBlock(),
1087-
getNextFunctionLocalAllocationInsertionPoint());
1088-
auto *eltAdjBuffer = localAllocBuilder.createAllocStack(loc, eltTanSILType);
1089-
functionLocalAllocations.push_back(eltAdjBuffer);
1131+
auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc);
10901132
// Temporarily change global builder insertion point and emit zero into the
1091-
// local buffer.
1133+
// local allocation.
10921134
auto insertionPoint = builder.getInsertionBB();
10931135
builder.setInsertionPoint(localAllocBuilder.getInsertionBB(),
10941136
localAllocBuilder.getInsertionPoint());

test/AutoDiff/downstream/class_differentiation.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,17 @@ ClassTests.test("TrivialMember") {
3434
static func controlFlow(_ c1: C, _ c2: C, _ flag: Bool) -> Float {
3535
var result: Float = 0
3636
if flag {
37-
result = c1.float * c2.float
37+
var c3 = C(c1.float * c2.float)
38+
result = c3.float
3839
} else {
3940
result = c2.float * c1.float
4041
}
4142
return result
4243
}
4344
}
45+
// Test class initializer differentiation.
46+
expectEqual(10, pullback(at: 3, in: { C($0) })(.init(float: 10)))
47+
// Test class method differentiation.
4448
expectEqual((.init(float: 3), 10), gradient(at: C(10), 3, in: { c, x in c.method(x) }))
4549
expectEqual(.init(float: 0), gradient(at: C(10), in: { c in c.testNoDerivative() }))
4650
expectEqual((.init(float: 20), .init(float: 10)),
@@ -72,6 +76,9 @@ ClassTests.test("NontrivialMember") {
7276
return result
7377
}
7478
}
79+
// Test class initializer differentiation.
80+
expectEqual(10, pullback(at: 3, in: { C($0) })(.init(float: 10)))
81+
// Test class method differentiation.
7582
expectEqual((.init(float: 3), 10), gradient(at: C(10), 3, in: { c, x in c.method(x) }))
7683
expectEqual((.init(float: 20), .init(float: 10)),
7784
gradient(at: C(10), C(20), in: { c1, c2 in C.controlFlow(c1, c2, true) }))
@@ -94,6 +101,9 @@ ClassTests.test("AddressOnlyTangentVector") {
94101
stored
95102
}
96103
}
104+
// Test class initializer differentiation.
105+
expectEqual(10, pullback(at: 3, in: { C<Float>($0) })(.init(float: 10)))
106+
// Test class method differentiation.
97107
expectEqual((.init(stored: Float(3)), 10),
98108
gradient(at: C<Float>(3), 3, in: { c, x in c.method(x) }))
99109
}

0 commit comments

Comments
 (0)