Skip to content

Commit aab622e

Browse files
authored
[AutoDiff upstream] Add derivative function SILDeclRefs. (#30564)
`@differentiable` attribute on protocol requirements and non-final class members will produce derivative function entries in witness tables and vtables. This patch adds an optional derivative function configuration (`AutoDiffDerivativeFunctionIdentifier`) to `SILDeclRef` to represent these derivative function entries. Derivative function configurations consist of: - A derivative function kind (JVP or VJP). - Differentiability parameter indices. Resolves TF-1209. Enables TF-1212: upstream derivative function entries in witness tables/vtables.
1 parent eb93cd6 commit aab622e

File tree

11 files changed

+268
-27
lines changed

11 files changed

+268
-27
lines changed

docs/SIL.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,8 +1088,9 @@ Declaration References
10881088
::
10891089

10901090
sil-decl-ref ::= '#' sil-identifier ('.' sil-identifier)* sil-decl-subref?
1091-
sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)?
1091+
sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)? ('.' sil-decl-autodiff)?
10921092
sil-decl-subref ::= '!' sil-decl-lang
1093+
sil-decl-subref ::= '!' sil-decl-autodiff
10931094
sil-decl-subref-part ::= 'getter'
10941095
sil-decl-subref-part ::= 'setter'
10951096
sil-decl-subref-part ::= 'allocator'
@@ -1102,6 +1103,10 @@ Declaration References
11021103
sil-decl-subref-part ::= 'ivarinitializer'
11031104
sil-decl-subref-part ::= 'defaultarg' '.' [0-9]+
11041105
sil-decl-lang ::= 'foreign'
1106+
sil-decl-autodiff ::= sil-decl-autodiff-kind '.' sil-decl-autodiff-indices
1107+
sil-decl-autodiff-kind ::= 'jvp'
1108+
sil-decl-autodiff-kind ::= 'vjp'
1109+
sil-decl-autodiff-indices ::= [SU]+
11051110

11061111
Some SIL instructions need to reference Swift declarations directly. These
11071112
references are introduced with the ``#`` sigil followed by the fully qualified

include/swift/AST/AutoDiff.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,42 @@ struct AutoDiffDerivativeFunctionKind {
7575
}
7676
};
7777

78+
/// A derivative function configuration, uniqued in `ASTContext`.
79+
/// Identifies a specific derivative function given an original function.
80+
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
81+
const AutoDiffDerivativeFunctionKind kind;
82+
IndexSubset *const parameterIndices;
83+
GenericSignature derivativeGenericSignature;
84+
85+
AutoDiffDerivativeFunctionIdentifier(
86+
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,
87+
GenericSignature derivativeGenericSignature)
88+
: kind(kind), parameterIndices(parameterIndices),
89+
derivativeGenericSignature(derivativeGenericSignature) {}
90+
91+
public:
92+
AutoDiffDerivativeFunctionKind getKind() const { return kind; }
93+
IndexSubset *getParameterIndices() const {
94+
return parameterIndices;
95+
}
96+
GenericSignature getDerivativeGenericSignature() const {
97+
return derivativeGenericSignature;
98+
}
99+
100+
static AutoDiffDerivativeFunctionIdentifier *
101+
get(AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,
102+
GenericSignature derivativeGenericSignature, ASTContext &C);
103+
104+
void Profile(llvm::FoldingSetNodeID &ID) {
105+
ID.AddInteger(kind);
106+
ID.AddPointer(parameterIndices);
107+
CanGenericSignature derivativeCanGenSig;
108+
if (derivativeGenericSignature)
109+
derivativeCanGenSig = derivativeGenericSignature->getCanonicalSignature();
110+
ID.AddPointer(derivativeCanGenSig.getPointer());
111+
}
112+
};
113+
78114
/// The kind of a differentiability witness function.
79115
struct DifferentiabilityWitnessFunctionKind {
80116
enum innerty : uint8_t {

include/swift/AST/DiagnosticsParse.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,9 @@ ERROR(expected_sil_colon,none,
504504
"expected ':' before %0", (StringRef))
505505
ERROR(expected_sil_tuple_index,none,
506506
"expected tuple element index", ())
507+
ERROR(invalid_index_subset,none,
508+
"invalid index subset; expected '[SU]+' where 'S' represents set indices "
509+
"and 'U' represents unset indices", ())
507510

508511
// SIL Values
509512
ERROR(sil_value_redefinition,none,

include/swift/SIL/SILDeclRef.h

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ namespace swift {
3434
enum class EffectsKind : uint8_t;
3535
class AbstractFunctionDecl;
3636
class AbstractClosureExpr;
37+
class AutoDiffDerivativeFunctionIdentifier;
3738
class ValueDecl;
3839
class FuncDecl;
3940
class ClosureExpr;
@@ -147,13 +148,17 @@ struct SILDeclRef {
147148
unsigned isForeign : 1;
148149
/// The default argument index for a default argument getter.
149150
unsigned defaultArgIndex : 10;
151+
/// The derivative function identifier.
152+
AutoDiffDerivativeFunctionIdentifier *derivativeFunctionIdentifier = nullptr;
150153

151154
/// Produces a null SILDeclRef.
152-
SILDeclRef() : loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0) {}
155+
SILDeclRef() : loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0),
156+
derivativeFunctionIdentifier(nullptr) {}
153157

154158
/// Produces a SILDeclRef of the given kind for the given decl.
155159
explicit SILDeclRef(ValueDecl *decl, Kind kind,
156-
bool isForeign = false);
160+
bool isForeign = false,
161+
AutoDiffDerivativeFunctionIdentifier *derivativeId = nullptr);
157162

158163
/// Produces a SILDeclRef for the given ValueDecl or
159164
/// AbstractClosureExpr:
@@ -166,8 +171,7 @@ struct SILDeclRef {
166171
/// for the containing ClassDecl.
167172
/// - If 'loc' is a global VarDecl, this returns its GlobalAccessor
168173
/// SILDeclRef.
169-
explicit SILDeclRef(Loc loc,
170-
bool isForeign = false);
174+
explicit SILDeclRef(Loc loc, bool isForeign = false);
171175

172176
/// Produce a SIL constant for a default argument generator.
173177
static SILDeclRef getDefaultArgGenerator(Loc loc, unsigned defaultArgIndex);
@@ -282,7 +286,8 @@ struct SILDeclRef {
282286
return loc.getOpaqueValue() == rhs.loc.getOpaqueValue()
283287
&& kind == rhs.kind
284288
&& isForeign == rhs.isForeign
285-
&& defaultArgIndex == rhs.defaultArgIndex;
289+
&& defaultArgIndex == rhs.defaultArgIndex
290+
&& derivativeFunctionIdentifier == rhs.derivativeFunctionIdentifier;
286291
}
287292
bool operator!=(SILDeclRef rhs) const {
288293
return !(*this == rhs);
@@ -297,7 +302,26 @@ struct SILDeclRef {
297302
/// decl.
298303
SILDeclRef asForeign(bool foreign = true) const {
299304
return SILDeclRef(loc.getOpaqueValue(), kind,
300-
foreign, defaultArgIndex);
305+
foreign, defaultArgIndex, derivativeFunctionIdentifier);
306+
}
307+
308+
/// Returns the entry point for the corresponding autodiff derivative
309+
/// function.
310+
SILDeclRef asAutoDiffDerivativeFunction(
311+
AutoDiffDerivativeFunctionIdentifier *derivativeId) const {
312+
assert(!derivativeFunctionIdentifier);
313+
SILDeclRef declRef = *this;
314+
declRef.derivativeFunctionIdentifier = derivativeId;
315+
return declRef;
316+
}
317+
318+
/// Returns the entry point for the original function corresponding to an
319+
/// autodiff derivative function.
320+
SILDeclRef asAutoDiffOriginalFunction() const {
321+
assert(derivativeFunctionIdentifier);
322+
SILDeclRef declRef = *this;
323+
declRef.derivativeFunctionIdentifier = nullptr;
324+
return declRef;
301325
}
302326

303327
/// True if the decl ref references a thunk from a natively foreign
@@ -372,9 +396,11 @@ struct SILDeclRef {
372396
explicit SILDeclRef(void *opaqueLoc,
373397
Kind kind,
374398
bool isForeign,
375-
unsigned defaultArgIndex)
399+
unsigned defaultArgIndex,
400+
AutoDiffDerivativeFunctionIdentifier *derivativeId)
376401
: loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind),
377-
isForeign(isForeign), defaultArgIndex(defaultArgIndex)
402+
isForeign(isForeign), defaultArgIndex(defaultArgIndex),
403+
derivativeFunctionIdentifier(derivativeId)
378404
{}
379405

380406
};
@@ -398,11 +424,11 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
398424

399425
static SILDeclRef getEmptyKey() {
400426
return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func,
401-
false, 0);
427+
false, 0, nullptr);
402428
}
403429
static SILDeclRef getTombstoneKey() {
404430
return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func,
405-
false, 0);
431+
false, 0, nullptr);
406432
}
407433
static unsigned getHashValue(swift::SILDeclRef Val) {
408434
unsigned h1 = PointerInfo::getHashValue(Val.loc.getOpaqueValue());
@@ -411,7 +437,9 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
411437
? UnsignedInfo::getHashValue(Val.defaultArgIndex)
412438
: 0;
413439
unsigned h4 = UnsignedInfo::getHashValue(Val.isForeign);
414-
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7);
440+
unsigned h5 =
441+
PointerInfo::getHashValue(Val.derivativeFunctionIdentifier);
442+
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11);
415443
}
416444
static bool isEqual(swift::SILDeclRef const &LHS,
417445
swift::SILDeclRef const &RHS) {

lib/AST/ASTContext.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,9 @@ struct ASTContext::Implementation {
420420
llvm::FoldingSet<BuiltinVectorType> BuiltinVectorTypes;
421421
llvm::FoldingSet<DeclName::CompoundDeclName> CompoundNames;
422422
llvm::DenseMap<UUID, OpenedArchetypeType *> OpenedExistentialArchetypes;
423-
424-
/// For uniquifying `IndexSubset` allocations.
425423
llvm::FoldingSet<IndexSubset> IndexSubsets;
424+
llvm::FoldingSet<AutoDiffDerivativeFunctionIdentifier>
425+
AutoDiffDerivativeFunctionIdentifiers;
426426

427427
/// A cache of information about whether particular nominal types
428428
/// are representable in a foreign language.
@@ -4754,3 +4754,30 @@ IndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) {
47544754
foldingSet.InsertNode(newNode, insertPos);
47554755
return newNode;
47564756
}
4757+
4758+
AutoDiffDerivativeFunctionIdentifier *AutoDiffDerivativeFunctionIdentifier::get(
4759+
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,
4760+
GenericSignature derivativeGenericSignature, ASTContext &C) {
4761+
assert(parameterIndices);
4762+
auto &foldingSet = C.getImpl().AutoDiffDerivativeFunctionIdentifiers;
4763+
llvm::FoldingSetNodeID id;
4764+
id.AddInteger((unsigned)kind);
4765+
id.AddPointer(parameterIndices);
4766+
CanGenericSignature derivativeCanGenSig;
4767+
if (derivativeGenericSignature)
4768+
derivativeCanGenSig = derivativeGenericSignature->getCanonicalSignature();
4769+
id.AddPointer(derivativeCanGenSig.getPointer());
4770+
4771+
void *insertPos;
4772+
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
4773+
if (existing)
4774+
return existing;
4775+
4776+
void *mem = C.Allocate(sizeof(AutoDiffDerivativeFunctionIdentifier),
4777+
alignof(AutoDiffDerivativeFunctionIdentifier));
4778+
auto *newNode = ::new (mem) AutoDiffDerivativeFunctionIdentifier(
4779+
kind, parameterIndices, derivativeGenericSignature);
4780+
foldingSet.InsertNode(newNode, insertPos);
4781+
4782+
return newNode;
4783+
}

lib/AST/AutoDiff.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919

2020
using namespace swift;
2121

22+
AutoDiffDerivativeFunctionKind::
23+
AutoDiffDerivativeFunctionKind(StringRef string) {
24+
Optional<innerty> result =
25+
llvm::StringSwitch<Optional<innerty>>(string)
26+
.Case("jvp", JVP).Case("vjp", VJP);
27+
assert(result && "Invalid string");
28+
rawValue = *result;
29+
}
30+
2231
DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind(
2332
StringRef string) {
2433
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)

lib/ParseSIL/ParseSIL.cpp

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,7 @@ static Optional<AccessorKind> getAccessorKind(StringRef ident) {
13411341

13421342
/// sil-decl-ref ::= '#' sil-identifier ('.' sil-identifier)* sil-decl-subref?
13431343
/// sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)?
1344+
/// ('.' sil-decl-autodiff)?
13441345
/// sil-decl-subref ::= '!' sil-decl-lang
13451346
/// sil-decl-subref-part ::= 'getter'
13461347
/// sil-decl-subref-part ::= 'setter'
@@ -1350,27 +1351,33 @@ static Optional<AccessorKind> getAccessorKind(StringRef ident) {
13501351
/// sil-decl-subref-part ::= 'destroyer'
13511352
/// sil-decl-subref-part ::= 'globalaccessor'
13521353
/// sil-decl-lang ::= 'foreign'
1354+
/// sil-decl-autodiff ::= sil-decl-autodiff-kind '.' sil-decl-autodiff-indices
1355+
/// sil-decl-autodiff-kind ::= 'jvp'
1356+
/// sil-decl-autodiff-kind ::= 'vjp'
1357+
/// sil-decl-autodiff-indices ::= [SU]+
13531358
bool SILParser::parseSILDeclRef(SILDeclRef &Result,
13541359
SmallVectorImpl<ValueDecl *> &values) {
13551360
ValueDecl *VD;
13561361
if (parseSILDottedPath(VD, values))
13571362
return true;
13581363

1359-
// Initialize Kind and IsObjC.
1364+
// Initialize SILDeclRef components.
13601365
SILDeclRef::Kind Kind = SILDeclRef::Kind::Func;
13611366
bool IsObjC = false;
1367+
AutoDiffDerivativeFunctionIdentifier *DerivativeId = nullptr;
13621368

13631369
if (!P.consumeIf(tok::sil_exclamation)) {
13641370
// Construct SILDeclRef.
1365-
Result = SILDeclRef(VD, Kind, IsObjC);
1371+
Result = SILDeclRef(VD, Kind, IsObjC, DerivativeId);
13661372
return false;
13671373
}
13681374

1369-
// Handle sil-constant-kind-and-uncurry-level.
1370-
// ParseState indicates the value we just handled.
1371-
// 1 means we just handled Kind.
1372-
// We accept func|getter|setter|...|foreign when ParseState is 0;
1373-
// accept foreign when ParseState is 1.
1375+
// Handle SILDeclRef components. ParseState tracks the last parsed component.
1376+
//
1377+
// When ParseState is 0, accept kind (`func|getter|setter|...`) and set
1378+
// ParseState to 1.
1379+
//
1380+
// Always accept `foreign` and derivative function identifier.
13741381
unsigned ParseState = 0;
13751382
Identifier Id;
13761383
do {
@@ -1439,15 +1446,47 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
14391446
} else if (Id.str() == "foreign") {
14401447
IsObjC = true;
14411448
break;
1442-
} else
1449+
} else if (Id.str() == "jvp" || Id.str() == "vjp") {
1450+
IndexSubset *parameterIndices = nullptr;
1451+
GenericSignature derivativeGenSig;
1452+
// Parse derivative function kind.
1453+
AutoDiffDerivativeFunctionKind derivativeKind(Id.str());
1454+
if (!P.consumeIf(tok::period)) {
1455+
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, ".");
1456+
return true;
1457+
}
1458+
// Parse parameter indices.
1459+
parameterIndices = IndexSubset::getFromString(
1460+
SILMod.getASTContext(), P.Tok.getText());
1461+
if (!parameterIndices) {
1462+
P.diagnose(P.Tok, diag::invalid_index_subset);
1463+
return true;
1464+
}
1465+
P.consumeToken();
1466+
// Parse derivative generic signature (optional).
1467+
if (P.Tok.is(tok::oper_binary_unspaced) && P.Tok.getText() == ".<") {
1468+
P.consumeStartingCharacterOfCurrentToken(tok::period);
1469+
// Create a new scope to avoid type redefinition errors.
1470+
Scope genericsScope(&P, ScopeKind::Generics);
1471+
auto *genericParams = P.maybeParseGenericParams().getPtrOrNull();
1472+
assert(genericParams);
1473+
auto *derivativeGenEnv = handleSILGenericParams(genericParams, &P.SF);
1474+
derivativeGenSig = derivativeGenEnv->getGenericSignature();
1475+
}
1476+
DerivativeId = AutoDiffDerivativeFunctionIdentifier::get(
1477+
derivativeKind, parameterIndices, derivativeGenSig,
1478+
SILMod.getASTContext());
1479+
break;
1480+
} else {
14431481
break;
1482+
}
14441483
} else
14451484
break;
14461485

14471486
} while (P.consumeIf(tok::period));
14481487

14491488
// Construct SILDeclRef.
1450-
Result = SILDeclRef(VD, Kind, IsObjC);
1489+
Result = SILDeclRef(VD, Kind, IsObjC, DerivativeId);
14511490
return false;
14521491
}
14531492

lib/SIL/SILDeclRef.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,14 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
114114
}
115115

116116
SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind,
117-
bool isForeign)
118-
: loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0)
117+
bool isForeign,
118+
AutoDiffDerivativeFunctionIdentifier *derivativeId)
119+
: loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0),
120+
derivativeFunctionIdentifier(derivativeId)
119121
{}
120122

121123
SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign)
122-
: defaultArgIndex(0)
124+
: defaultArgIndex(0), derivativeFunctionIdentifier(nullptr)
123125
{
124126
if (auto *vd = baseLoc.dyn_cast<ValueDecl*>()) {
125127
if (auto *fd = dyn_cast<FuncDecl>(vd)) {
@@ -845,7 +847,8 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const {
845847
SILDeclRef SILDeclRef::getOverriddenWitnessTableEntry() const {
846848
auto bestOverridden =
847849
getOverriddenWitnessTableEntry(cast<AbstractFunctionDecl>(getDecl()));
848-
return SILDeclRef(bestOverridden, kind);
850+
return SILDeclRef(bestOverridden, kind, isForeign,
851+
derivativeFunctionIdentifier);
849852
}
850853

851854
AbstractFunctionDecl *SILDeclRef::getOverriddenWitnessTableEntry(

0 commit comments

Comments
 (0)