Skip to content

Commit db52338

Browse files
authored
add SILDeclRef modifier for autodiff functions (#21224)
Adds a new field to SILDeclRef that modifies it to refer to an autodiff function that is associated with the original one. Adds printing/parsing for it. Uses the SILDeclRef to TBDGen a public symbol that the AD pass creates.
1 parent a0a7ee5 commit db52338

File tree

10 files changed

+253
-42
lines changed

10 files changed

+253
-42
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,14 @@ struct SILAutoDiffIndices {
316316
[&s](unsigned p) { s << p; }, [&s]{ s << ' '; });
317317
s << "))";
318318
}
319+
320+
std::string mangle() const {
321+
std::string result = "src_" + llvm::utostr(source) + "_wrt_";
322+
interleave(parameters.set_bits(),
323+
[&](unsigned idx) { result += llvm::utostr(idx); },
324+
[&] { result += '_'; });
325+
return result;
326+
}
319327
};
320328

321329
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,

include/swift/AST/DiagnosticsParse.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ ERROR(expected_sil_value_ownership_kind,none,
483483
"expected value ownership kind in SIL code", ())
484484
ERROR(expected_sil_colon,none,
485485
"expected ':' before %0", (StringRef))
486+
// SWIFT_ENABLE_TENSORFLOW
487+
ERROR(malformed_autodiff_parameter_indices,none,
488+
"malformed autodiff parameter indices", ())
486489

487490
// SIL Values
488491
ERROR(sil_value_redefinition,none,

include/swift/SIL/SILDeclRef.h

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ namespace swift {
3434
enum class EffectsKind : uint8_t;
3535
class AbstractFunctionDecl;
3636
class AbstractClosureExpr;
37+
// SWIFT_ENABLE_TENSORFLOW
38+
class AutoDiffAssociatedFunctionIdentifier;
3739
class ValueDecl;
3840
class FuncDecl;
3941
class ClosureExpr;
@@ -149,15 +151,25 @@ struct SILDeclRef {
149151
/// The default argument index for a default argument getter.
150152
unsigned defaultArgIndex : 10;
151153

154+
// SWIFT_ENABLE_TENSORFLOW
155+
/// When this is non-null, it modifies the SILDeclRef to refer to the
156+
/// corresponding autodiff associated function.
157+
AutoDiffAssociatedFunctionIdentifier *autoDiffAssociatedFunctionIdentifier;
158+
152159
/// Produces a null SILDeclRef.
153160
SILDeclRef() : loc(), kind(Kind::Func),
154161
isCurried(0), isForeign(0), isDirectReference(0),
155-
defaultArgIndex(0) {}
162+
// SWIFT_ENABLE_TENSORFLOW
163+
defaultArgIndex(0),
164+
autoDiffAssociatedFunctionIdentifier(nullptr) {}
156165

157166
/// Produces a SILDeclRef of the given kind for the given decl.
158167
explicit SILDeclRef(ValueDecl *decl, Kind kind,
159168
bool isCurried = false,
160-
bool isForeign = false);
169+
// SWIFT_ENABLE_TENSORFLOW
170+
bool isForeign = false,
171+
AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId =
172+
nullptr);
161173

162174
/// Produces a SILDeclRef for the given ValueDecl or
163175
/// AbstractClosureExpr:
@@ -284,7 +296,10 @@ struct SILDeclRef {
284296
&& isCurried == rhs.isCurried
285297
&& isForeign == rhs.isForeign
286298
&& isDirectReference == rhs.isDirectReference
287-
&& defaultArgIndex == rhs.defaultArgIndex;
299+
// SWIFT_ENABLE_TENSORFLOW
300+
&& defaultArgIndex == rhs.defaultArgIndex
301+
&& autoDiffAssociatedFunctionIdentifier ==
302+
rhs.autoDiffAssociatedFunctionIdentifier;
288303
}
289304
bool operator!=(SILDeclRef rhs) const {
290305
return !(*this == rhs);
@@ -303,15 +318,19 @@ struct SILDeclRef {
303318
bool willBeDirect = isDirectReference;
304319
return SILDeclRef(loc.getOpaqueValue(), kind,
305320
curried, willBeDirect, willBeForeign,
306-
defaultArgIndex);
321+
// SWIFT_ENABLE_TENSORFLOW
322+
defaultArgIndex,
323+
autoDiffAssociatedFunctionIdentifier);
307324
}
308325

309326
/// Returns the foreign (or native) entry point corresponding to the same
310327
/// decl.
311328
SILDeclRef asForeign(bool foreign = true) const {
312329
assert(!isCurried);
313330
return SILDeclRef(loc.getOpaqueValue(), kind,
314-
isCurried, isDirectReference, foreign, defaultArgIndex);
331+
// SWIFT_ENABLE_TENSORFLOW
332+
isCurried, isDirectReference, foreign, defaultArgIndex,
333+
autoDiffAssociatedFunctionIdentifier);
315334
}
316335

317336
SILDeclRef asDirectReference(bool direct = true) const {
@@ -322,6 +341,26 @@ struct SILDeclRef {
322341
return r;
323342
}
324343

344+
// SWIFT_ENABLE_TENSORFLOW
345+
/// Returns the entry point for the corresponding autodiff associated
346+
/// function.
347+
SILDeclRef asAutoDiffAssociatedFunction(
348+
AutoDiffAssociatedFunctionIdentifier *id) const {
349+
assert(!autoDiffAssociatedFunctionIdentifier);
350+
SILDeclRef r = *this;
351+
r.autoDiffAssociatedFunctionIdentifier = id;
352+
return r;
353+
}
354+
355+
/// Returns the entry point for the original function corresponding to an
356+
/// autodiff associated function.
357+
SILDeclRef asAutoDiffOriginalFunction() const {
358+
assert(autoDiffAssociatedFunctionIdentifier);
359+
SILDeclRef r = *this;
360+
r.autoDiffAssociatedFunctionIdentifier = nullptr;
361+
return r;
362+
}
363+
325364
/// True if the decl ref references a thunk from a natively foreign
326365
/// declaration to Swift calling convention.
327366
bool isForeignToNativeThunk() const;
@@ -392,12 +431,16 @@ struct SILDeclRef {
392431
bool isCurried,
393432
bool isDirectReference,
394433
bool isForeign,
395-
unsigned defaultArgIndex)
434+
// SWIFT_ENABLE_TENSORFLOW
435+
unsigned defaultArgIndex,
436+
AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId)
396437
: loc(Loc::getFromOpaqueValue(opaqueLoc)),
397438
kind(kind),
398439
isCurried(isCurried),
399440
isForeign(isForeign), isDirectReference(isDirectReference),
400-
defaultArgIndex(defaultArgIndex)
441+
// SWIFT_ENABLE_TENSORFLOW
442+
defaultArgIndex(defaultArgIndex),
443+
autoDiffAssociatedFunctionIdentifier(autoDiffFuncId)
401444
{}
402445

403446
};
@@ -421,11 +464,13 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
421464

422465
static SILDeclRef getEmptyKey() {
423466
return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func,
424-
false, false, false, 0);
467+
// SWIFT_ENABLE_TENSORFLOW
468+
false, false, false, 0, nullptr);
425469
}
426470
static SILDeclRef getTombstoneKey() {
427471
return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func,
428-
false, false, false, 0);
472+
// SWIFT_ENABLE_TENSORFLOW
473+
false, false, false, 0, nullptr);
429474
}
430475
static unsigned getHashValue(swift::SILDeclRef Val) {
431476
unsigned h1 = PointerInfo::getHashValue(Val.loc.getOpaqueValue());
@@ -435,7 +480,10 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
435480
: UnsignedInfo::getHashValue(Val.isCurried);
436481
unsigned h4 = UnsignedInfo::getHashValue(Val.isForeign);
437482
unsigned h5 = UnsignedInfo::getHashValue(Val.isDirectReference);
438-
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11);
483+
// SWIFT_ENABLE_TENSORFLOW
484+
unsigned h6 =
485+
PointerInfo::getHashValue(Val.autoDiffAssociatedFunctionIdentifier);
486+
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11) ^ (h6 << 13);
439487
}
440488
static bool isEqual(swift::SILDeclRef const &LHS,
441489
swift::SILDeclRef const &RHS) {

lib/ParseSIL/ParseSIL.cpp

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,11 +1584,14 @@ static Optional<AccessorKind> getAccessorKind(StringRef ident) {
15841584
.Default(None);
15851585
}
15861586

1587+
// SWIFT_ENABLE_TENSORFLOW
15871588
/// sil-decl-ref ::= '#' sil-identifier ('.' sil-identifier)* sil-decl-subref?
15881589
/// sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-uncurry-level)?
1589-
/// ('.' sil-decl-lang)?
1590+
/// ('.' sil-decl-lang)? ('.' sil-decl-autodiff)?
15901591
/// sil-decl-subref ::= '!' sil-decl-uncurry-level ('.' sil-decl-lang)?
1591-
/// sil-decl-subref ::= '!' sil-decl-lang
1592+
/// ('.' sil-decl-autodiff)?
1593+
/// sil-decl-subref ::= '!' sil-decl-lang ('.' sil-decl-autodiff)?
1594+
/// sil-decl-subref ::= '!' sil-decl-autodiff
15921595
/// sil-decl-subref-part ::= 'getter'
15931596
/// sil-decl-subref-part ::= 'setter'
15941597
/// sil-decl-subref-part ::= 'allocator'
@@ -1598,6 +1601,12 @@ static Optional<AccessorKind> getAccessorKind(StringRef ident) {
15981601
/// sil-decl-subref-part ::= 'globalaccessor'
15991602
/// sil-decl-uncurry-level ::= [0-9]+
16001603
/// sil-decl-lang ::= 'foreign'
1604+
/// sil-decl-autodiff ::= sil-decl-autodiff-kind '.' sil-decl-autodiff-order
1605+
/// '.' sil-decl-autodiff-indices
1606+
/// sil-decl-autodiff-kind ::= 'jvp'
1607+
/// sil-decl-autodiff-kind ::= 'vjp'
1608+
/// sil-decl-autodiff-order ::= [0-9]+
1609+
/// sil-decl-autodiff-indices ::= [FM][SU]+
16011610
bool SILParser::parseSILDeclRef(SILDeclRef &Result,
16021611
SmallVectorImpl<ValueDecl *> &values) {
16031612
ValueDecl *VD;
@@ -1608,6 +1617,8 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
16081617
SILDeclRef::Kind Kind = SILDeclRef::Kind::Func;
16091618
unsigned uncurryLevel = 0;
16101619
bool IsObjC = false;
1620+
// SWIFT_ENABLE_TENSORFLOW
1621+
AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId = nullptr;
16111622

16121623
if (!P.consumeIf(tok::sil_exclamation)) {
16131624
// Construct SILDeclRef.
@@ -1619,10 +1630,13 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
16191630

16201631
// Handle sil-constant-kind-and-uncurry-level.
16211632
// ParseState indicates the value we just handled.
1622-
// 1 means we just handled Kind, 2 means we just handled uncurryLevel.
1623-
// We accept func|getter|setter|...|foreign or an integer when ParseState is
1624-
// 0; accept foreign or an integer when ParseState is 1; accept foreign when
1625-
// ParseState is 2.
1633+
// SWIFT_ENABLE_TENSORFLOW
1634+
// 1 means we just handled Kind, 2 means we just handled uncurryLevel, 3 means
1635+
// we just handled foreign.
1636+
// We accept func|getter|setter|...|foreign, an autodiff identifier, or an
1637+
// integer when ParseState is 0; accept foreign, an autodiff identifier, or an
1638+
// integer when ParseState is 1; accept foreign or an autodiff identifier when
1639+
// ParseState is 2; accept an autodiff identifier when ParseState is 3.
16261640
unsigned ParseState = 0;
16271641
Identifier Id;
16281642
do {
@@ -1682,8 +1696,49 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
16821696
} else if (!ParseState && Id.str() == "propertyinit") {
16831697
Kind = SILDeclRef::Kind::StoredPropertyInitializer;
16841698
ParseState = 1;
1685-
} else if (Id.str() == "foreign") {
1699+
// SWIFT_ENABLE_TENSORFLOW
1700+
} else if (ParseState < 3 && Id.str() == "foreign") {
16861701
IsObjC = true;
1702+
// SWIFT_ENABLE_TENSORFLOW
1703+
ParseState = 3;
1704+
} else if (Id.str() == "jvp" || Id.str() == "vjp") {
1705+
AutoDiffAssociatedFunctionKind kind;
1706+
unsigned differentiationOrder;
1707+
AutoDiffParameterIndices *parameterIndices = nullptr;
1708+
1709+
if (Id.str() == "jvp")
1710+
kind = AutoDiffAssociatedFunctionKind::JVP;
1711+
else if (Id.str() == "vjp")
1712+
kind = AutoDiffAssociatedFunctionKind::VJP;
1713+
else
1714+
llvm_unreachable("Should only have JVP and VJP here");
1715+
1716+
if (!P.consumeIf(tok::period)) {
1717+
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, ".");
1718+
return true;
1719+
}
1720+
1721+
if (parseInteger(differentiationOrder,
1722+
diag::sil_const_expected_int_value))
1723+
return true;
1724+
1725+
if (!P.consumeIf(tok::period)) {
1726+
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, ".");
1727+
return true;
1728+
}
1729+
1730+
parameterIndices = AutoDiffParameterIndices::create(
1731+
SILMod.getASTContext(), P.Tok.getText());
1732+
if (!parameterIndices) {
1733+
P.diagnose(P.Tok, diag::malformed_autodiff_parameter_indices);
1734+
return true;
1735+
}
1736+
P.consumeToken();
1737+
1738+
autoDiffFuncId = AutoDiffAssociatedFunctionIdentifier::get(
1739+
kind, differentiationOrder, parameterIndices,
1740+
SILMod.getASTContext());
1741+
16871742
break;
16881743
} else
16891744
break;
@@ -1697,7 +1752,8 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
16971752
} while (P.consumeIf(tok::period));
16981753

16991754
// Construct SILDeclRef.
1700-
Result = SILDeclRef(VD, Kind, /*isCurried=*/false, IsObjC);
1755+
// SWIFT_ENABLE_TENSORFLOW
1756+
Result = SILDeclRef(VD, Kind, /*isCurried=*/false, IsObjC, autoDiffFuncId);
17011757
if (uncurryLevel < Result.getParameterListCount() - 1)
17021758
Result = Result.asCurried();
17031759
return false;

lib/SIL/SILDeclRef.cpp

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,21 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
112112
}
113113

114114
SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind,
115-
bool isCurried, bool isForeign)
115+
// SWIFT_ENABLE_TENSORFLOW
116+
bool isCurried, bool isForeign,
117+
AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId)
116118
: loc(vd), kind(kind),
117119
isCurried(isCurried), isForeign(isForeign),
118-
isDirectReference(0), defaultArgIndex(0)
120+
// SWIFT_ENABLE_TENSORFLOW
121+
isDirectReference(0), defaultArgIndex(0),
122+
autoDiffAssociatedFunctionIdentifier(autoDiffFuncId)
119123
{}
120124

121125
SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc,
122-
bool isCurried, bool asForeign)
123-
: isCurried(isCurried), isDirectReference(0), defaultArgIndex(0)
126+
bool isCurried, bool asForeign)
127+
// SWIFT_ENABLE_TENSORFLOW
128+
: isCurried(isCurried), isDirectReference(0), defaultArgIndex(0),
129+
autoDiffAssociatedFunctionIdentifier(nullptr)
124130
{
125131
if (auto *vd = baseLoc.dyn_cast<ValueDecl*>()) {
126132
if (auto *fd = dyn_cast<FuncDecl>(vd)) {
@@ -601,6 +607,30 @@ static void mangleClangDecl(raw_ostream &buffer,
601607
}
602608

603609
std::string SILDeclRef::mangle(ManglingKind MKind) const {
610+
// SWIFT_ENABLE_TENSORFLOW
611+
if (autoDiffAssociatedFunctionIdentifier) {
612+
std::string originalMangled = asAutoDiffOriginalFunction().mangle(MKind);
613+
bool isMethod = cast<AbstractFunctionDecl>(getDecl())->getImplicitSelfDecl()
614+
? true : false;
615+
auto *functionTy =
616+
getDecl()->getInterfaceType()->castTo<AnyFunctionType>();
617+
auto silParameterIndices =
618+
autoDiffAssociatedFunctionIdentifier->getParameterIndices()->getLowered(
619+
functionTy, isMethod);
620+
SILAutoDiffIndices indices(/*source*/ 0, silParameterIndices);
621+
std::string mangledKind;
622+
switch (autoDiffAssociatedFunctionIdentifier->getKind()) {
623+
case AutoDiffAssociatedFunctionKind::JVP:
624+
mangledKind = "jvp";
625+
break;
626+
case AutoDiffAssociatedFunctionKind::VJP:
627+
mangledKind = "vjp";
628+
break;
629+
}
630+
return "AD__" + originalMangled + "__" + mangledKind + "_" +
631+
indices.mangle();
632+
}
633+
604634
using namespace Mangle;
605635
ASTMangler mangler;
606636

lib/SIL/SILPrinter.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,23 @@ void SILDeclRef::print(raw_ostream &OS) const {
343343

344344
if (isDirectReference)
345345
OS << ((isDot || uncurryLevel != 0) ? '.' : '!') << "direct";
346+
347+
// SWIFT_ENABLE_TENSORFLOW
348+
if (autoDiffAssociatedFunctionIdentifier) {
349+
auto *autoDiffFuncId = autoDiffAssociatedFunctionIdentifier;
350+
OS << ((isDot || uncurryLevel != 0 || isForeign || isDirectReference)
351+
? '.' : '!');
352+
switch (autoDiffFuncId->getKind()) {
353+
case AutoDiffAssociatedFunctionKind::JVP:
354+
OS << "jvp.";
355+
break;
356+
case AutoDiffAssociatedFunctionKind::VJP:
357+
OS << "vjp.";
358+
break;
359+
}
360+
OS << autoDiffFuncId->getDifferentiationOrder() << "."
361+
<< autoDiffFuncId->getParameterIndices()->getString();
362+
}
346363
}
347364

348365
void SILDeclRef::dump() const {

0 commit comments

Comments
 (0)