Skip to content

Commit cfe0bc7

Browse files
authored
[AutoDiff] Make sure AdjointValues are always in the cotangent space (#21633)
1 parent a123f9b commit cfe0bc7

File tree

3 files changed

+136
-33
lines changed

3 files changed

+136
-33
lines changed

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,21 @@ static LoadOwnershipQualifier getBufferLOQ(Type type, SILFunction &fn) {
252252
return LoadOwnershipQualifier::Unqualified;
253253
}
254254

255+
/// Assuming the given type conforms to `Differentiable`, returns the associated
256+
/// cotangent space type.
257+
static SILType getCotangentType(CanType type, SILModule &mod) {
258+
return SILType::getPrimitiveObjectType(
259+
type->getAutoDiffAssociatedVectorSpace(
260+
AutoDiffAssociatedVectorSpaceKind::Cotangent,
261+
LookUpConformanceInModule(mod.getSwiftModule()))->getCanonicalType());
262+
}
263+
264+
/// Assuming the given type conforms to `Differentiable`, returns the associated
265+
/// cotangent space type.
266+
static SILType getCotangentType(SILType type, SILModule &mod) {
267+
return getCotangentType(type.getASTType(), mod);
268+
}
269+
255270
//===----------------------------------------------------------------------===//
256271
// Auxiliary data structures
257272
//===----------------------------------------------------------------------===//
@@ -2891,7 +2906,6 @@ class AdjointValue {
28912906
private:
28922907
static bool isLegalAggregate(ArrayRef<AdjointValue> elements, SILType type) {
28932908
if (auto *structDecl = type.getASTType()->getStructOrBoundGenericStruct()) {
2894-
// TODO: Check whether this struct is @_fixed_layout and ABI public.
28952909
for (auto pair : llvm::zip(structDecl->getStoredProperties(), elements))
28962910
if (!std::get<0>(pair)->getType()->getCanonicalType()
28972911
->isEqual(std::get<1>(pair).getSwiftType()))
@@ -3130,7 +3144,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
31303144
AdjointValue getAdjointValue(SILValue originalValue) {
31313145
assert(originalValue->getFunction() == &getOriginal());
31323146
auto insertion = adjointMap.try_emplace(
3133-
originalValue, AdjointValue::getZero(originalValue->getType()));
3147+
originalValue, AdjointValue::getZero(
3148+
getCotangentType(originalValue->getType(), getModule())));
31343149
return insertion.first->getSecond();
31353150
}
31363151

@@ -3140,6 +3155,15 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
31403155
AdjointValue adjointValue) {
31413156
assert(originalValue->getFunction() == &getOriginal());
31423157
LLVM_DEBUG(getADDebugStream() << "Adding adjoint for " << originalValue);
3158+
#ifndef NDEBUG
3159+
auto origTy = originalValue->getType().getASTType();
3160+
auto cotanSpace = origTy->getAutoDiffAssociatedVectorSpace(
3161+
AutoDiffAssociatedVectorSpaceKind::Cotangent,
3162+
LookUpConformanceInModule(getModule().getSwiftModule()));
3163+
// The adjoint value must be in the cotangent space.
3164+
assert(cotanSpace && adjointValue.getType().getASTType()
3165+
== cotanSpace->getCanonicalType());
3166+
#endif
31433167
auto insertion = adjointMap.try_emplace(originalValue, adjointValue);
31443168
auto inserted = insertion.second;
31453169
auto &value = insertion.first->getSecond();
@@ -3644,25 +3668,32 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
36443668
void visitStructInst(StructInst *si) {
36453669
auto *decl = si->getStructDecl();
36463670
auto av = getAdjointValue(si);
3647-
auto loc = si->getLoc();
36483671
switch (av.getKind()) {
36493672
case AdjointValue::Zero:
36503673
for (auto *field : decl->getStoredProperties()) {
36513674
auto fv = si->getFieldValue(field);
3652-
addAdjointValue(fv, AdjointValue::getZero(fv->getType()));
3675+
addAdjointValue(
3676+
fv, AdjointValue::getZero(getCotangentType(fv->getType(),
3677+
getModule())));
36533678
}
36543679
break;
36553680
case AdjointValue::Materialized: {
3656-
auto adjY = av.getMaterializedValue();
3657-
for (auto *field : decl->getStoredProperties())
3658-
addAdjointValue(si->getFieldValue(field),
3659-
builder.createStructExtract(loc, adjY, field));
3660-
break;
3681+
// FIXME(SR-9602): If `CotangentVector` is not marked
3682+
// `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer.
3683+
// auto adjY = av.getMaterializedValue();
3684+
// for (auto *field : decl->getStoredProperties())
3685+
// addAdjointValue(si->getFieldValue(field),
3686+
// builder.createStructExtract(loc, adjY, field));
3687+
llvm_unreachable("Unhandled. Are you trying to differentiate a "
3688+
"memberwise initializer?");
36613689
}
36623690
case AdjointValue::Aggregate: {
3663-
for (auto pair : llvm::zip(si->getElements(), av.getAggregateElements()))
3664-
addAdjointValue(std::get<0>(pair), std::get<1>(pair));
3665-
break;
3691+
// FIXME(SR-9602): If `CotangentVector` is not marked
3692+
// `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer.
3693+
// for (auto pair : llvm::zip(si->getElements(), av.getAggregateElements()))
3694+
// addAdjointValue(std::get<0>(pair), std::get<1>(pair));
3695+
llvm_unreachable("Unhandled. Are you trying to differentiate a "
3696+
"memberwise initializer?");
36663697
}
36673698
}
36683699
}
@@ -3739,9 +3770,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
37393770
if (field == correspondingField)
37403771
eltVals.push_back(av);
37413772
else
3742-
eltVals.push_back(
3743-
AdjointValue::getZero(SILType::getPrimitiveObjectType(
3744-
field->getType()->getCanonicalType())));
3773+
eltVals.push_back(AdjointValue::getZero(
3774+
getCotangentType(field->getType()->getCanonicalType(),
3775+
getModule())));
37453776
}
37463777
addAdjointValue(sei->getOperand(),
37473778
AdjointValue::getAggregate(cotangentVectorSILTy,
@@ -3789,12 +3820,15 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
37893820
switch (av.getKind()) {
37903821
case AdjointValue::Kind::Zero:
37913822
for (auto eltVal : ti->getElements())
3792-
addAdjointValue(eltVal, AdjointValue::getZero(eltVal->getType()));
3823+
addAdjointValue(eltVal,
3824+
AdjointValue::getZero(getCotangentType(eltVal->getType(),
3825+
getModule())));
37933826
break;
37943827
case AdjointValue::Kind::Materialized:
37953828
for (auto i : range(ti->getNumOperands()))
37963829
addAdjointValue(ti->getOperand(i),
3797-
builder.createTupleExtract(ti->getLoc(), ti, i));
3830+
builder.createTupleExtract(
3831+
ti->getLoc(), av.getMaterializedValue(), i));
37983832
break;
37993833
case AdjointValue::Kind::Aggregate:
38003834
for (auto pair : llvm::zip(ti->getElements(), av.getAggregateElements()))
@@ -3809,28 +3843,26 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
38093843
/// adj[x] = tuple (0, 0, ..., adj[y], ..., 0, 0)
38103844
void visitTupleExtractInst(TupleExtractInst *tei) {
38113845
auto *tupleTy = tei->getTupleType();
3846+
auto tupleCotanTy = getCotangentType(tupleTy->getCanonicalType(),
3847+
getModule());
38123848
auto av = getAdjointValue(tei);
38133849
switch (av.getKind()) {
38143850
case AdjointValue::Kind::Zero:
3815-
addAdjointValue(tei->getOperand(),
3816-
AdjointValue::getZero(SILType::getPrimitiveObjectType(
3817-
tupleTy->getCanonicalType())));
3851+
addAdjointValue(tei->getOperand(), AdjointValue::getZero(tupleCotanTy));
38183852
break;
38193853
case AdjointValue::Kind::Aggregate:
38203854
case AdjointValue::Kind::Materialized: {
38213855
SmallVector<AdjointValue, 8> elements;
38223856
for (unsigned i : range(tupleTy->getNumElements())) {
38233857
if (tei->getFieldNo() == i)
38243858
elements.push_back(av);
3825-
else {
3826-
auto eltTy = SILType::getPrimitiveObjectType(
3827-
tupleTy->getElementType(i)->getCanonicalType());
3828-
elements.push_back(AdjointValue::getZero(eltTy));
3829-
}
3859+
else
3860+
elements.push_back(AdjointValue::getZero(
3861+
getCotangentType(tupleTy->getElementType(i)->getCanonicalType(),
3862+
getModule())));
38303863
}
38313864
addAdjointValue(tei->getOperand(),
3832-
AdjointValue::getAggregate(tei->getOperand()->getType(),
3833-
elements, allocator));
3865+
AdjointValue::getAggregate(tupleCotanTy, elements, allocator));
38343866
break;
38353867
}
38363868
}
@@ -4406,7 +4438,7 @@ void DifferentiationTask::createEmptyPrimal() {
44064438
auto linkage = SILLinkage::Hidden;
44074439
primal = fb.getOrCreateFunction(
44084440
original->getLocation(), primalName, linkage, primalTy,
4409-
original->isBare(), original->isTransparent(), original->isSerialized());
4441+
original->isBare(), IsNotTransparent, original->isSerialized());
44104442
primal->setUnqualifiedOwnership();
44114443
LLVM_DEBUG(getADDebugStream() << "Primal function created \n"
44124444
<< *primal << '\n');
@@ -4531,7 +4563,7 @@ void DifferentiationTask::createEmptyAdjoint() {
45314563
auto linkage = SILLinkage::Hidden;
45324564
adjoint = fb.createFunction(
45334565
linkage, adjName, adjType, original->getGenericEnvironment(),
4534-
original->getLocation(), original->isBare(), original->isTransparent(),
4566+
original->getLocation(), original->isBare(), IsNotTransparent,
45354567
original->isSerialized());
45364568
adjoint->setUnqualifiedOwnership();
45374569
adjoint->setDebugScope(new (module)
@@ -4556,7 +4588,7 @@ void DifferentiationTask::createJVP() {
45564588
jvp = fb.createFunction(original->getLinkage(), jvpName, jvpType,
45574589
original->getGenericEnvironment(),
45584590
original->getLocation(), original->isBare(),
4559-
original->isTransparent(), original->isSerialized());
4591+
IsNotTransparent, original->isSerialized());
45604592
jvp->setUnqualifiedOwnership();
45614593
jvp->setDebugScope(new (module) SILDebugScope(original->getLocation(), jvp));
45624594
attr->setJVPName(jvp->getName());
@@ -4613,7 +4645,7 @@ void DifferentiationTask::createVJP() {
46134645
vjp = fb.createFunction(linkage, vjpName, vjpType,
46144646
original->getGenericEnvironment(),
46154647
original->getLocation(), original->isBare(),
4616-
original->isTransparent(), original->isSerialized());
4648+
IsNotTransparent, original->isSerialized());
46174649
vjp->setUnqualifiedOwnership();
46184650
vjp->setDebugScope(new (module)
46194651
SILDebugScope(original->getLocation(), vjp));
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
#if os(macOS)
6+
import Darwin.C
7+
#else
8+
import Glibc
9+
#endif
10+
11+
var SeparateCotangentTypeTests = TestSuite("SeparateCotangentType")
12+
13+
struct DifferentiableSubset : Differentiable {
14+
@differentiable(wrt: (self))
15+
var w: Float
16+
@differentiable(wrt: (self))
17+
var b: Float
18+
@noDerivative var flag: Bool
19+
20+
// @_fieldwiseProductSpace
21+
struct TangentVector : Differentiable, VectorNumeric {
22+
@_fieldwiseProductSpace
23+
typealias TangentVector = DifferentiableSubset.TangentVector
24+
@_fieldwiseProductSpace
25+
typealias CotangentVector = DifferentiableSubset.CotangentVector
26+
var w: Float
27+
var b: Float
28+
func tangentVector(from cotan: CotangentVector) -> TangentVector {
29+
return TangentVector(w: cotan.w, b: cotan.b)
30+
}
31+
}
32+
// @_fieldwiseProductSpace
33+
struct CotangentVector : Differentiable, VectorNumeric {
34+
@_fieldwiseProductSpace
35+
typealias TangentVector = DifferentiableSubset.CotangentVector
36+
@_fieldwiseProductSpace
37+
typealias CotangentVector = DifferentiableSubset.TangentVector
38+
var w: Float
39+
var b: Float
40+
func tangentVector(from cotan: CotangentVector) -> TangentVector {
41+
return TangentVector(w: cotan.w, b: cotan.b)
42+
}
43+
}
44+
func tangentVector(from cotan: CotangentVector) -> TangentVector {
45+
return TangentVector(w: cotan.w, b: cotan.b)
46+
}
47+
func moved(along v: TangentVector) -> DifferentiableSubset {
48+
return DifferentiableSubset(w: w.moved(along: v.w), b: b.moved(along: v.b), flag: flag)
49+
}
50+
}
51+
52+
SeparateCotangentTypeTests.test("Trivial") {
53+
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
54+
let pb = pullback(at: x) { x in x }
55+
expectEqual(pb(DifferentiableSubset.CotangentVector.zero), DifferentiableSubset.CotangentVector.zero)
56+
}
57+
58+
SeparateCotangentTypeTests.test("Initialization") {
59+
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
60+
let pb = pullback(at: x) { x in DifferentiableSubset(w: 1, b: 2, flag: true) }
61+
expectEqual(pb(DifferentiableSubset.CotangentVector.zero), DifferentiableSubset.CotangentVector.zero)
62+
}
63+
64+
// FIXME(SR-9602): If `CotangentVector` is not marked
65+
// `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer.
66+
// SeparateCotangentTypeTests.test("SomeArithmetics") {
67+
// let x = DifferentiableSubset(w: 0, b: 1, flag: false)
68+
// let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) }
69+
// expectEqual(pb(DifferentiableSubset.CotangentVector.zero), DifferentiableSubset.CotangentVector.zero)
70+
// }
71+
72+
runAllTests()

test/AutoDiff/superset_adjoint_loaded.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,4 @@ func mul3(_ x: Float) -> Float {
1515
let _ = gradient(at: 0, in: mul3)
1616

1717
// CHECK-LABEL: sil{{.*}} @AD__{{.*}}mul3{{.*}}__primal{{.*}}
18-
// CHECK: function_ref static Float._adjointMultiply(_:_:_:_:)
19-
// CHECK: } // end sil function 'AD__{{.*}}mul3{{.*}}__primal{{.*}}'
18+
// CHECK: function_ref AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1

0 commit comments

Comments
 (0)