Skip to content

add SILDeclRef modifier for autodiff functions #21224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,14 @@ struct SILAutoDiffIndices {
[&s](unsigned p) { s << p; }, [&s]{ s << ' '; });
s << "))";
}

std::string mangle() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you define a new AutoDiffMangler or similar, even if its still doing stuff like this? It will be easier to move to a real mangling in the future (which I encourage you to do).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of following the example in ASTMangler? I haven't read through that and understood how it works yet, but I'll do that and then implement something like it, in a later PR.

std::string result = "src_" + llvm::utostr(source) + "_wrt_";
interleave(parameters.set_bits(),
[&](unsigned idx) { result += llvm::utostr(idx); },
[&] { result += '_'; });
return result;
}
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ ERROR(expected_sil_value_ownership_kind,none,
"expected value ownership kind in SIL code", ())
ERROR(expected_sil_colon,none,
"expected ':' before %0", (StringRef))
// SWIFT_ENABLE_TENSORFLOW
ERROR(malformed_autodiff_parameter_indices,none,
"malformed autodiff parameter indices", ())

// SIL Values
ERROR(sil_value_redefinition,none,
Expand Down
68 changes: 58 additions & 10 deletions include/swift/SIL/SILDeclRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace swift {
enum class EffectsKind : uint8_t;
class AbstractFunctionDecl;
class AbstractClosureExpr;
// SWIFT_ENABLE_TENSORFLOW
class AutoDiffAssociatedFunctionIdentifier;
class ValueDecl;
class FuncDecl;
class ClosureExpr;
Expand Down Expand Up @@ -149,15 +151,25 @@ struct SILDeclRef {
/// The default argument index for a default argument getter.
unsigned defaultArgIndex : 10;

// SWIFT_ENABLE_TENSORFLOW
/// When this is non-null, it modifies the SILDeclRef to refer to the
/// corresponding autodiff associated function.
AutoDiffAssociatedFunctionIdentifier *autoDiffAssociatedFunctionIdentifier;

/// Produces a null SILDeclRef.
SILDeclRef() : loc(), kind(Kind::Func),
isCurried(0), isForeign(0), isDirectReference(0),
defaultArgIndex(0) {}
// SWIFT_ENABLE_TENSORFLOW
defaultArgIndex(0),
autoDiffAssociatedFunctionIdentifier(nullptr) {}

/// Produces a SILDeclRef of the given kind for the given decl.
explicit SILDeclRef(ValueDecl *decl, Kind kind,
bool isCurried = false,
bool isForeign = false);
// SWIFT_ENABLE_TENSORFLOW
bool isForeign = false,
AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId =
nullptr);

/// Produces a SILDeclRef for the given ValueDecl or
/// AbstractClosureExpr:
Expand Down Expand Up @@ -284,7 +296,10 @@ struct SILDeclRef {
&& isCurried == rhs.isCurried
&& isForeign == rhs.isForeign
&& isDirectReference == rhs.isDirectReference
&& defaultArgIndex == rhs.defaultArgIndex;
// SWIFT_ENABLE_TENSORFLOW
&& defaultArgIndex == rhs.defaultArgIndex
&& autoDiffAssociatedFunctionIdentifier ==
rhs.autoDiffAssociatedFunctionIdentifier;
}
bool operator!=(SILDeclRef rhs) const {
return !(*this == rhs);
Expand All @@ -303,15 +318,19 @@ struct SILDeclRef {
bool willBeDirect = isDirectReference;
return SILDeclRef(loc.getOpaqueValue(), kind,
curried, willBeDirect, willBeForeign,
defaultArgIndex);
// SWIFT_ENABLE_TENSORFLOW
defaultArgIndex,
autoDiffAssociatedFunctionIdentifier);
}

/// Returns the foreign (or native) entry point corresponding to the same
/// decl.
SILDeclRef asForeign(bool foreign = true) const {
assert(!isCurried);
return SILDeclRef(loc.getOpaqueValue(), kind,
isCurried, isDirectReference, foreign, defaultArgIndex);
// SWIFT_ENABLE_TENSORFLOW
isCurried, isDirectReference, foreign, defaultArgIndex,
autoDiffAssociatedFunctionIdentifier);
}

SILDeclRef asDirectReference(bool direct = true) const {
Expand All @@ -322,6 +341,26 @@ struct SILDeclRef {
return r;
}

// SWIFT_ENABLE_TENSORFLOW
/// Returns the entry point for the corresponding autodiff associated
/// function.
SILDeclRef asAutoDiffAssociatedFunction(
AutoDiffAssociatedFunctionIdentifier *id) const {
assert(!autoDiffAssociatedFunctionIdentifier);
SILDeclRef r = *this;
r.autoDiffAssociatedFunctionIdentifier = id;
return r;
}

/// Returns the entry point for the original function corresponding to an
/// autodiff associated function.
SILDeclRef asAutoDiffOriginalFunction() const {
assert(autoDiffAssociatedFunctionIdentifier);
SILDeclRef r = *this;
r.autoDiffAssociatedFunctionIdentifier = nullptr;
return r;
}

/// True if the decl ref references a thunk from a natively foreign
/// declaration to Swift calling convention.
bool isForeignToNativeThunk() const;
Expand Down Expand Up @@ -392,12 +431,16 @@ struct SILDeclRef {
bool isCurried,
bool isDirectReference,
bool isForeign,
unsigned defaultArgIndex)
// SWIFT_ENABLE_TENSORFLOW
unsigned defaultArgIndex,
AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId)
: loc(Loc::getFromOpaqueValue(opaqueLoc)),
kind(kind),
isCurried(isCurried),
isForeign(isForeign), isDirectReference(isDirectReference),
defaultArgIndex(defaultArgIndex)
// SWIFT_ENABLE_TENSORFLOW
defaultArgIndex(defaultArgIndex),
autoDiffAssociatedFunctionIdentifier(autoDiffFuncId)
{}

};
Expand All @@ -421,11 +464,13 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {

static SILDeclRef getEmptyKey() {
return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func,
false, false, false, 0);
// SWIFT_ENABLE_TENSORFLOW
false, false, false, 0, nullptr);
}
static SILDeclRef getTombstoneKey() {
return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func,
false, false, false, 0);
// SWIFT_ENABLE_TENSORFLOW
false, false, false, 0, nullptr);
}
static unsigned getHashValue(swift::SILDeclRef Val) {
unsigned h1 = PointerInfo::getHashValue(Val.loc.getOpaqueValue());
Expand All @@ -435,7 +480,10 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
: UnsignedInfo::getHashValue(Val.isCurried);
unsigned h4 = UnsignedInfo::getHashValue(Val.isForeign);
unsigned h5 = UnsignedInfo::getHashValue(Val.isDirectReference);
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11);
// SWIFT_ENABLE_TENSORFLOW
unsigned h6 =
PointerInfo::getHashValue(Val.autoDiffAssociatedFunctionIdentifier);
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11) ^ (h6 << 13);
}
static bool isEqual(swift::SILDeclRef const &LHS,
swift::SILDeclRef const &RHS) {
Expand Down
72 changes: 64 additions & 8 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1584,11 +1584,14 @@ static Optional<AccessorKind> getAccessorKind(StringRef ident) {
.Default(None);
}

// SWIFT_ENABLE_TENSORFLOW
/// sil-decl-ref ::= '#' sil-identifier ('.' sil-identifier)* sil-decl-subref?
/// sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-uncurry-level)?
/// ('.' sil-decl-lang)?
/// ('.' sil-decl-lang)? ('.' sil-decl-autodiff)?
/// sil-decl-subref ::= '!' sil-decl-uncurry-level ('.' sil-decl-lang)?
/// sil-decl-subref ::= '!' sil-decl-lang
/// ('.' sil-decl-autodiff)?
/// sil-decl-subref ::= '!' sil-decl-lang ('.' sil-decl-autodiff)?
/// sil-decl-subref ::= '!' sil-decl-autodiff
/// sil-decl-subref-part ::= 'getter'
/// sil-decl-subref-part ::= 'setter'
/// sil-decl-subref-part ::= 'allocator'
Expand All @@ -1598,6 +1601,12 @@ static Optional<AccessorKind> getAccessorKind(StringRef ident) {
/// sil-decl-subref-part ::= 'globalaccessor'
/// sil-decl-uncurry-level ::= [0-9]+
/// sil-decl-lang ::= 'foreign'
/// sil-decl-autodiff ::= sil-decl-autodiff-kind '.' sil-decl-autodiff-order
/// '.' sil-decl-autodiff-indices
/// sil-decl-autodiff-kind ::= 'jvp'
/// sil-decl-autodiff-kind ::= 'vjp'
/// sil-decl-autodiff-order ::= [0-9]+
/// sil-decl-autodiff-indices ::= [FM][SU]+
bool SILParser::parseSILDeclRef(SILDeclRef &Result,
SmallVectorImpl<ValueDecl *> &values) {
ValueDecl *VD;
Expand All @@ -1608,6 +1617,8 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
SILDeclRef::Kind Kind = SILDeclRef::Kind::Func;
unsigned uncurryLevel = 0;
bool IsObjC = false;
// SWIFT_ENABLE_TENSORFLOW
AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId = nullptr;

if (!P.consumeIf(tok::sil_exclamation)) {
// Construct SILDeclRef.
Expand All @@ -1619,10 +1630,13 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,

// Handle sil-constant-kind-and-uncurry-level.
// ParseState indicates the value we just handled.
// 1 means we just handled Kind, 2 means we just handled uncurryLevel.
// We accept func|getter|setter|...|foreign or an integer when ParseState is
// 0; accept foreign or an integer when ParseState is 1; accept foreign when
// ParseState is 2.
// SWIFT_ENABLE_TENSORFLOW
// 1 means we just handled Kind, 2 means we just handled uncurryLevel, 3 means
// we just handled foreign.
// We accept func|getter|setter|...|foreign, an autodiff identifier, or an
// integer when ParseState is 0; accept foreign, an autodiff identifier, or an
// integer when ParseState is 1; accept foreign or an autodiff identifier when
// ParseState is 2; accept an autodiff identifier when ParseState is 3.
unsigned ParseState = 0;
Identifier Id;
do {
Expand Down Expand Up @@ -1682,8 +1696,49 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
} else if (!ParseState && Id.str() == "propertyinit") {
Kind = SILDeclRef::Kind::StoredPropertyInitializer;
ParseState = 1;
} else if (Id.str() == "foreign") {
// SWIFT_ENABLE_TENSORFLOW
} else if (ParseState < 3 && Id.str() == "foreign") {
IsObjC = true;
// SWIFT_ENABLE_TENSORFLOW
ParseState = 3;
} else if (Id.str() == "jvp" || Id.str() == "vjp") {
AutoDiffAssociatedFunctionKind kind;
unsigned differentiationOrder;
AutoDiffParameterIndices *parameterIndices = nullptr;

if (Id.str() == "jvp")
kind = AutoDiffAssociatedFunctionKind::JVP;
else if (Id.str() == "vjp")
kind = AutoDiffAssociatedFunctionKind::VJP;
else
llvm_unreachable("Should only have JVP and VJP here");

if (!P.consumeIf(tok::period)) {
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, ".");
return true;
}

if (parseInteger(differentiationOrder,
diag::sil_const_expected_int_value))
return true;

if (!P.consumeIf(tok::period)) {
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, ".");
return true;
}

parameterIndices = AutoDiffParameterIndices::create(
SILMod.getASTContext(), P.Tok.getText());
if (!parameterIndices) {
P.diagnose(P.Tok, diag::malformed_autodiff_parameter_indices);
return true;
}
P.consumeToken();

autoDiffFuncId = AutoDiffAssociatedFunctionIdentifier::get(
kind, differentiationOrder, parameterIndices,
SILMod.getASTContext());

break;
} else
break;
Expand All @@ -1697,7 +1752,8 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
} while (P.consumeIf(tok::period));

// Construct SILDeclRef.
Result = SILDeclRef(VD, Kind, /*isCurried=*/false, IsObjC);
// SWIFT_ENABLE_TENSORFLOW
Result = SILDeclRef(VD, Kind, /*isCurried=*/false, IsObjC, autoDiffFuncId);
if (uncurryLevel < Result.getParameterListCount() - 1)
Result = Result.asCurried();
return false;
Expand Down
38 changes: 34 additions & 4 deletions lib/SIL/SILDeclRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,21 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
}

SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind,
bool isCurried, bool isForeign)
// SWIFT_ENABLE_TENSORFLOW
bool isCurried, bool isForeign,
AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId)
: loc(vd), kind(kind),
isCurried(isCurried), isForeign(isForeign),
isDirectReference(0), defaultArgIndex(0)
// SWIFT_ENABLE_TENSORFLOW
isDirectReference(0), defaultArgIndex(0),
autoDiffAssociatedFunctionIdentifier(autoDiffFuncId)
{}

SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc,
bool isCurried, bool asForeign)
: isCurried(isCurried), isDirectReference(0), defaultArgIndex(0)
bool isCurried, bool asForeign)
// SWIFT_ENABLE_TENSORFLOW
: isCurried(isCurried), isDirectReference(0), defaultArgIndex(0),
autoDiffAssociatedFunctionIdentifier(nullptr)
{
if (auto *vd = baseLoc.dyn_cast<ValueDecl*>()) {
if (auto *fd = dyn_cast<FuncDecl>(vd)) {
Expand Down Expand Up @@ -601,6 +607,30 @@ static void mangleClangDecl(raw_ostream &buffer,
}

std::string SILDeclRef::mangle(ManglingKind MKind) const {
// SWIFT_ENABLE_TENSORFLOW
if (autoDiffAssociatedFunctionIdentifier) {
std::string originalMangled = asAutoDiffOriginalFunction().mangle(MKind);
bool isMethod = cast<AbstractFunctionDecl>(getDecl())->getImplicitSelfDecl()
? true : false;
auto *functionTy =
getDecl()->getInterfaceType()->castTo<AnyFunctionType>();
auto silParameterIndices =
autoDiffAssociatedFunctionIdentifier->getParameterIndices()->getLowered(
functionTy, isMethod);
SILAutoDiffIndices indices(/*source*/ 0, silParameterIndices);
std::string mangledKind;
switch (autoDiffAssociatedFunctionIdentifier->getKind()) {
case AutoDiffAssociatedFunctionKind::JVP:
mangledKind = "jvp";
break;
case AutoDiffAssociatedFunctionKind::VJP:
mangledKind = "vjp";
break;
}
return "AD__" + originalMangled + "__" + mangledKind + "_" +
indices.mangle();
}

using namespace Mangle;
ASTMangler mangler;

Expand Down
17 changes: 17 additions & 0 deletions lib/SIL/SILPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,23 @@ void SILDeclRef::print(raw_ostream &OS) const {

if (isDirectReference)
OS << ((isDot || uncurryLevel != 0) ? '.' : '!') << "direct";

// SWIFT_ENABLE_TENSORFLOW
if (autoDiffAssociatedFunctionIdentifier) {
auto *autoDiffFuncId = autoDiffAssociatedFunctionIdentifier;
OS << ((isDot || uncurryLevel != 0 || isForeign || isDirectReference)
? '.' : '!');
switch (autoDiffFuncId->getKind()) {
case AutoDiffAssociatedFunctionKind::JVP:
OS << "jvp.";
break;
case AutoDiffAssociatedFunctionKind::VJP:
OS << "vjp.";
break;
}
OS << autoDiffFuncId->getDifferentiationOrder() << "."
<< autoDiffFuncId->getParameterIndices()->getString();
}
}

void SILDeclRef::dump() const {
Expand Down
Loading