Skip to content

Commit d7b12c2

Browse files
committed
Add requested test.
Fix `ref_element_addr` adjoint buffer type remapping. Garden tests.
1 parent e6c9886 commit d7b12c2

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
376376
auto *tanField = cast<VarDecl>(tanFieldLookup.front());
377377
// Create a local allocation for the element adjoint buffer.
378378
auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType();
379-
auto eltTanSILType = SILType::getPrimitiveAddressType(eltTanType);
379+
auto eltTanSILType =
380+
remapType(SILType::getPrimitiveAddressType(eltTanType));
380381
auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc);
381382
builder.emitScopedBorrowOperation(
382383
loc, adjClass, [&](SILValue borrowedAdjClass) {
@@ -1090,7 +1091,7 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
10901091
auto arrayTanType = cast<StructType>(arrayAdjoint->getType().getASTType());
10911092
auto arrayType = arrayTanType->getParent()->castTo<BoundGenericStructType>();
10921093
auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType();
1093-
auto eltTanSILType = SILType::getPrimitiveAddressType(eltTanType);
1094+
auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType));
10941095
// Get `function_ref` and generic signature of
10951096
// `Array.TangentVector.subscript.getter`.
10961097
auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct();
@@ -1602,12 +1603,11 @@ void PullbackEmitter::visitLoadOperation(SingleValueInstruction *inst) {
16021603
void PullbackEmitter::visitStoreOperation(SILBasicBlock *bb, SILLocation loc,
16031604
SILValue origSrc, SILValue origDest) {
16041605
auto &adjBuf = getAdjointBuffer(bb, origDest);
1605-
auto bufType = remapType(adjBuf->getType());
16061606
auto adjVal =
16071607
builder.emitLoadValueOperation(loc, adjBuf, LoadOwnershipQualifier::Take);
16081608
recordTemporary(adjVal);
16091609
addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc);
1610-
emitZeroIndirect(bufType.getASTType(), adjBuf, loc);
1610+
emitZeroIndirect(adjBuf->getType().getASTType(), adjBuf, loc);
16111611
}
16121612

16131613
void PullbackEmitter::visitStoreInst(StoreInst *si) {

test/AutoDiff/downstream/class_differentiation.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@ ClassTests.test("TrivialMember") {
1616
@noDerivative
1717
final var noDerivative: Float = 1
1818

19+
@differentiable
1920
init(_ float: Float) {
2021
self.float = float
2122
}
2223

24+
@differentiable
25+
convenience init(convenience x: Float) {
26+
self.init(x)
27+
}
28+
2329
@differentiable
2430
func method(_ x: Float) -> Float {
2531
x * float
@@ -44,6 +50,7 @@ ClassTests.test("TrivialMember") {
4450
}
4551
// Test class initializer differentiation.
4652
expectEqual(10, pullback(at: 3, in: { C($0) })(.init(float: 10)))
53+
expectEqual(10, pullback(at: 3, in: { C(convenience: $0) })(.init(float: 10)))
4754
// Test class method differentiation.
4855
expectEqual((.init(float: 3), 10), gradient(at: C(10), 3, in: { c, x in c.method(x) }))
4956
expectEqual(.init(float: 0), gradient(at: C(10), in: { c in c.testNoDerivative() }))
@@ -56,6 +63,7 @@ ClassTests.test("NontrivialMember") {
5663
@differentiable
5764
var float: Tracked<Float>
5865

66+
@differentiable
5967
init(_ float: Tracked<Float>) {
6068
self.float = float
6169
}
@@ -84,6 +92,26 @@ ClassTests.test("NontrivialMember") {
8492
gradient(at: C(10), C(20), in: { c1, c2 in C.controlFlow(c1, c2, true) }))
8593
}
8694

95+
ClassTests.test("GenericNontrivialMember") {
96+
final class C<T: Differentiable>: Differentiable where T == T.TangentVector {
97+
@differentiable
98+
var x: Tracked<T>
99+
100+
@differentiable
101+
init(_ x: T) {
102+
self.x = Tracked(x)
103+
}
104+
105+
@differentiable
106+
convenience init(convenience x: T) {
107+
self.init(x)
108+
}
109+
}
110+
// Test class initializer differentiation.
111+
expectEqual(10, pullback(at: 3, in: { C<Float>($0) })(.init(x: 10)))
112+
expectEqual(10, pullback(at: 3, in: { C<Float>(convenience: $0) })(.init(x: 10)))
113+
}
114+
87115
// TF-1149: Test class with loadable type but address-only `TangentVector` type.
88116
// TODO(TF-1149): Uncomment when supported.
89117
/*

0 commit comments

Comments
 (0)