Skip to content

[AutoDiff upstream] Add derivative function SILDeclRefs. #30564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1088,8 +1088,9 @@ Declaration References
::

sil-decl-ref ::= '#' sil-identifier ('.' sil-identifier)* sil-decl-subref?
sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)?
sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)? ('.' sil-decl-autodiff)?
sil-decl-subref ::= '!' sil-decl-lang
sil-decl-subref ::= '!' sil-decl-autodiff
sil-decl-subref-part ::= 'getter'
sil-decl-subref-part ::= 'setter'
sil-decl-subref-part ::= 'allocator'
Expand All @@ -1102,6 +1103,10 @@ Declaration References
sil-decl-subref-part ::= 'ivarinitializer'
sil-decl-subref-part ::= 'defaultarg' '.' [0-9]+
sil-decl-lang ::= 'foreign'
sil-decl-autodiff ::= sil-decl-autodiff-kind '.' sil-decl-autodiff-indices
sil-decl-autodiff-kind ::= 'jvp'
sil-decl-autodiff-kind ::= 'vjp'
sil-decl-autodiff-indices ::= [SU]+

Some SIL instructions need to reference Swift declarations directly. These
references are introduced with the ``#`` sigil followed by the fully qualified
Expand Down
36 changes: 36 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,42 @@ struct AutoDiffDerivativeFunctionKind {
}
};

/// A derivative function configuration, uniqued in `ASTContext`.
/// Identifies a specific derivative function given an original function.
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
const AutoDiffDerivativeFunctionKind kind;
IndexSubset *const parameterIndices;
GenericSignature derivativeGenericSignature;

AutoDiffDerivativeFunctionIdentifier(
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,
GenericSignature derivativeGenericSignature)
: kind(kind), parameterIndices(parameterIndices),
derivativeGenericSignature(derivativeGenericSignature) {}

public:
AutoDiffDerivativeFunctionKind getKind() const { return kind; }
IndexSubset *getParameterIndices() const {
return parameterIndices;
}
GenericSignature getDerivativeGenericSignature() const {
return derivativeGenericSignature;
}

static AutoDiffDerivativeFunctionIdentifier *
get(AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,
GenericSignature derivativeGenericSignature, ASTContext &C);

void Profile(llvm::FoldingSetNodeID &ID) {
ID.AddInteger(kind);
ID.AddPointer(parameterIndices);
CanGenericSignature derivativeCanGenSig;
if (derivativeGenericSignature)
derivativeCanGenSig = derivativeGenericSignature->getCanonicalSignature();
Comment on lines +107 to +109
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified after #29105:

Suggested change
CanGenericSignature derivativeCanGenSig;
if (derivativeGenericSignature)
derivativeCanGenSig = derivativeGenericSignature->getCanonicalSignature();
auto derivativeCanGenSig = derivativeGenericSignature.getCanonicalSignature();

I'll land this in a follow-up.

ID.AddPointer(derivativeCanGenSig.getPointer());
}
};

/// The kind of a differentiability witness function.
struct DifferentiabilityWitnessFunctionKind {
enum innerty : uint8_t {
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,9 @@ ERROR(expected_sil_colon,none,
"expected ':' before %0", (StringRef))
ERROR(expected_sil_tuple_index,none,
"expected tuple element index", ())
ERROR(invalid_index_subset,none,
"invalid index subset; expected '[SU]+' where 'S' represents set indices "
"and 'U' represents unset indices", ())

// SIL Values
ERROR(sil_value_redefinition,none,
Expand Down
50 changes: 39 additions & 11 deletions include/swift/SIL/SILDeclRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace swift {
enum class EffectsKind : uint8_t;
class AbstractFunctionDecl;
class AbstractClosureExpr;
class AutoDiffDerivativeFunctionIdentifier;
class ValueDecl;
class FuncDecl;
class ClosureExpr;
Expand Down Expand Up @@ -147,13 +148,17 @@ struct SILDeclRef {
unsigned isForeign : 1;
/// The default argument index for a default argument getter.
unsigned defaultArgIndex : 10;
/// The derivative function identifier.
AutoDiffDerivativeFunctionIdentifier *derivativeFunctionIdentifier = nullptr;

/// Produces a null SILDeclRef.
SILDeclRef() : loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0) {}
SILDeclRef() : loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0),
derivativeFunctionIdentifier(nullptr) {}

/// Produces a SILDeclRef of the given kind for the given decl.
explicit SILDeclRef(ValueDecl *decl, Kind kind,
bool isForeign = false);
bool isForeign = false,
AutoDiffDerivativeFunctionIdentifier *derivativeId = nullptr);

/// Produces a SILDeclRef for the given ValueDecl or
/// AbstractClosureExpr:
Expand All @@ -166,8 +171,7 @@ struct SILDeclRef {
/// for the containing ClassDecl.
/// - If 'loc' is a global VarDecl, this returns its GlobalAccessor
/// SILDeclRef.
explicit SILDeclRef(Loc loc,
bool isForeign = false);
explicit SILDeclRef(Loc loc, bool isForeign = false);

/// Produce a SIL constant for a default argument generator.
static SILDeclRef getDefaultArgGenerator(Loc loc, unsigned defaultArgIndex);
Expand Down Expand Up @@ -282,7 +286,8 @@ struct SILDeclRef {
return loc.getOpaqueValue() == rhs.loc.getOpaqueValue()
&& kind == rhs.kind
&& isForeign == rhs.isForeign
&& defaultArgIndex == rhs.defaultArgIndex;
&& defaultArgIndex == rhs.defaultArgIndex
&& derivativeFunctionIdentifier == rhs.derivativeFunctionIdentifier;
}
bool operator!=(SILDeclRef rhs) const {
return !(*this == rhs);
Expand All @@ -297,7 +302,26 @@ struct SILDeclRef {
/// decl.
SILDeclRef asForeign(bool foreign = true) const {
return SILDeclRef(loc.getOpaqueValue(), kind,
foreign, defaultArgIndex);
foreign, defaultArgIndex, derivativeFunctionIdentifier);
}

/// Returns the entry point for the corresponding autodiff derivative
/// function.
SILDeclRef asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier *derivativeId) const {
assert(!derivativeFunctionIdentifier);
SILDeclRef declRef = *this;
declRef.derivativeFunctionIdentifier = derivativeId;
return declRef;
}

/// Returns the entry point for the original function corresponding to an
/// autodiff derivative function.
SILDeclRef asAutoDiffOriginalFunction() const {
assert(derivativeFunctionIdentifier);
SILDeclRef declRef = *this;
declRef.derivativeFunctionIdentifier = nullptr;
return declRef;
}

/// True if the decl ref references a thunk from a natively foreign
Expand Down Expand Up @@ -372,9 +396,11 @@ struct SILDeclRef {
explicit SILDeclRef(void *opaqueLoc,
Kind kind,
bool isForeign,
unsigned defaultArgIndex)
unsigned defaultArgIndex,
AutoDiffDerivativeFunctionIdentifier *derivativeId)
: loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind),
isForeign(isForeign), defaultArgIndex(defaultArgIndex)
isForeign(isForeign), defaultArgIndex(defaultArgIndex),
derivativeFunctionIdentifier(derivativeId)
{}

};
Expand All @@ -398,11 +424,11 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {

static SILDeclRef getEmptyKey() {
return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func,
false, 0);
false, 0, nullptr);
}
static SILDeclRef getTombstoneKey() {
return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func,
false, 0);
false, 0, nullptr);
}
static unsigned getHashValue(swift::SILDeclRef Val) {
unsigned h1 = PointerInfo::getHashValue(Val.loc.getOpaqueValue());
Expand All @@ -411,7 +437,9 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
? UnsignedInfo::getHashValue(Val.defaultArgIndex)
: 0;
unsigned h4 = UnsignedInfo::getHashValue(Val.isForeign);
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7);
unsigned h5 =
PointerInfo::getHashValue(Val.derivativeFunctionIdentifier);
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11);
}
static bool isEqual(swift::SILDeclRef const &LHS,
swift::SILDeclRef const &RHS) {
Expand Down
31 changes: 29 additions & 2 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,9 @@ struct ASTContext::Implementation {
llvm::FoldingSet<BuiltinVectorType> BuiltinVectorTypes;
llvm::FoldingSet<DeclName::CompoundDeclName> CompoundNames;
llvm::DenseMap<UUID, OpenedArchetypeType *> OpenedExistentialArchetypes;

/// For uniquifying `IndexSubset` allocations.
llvm::FoldingSet<IndexSubset> IndexSubsets;
llvm::FoldingSet<AutoDiffDerivativeFunctionIdentifier>
AutoDiffDerivativeFunctionIdentifiers;

/// A cache of information about whether particular nominal types
/// are representable in a foreign language.
Expand Down Expand Up @@ -4754,3 +4754,30 @@ IndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) {
foldingSet.InsertNode(newNode, insertPos);
return newNode;
}

AutoDiffDerivativeFunctionIdentifier *AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,
GenericSignature derivativeGenericSignature, ASTContext &C) {
assert(parameterIndices);
auto &foldingSet = C.getImpl().AutoDiffDerivativeFunctionIdentifiers;
llvm::FoldingSetNodeID id;
id.AddInteger((unsigned)kind);
id.AddPointer(parameterIndices);
CanGenericSignature derivativeCanGenSig;
if (derivativeGenericSignature)
derivativeCanGenSig = derivativeGenericSignature->getCanonicalSignature();
id.AddPointer(derivativeCanGenSig.getPointer());

void *insertPos;
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
if (existing)
return existing;

void *mem = C.Allocate(sizeof(AutoDiffDerivativeFunctionIdentifier),
alignof(AutoDiffDerivativeFunctionIdentifier));
auto *newNode = ::new (mem) AutoDiffDerivativeFunctionIdentifier(
kind, parameterIndices, derivativeGenericSignature);
foldingSet.InsertNode(newNode, insertPos);

return newNode;
}
9 changes: 9 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@

using namespace swift;

AutoDiffDerivativeFunctionKind::
AutoDiffDerivativeFunctionKind(StringRef string) {
Optional<innerty> result =
llvm::StringSwitch<Optional<innerty>>(string)
.Case("jvp", JVP).Case("vjp", VJP);
assert(result && "Invalid string");
rawValue = *result;
}

DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind(
StringRef string) {
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
Expand Down
57 changes: 48 additions & 9 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,7 @@ static Optional<AccessorKind> getAccessorKind(StringRef ident) {

/// sil-decl-ref ::= '#' sil-identifier ('.' sil-identifier)* sil-decl-subref?
/// sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)?
/// ('.' sil-decl-autodiff)?
/// sil-decl-subref ::= '!' sil-decl-lang
/// sil-decl-subref-part ::= 'getter'
/// sil-decl-subref-part ::= 'setter'
Expand All @@ -1350,27 +1351,33 @@ static Optional<AccessorKind> getAccessorKind(StringRef ident) {
/// sil-decl-subref-part ::= 'destroyer'
/// sil-decl-subref-part ::= 'globalaccessor'
/// sil-decl-lang ::= 'foreign'
/// sil-decl-autodiff ::= sil-decl-autodiff-kind '.' sil-decl-autodiff-indices
/// sil-decl-autodiff-kind ::= 'jvp'
/// sil-decl-autodiff-kind ::= 'vjp'
/// sil-decl-autodiff-indices ::= [SU]+
bool SILParser::parseSILDeclRef(SILDeclRef &Result,
SmallVectorImpl<ValueDecl *> &values) {
ValueDecl *VD;
if (parseSILDottedPath(VD, values))
return true;

// Initialize Kind and IsObjC.
// Initialize SILDeclRef components.
SILDeclRef::Kind Kind = SILDeclRef::Kind::Func;
bool IsObjC = false;
AutoDiffDerivativeFunctionIdentifier *DerivativeId = nullptr;

if (!P.consumeIf(tok::sil_exclamation)) {
// Construct SILDeclRef.
Result = SILDeclRef(VD, Kind, IsObjC);
Result = SILDeclRef(VD, Kind, IsObjC, DerivativeId);
return false;
}

// Handle sil-constant-kind-and-uncurry-level.
// ParseState indicates the value we just handled.
// 1 means we just handled Kind.
// We accept func|getter|setter|...|foreign when ParseState is 0;
// accept foreign when ParseState is 1.
// Handle SILDeclRef components. ParseState tracks the last parsed component.
//
// When ParseState is 0, accept kind (`func|getter|setter|...`) and set
// ParseState to 1.
//
// Always accept `foreign` and derivative function identifier.
unsigned ParseState = 0;
Identifier Id;
do {
Expand Down Expand Up @@ -1439,15 +1446,47 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
} else if (Id.str() == "foreign") {
IsObjC = true;
break;
} else
} else if (Id.str() == "jvp" || Id.str() == "vjp") {
IndexSubset *parameterIndices = nullptr;
GenericSignature derivativeGenSig;
// Parse derivative function kind.
AutoDiffDerivativeFunctionKind derivativeKind(Id.str());
if (!P.consumeIf(tok::period)) {
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, ".");
return true;
}
// Parse parameter indices.
parameterIndices = IndexSubset::getFromString(
SILMod.getASTContext(), P.Tok.getText());
if (!parameterIndices) {
P.diagnose(P.Tok, diag::invalid_index_subset);
return true;
}
P.consumeToken();
// Parse derivative generic signature (optional).
if (P.Tok.is(tok::oper_binary_unspaced) && P.Tok.getText() == ".<") {
P.consumeStartingCharacterOfCurrentToken(tok::period);
// Create a new scope to avoid type redefinition errors.
Scope genericsScope(&P, ScopeKind::Generics);
auto *genericParams = P.maybeParseGenericParams().getPtrOrNull();
assert(genericParams);
auto *derivativeGenEnv = handleSILGenericParams(genericParams, &P.SF);
derivativeGenSig = derivativeGenEnv->getGenericSignature();
}
DerivativeId = AutoDiffDerivativeFunctionIdentifier::get(
derivativeKind, parameterIndices, derivativeGenSig,
SILMod.getASTContext());
break;
} else {
break;
}
} else
break;

} while (P.consumeIf(tok::period));

// Construct SILDeclRef.
Result = SILDeclRef(VD, Kind, IsObjC);
Result = SILDeclRef(VD, Kind, IsObjC, DerivativeId);
return false;
}

Expand Down
11 changes: 7 additions & 4 deletions lib/SIL/SILDeclRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
}

SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind,
bool isForeign)
: loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0)
bool isForeign,
AutoDiffDerivativeFunctionIdentifier *derivativeId)
: loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0),
derivativeFunctionIdentifier(derivativeId)
{}

SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign)
: defaultArgIndex(0)
: defaultArgIndex(0), derivativeFunctionIdentifier(nullptr)
{
if (auto *vd = baseLoc.dyn_cast<ValueDecl*>()) {
if (auto *fd = dyn_cast<FuncDecl>(vd)) {
Expand Down Expand Up @@ -845,7 +847,8 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const {
SILDeclRef SILDeclRef::getOverriddenWitnessTableEntry() const {
auto bestOverridden =
getOverriddenWitnessTableEntry(cast<AbstractFunctionDecl>(getDecl()));
return SILDeclRef(bestOverridden, kind);
return SILDeclRef(bestOverridden, kind, isForeign,
derivativeFunctionIdentifier);
}

AbstractFunctionDecl *SILDeclRef::getOverriddenWitnessTableEntry(
Expand Down
Loading