Skip to content

[AutoDiff] NFC: formatting. #30573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {

public:
AutoDiffDerivativeFunctionKind getKind() const { return kind; }
IndexSubset *getParameterIndices() const {
return parameterIndices;
}
IndexSubset *getParameterIndices() const { return parameterIndices; }
GenericSignature getDerivativeGenericSignature() const {
return derivativeGenericSignature;
}
Expand Down
53 changes: 24 additions & 29 deletions include/swift/SIL/SILDeclRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,17 @@ struct SILDeclRef {
unsigned defaultArgIndex : 10;
/// The derivative function identifier.
AutoDiffDerivativeFunctionIdentifier *derivativeFunctionIdentifier = nullptr;

/// Produces a null SILDeclRef.
SILDeclRef() : loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0),
derivativeFunctionIdentifier(nullptr) {}

SILDeclRef()
: loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0),
derivativeFunctionIdentifier(nullptr) {}

/// Produces a SILDeclRef of the given kind for the given decl.
explicit SILDeclRef(ValueDecl *decl, Kind kind,
bool isForeign = false,
AutoDiffDerivativeFunctionIdentifier *derivativeId = nullptr);
explicit SILDeclRef(
ValueDecl *decl, Kind kind, bool isForeign = false,
AutoDiffDerivativeFunctionIdentifier *derivativeId = nullptr);

/// Produces a SILDeclRef for the given ValueDecl or
/// AbstractClosureExpr:
/// - If 'loc' is a func or closure, this returns a Func SILDeclRef.
Expand Down Expand Up @@ -283,11 +284,10 @@ struct SILDeclRef {
}

bool operator==(SILDeclRef rhs) const {
return loc.getOpaqueValue() == rhs.loc.getOpaqueValue()
&& kind == rhs.kind
&& isForeign == rhs.isForeign
&& defaultArgIndex == rhs.defaultArgIndex
&& derivativeFunctionIdentifier == rhs.derivativeFunctionIdentifier;
return loc.getOpaqueValue() == rhs.loc.getOpaqueValue() &&
kind == rhs.kind && isForeign == rhs.isForeign &&
defaultArgIndex == rhs.defaultArgIndex &&
derivativeFunctionIdentifier == rhs.derivativeFunctionIdentifier;
}
bool operator!=(SILDeclRef rhs) const {
return !(*this == rhs);
Expand All @@ -301,8 +301,8 @@ struct SILDeclRef {
/// Returns the foreign (or native) entry point corresponding to the same
/// decl.
SILDeclRef asForeign(bool foreign = true) const {
return SILDeclRef(loc.getOpaqueValue(), kind,
foreign, defaultArgIndex, derivativeFunctionIdentifier);
return SILDeclRef(loc.getOpaqueValue(), kind, foreign, defaultArgIndex,
derivativeFunctionIdentifier);
}

/// Returns the entry point for the corresponding autodiff derivative
Expand Down Expand Up @@ -400,16 +400,12 @@ struct SILDeclRef {
private:
friend struct llvm::DenseMapInfo<swift::SILDeclRef>;
/// Produces a SILDeclRef from an opaque value.
explicit SILDeclRef(void *opaqueLoc,
Kind kind,
bool isForeign,
explicit SILDeclRef(void *opaqueLoc, Kind kind, bool isForeign,
unsigned defaultArgIndex,
AutoDiffDerivativeFunctionIdentifier *derivativeId)
: loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind),
isForeign(isForeign), defaultArgIndex(defaultArgIndex),
derivativeFunctionIdentifier(derivativeId)
{}

: loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind),
isForeign(isForeign), defaultArgIndex(defaultArgIndex),
derivativeFunctionIdentifier(derivativeId) {}
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, SILDeclRef C) {
Expand All @@ -430,12 +426,12 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
using UnsignedInfo = DenseMapInfo<unsigned>;

static SILDeclRef getEmptyKey() {
return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func,
false, 0, nullptr);
return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func, false, 0,
nullptr);
}
static SILDeclRef getTombstoneKey() {
return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func,
false, 0, nullptr);
return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func, false, 0,
nullptr);
}
static unsigned getHashValue(swift::SILDeclRef Val) {
unsigned h1 = PointerInfo::getHashValue(Val.loc.getOpaqueValue());
Expand All @@ -444,8 +440,7 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
? UnsignedInfo::getHashValue(Val.defaultArgIndex)
: 0;
unsigned h4 = UnsignedInfo::getHashValue(Val.isForeign);
unsigned h5 =
PointerInfo::getHashValue(Val.derivativeFunctionIdentifier);
unsigned h5 = PointerInfo::getHashValue(Val.derivativeFunctionIdentifier);
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11);
}
static bool isEqual(swift::SILDeclRef const &LHS,
Expand Down
10 changes: 5 additions & 5 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

using namespace swift;

AutoDiffDerivativeFunctionKind::
AutoDiffDerivativeFunctionKind(StringRef string) {
Optional<innerty> result =
llvm::StringSwitch<Optional<innerty>>(string)
.Case("jvp", JVP).Case("vjp", VJP);
AutoDiffDerivativeFunctionKind::AutoDiffDerivativeFunctionKind(
StringRef string) {
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
.Case("jvp", JVP)
.Case("vjp", VJP);
assert(result && "Invalid string");
rawValue = *result;
}
Expand Down
4 changes: 2 additions & 2 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1456,8 +1456,8 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
return true;
}
// Parse parameter indices.
parameterIndices = IndexSubset::getFromString(
SILMod.getASTContext(), P.Tok.getText());
parameterIndices =
IndexSubset::getFromString(SILMod.getASTContext(), P.Tok.getText());
if (!parameterIndices) {
P.diagnose(P.Tok, diag::invalid_index_subset);
return true;
Expand Down
17 changes: 6 additions & 11 deletions lib/SIL/SILDeclRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,13 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
return false;
}

SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind,
bool isForeign,
SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind, bool isForeign,
AutoDiffDerivativeFunctionIdentifier *derivativeId)
: loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0),
derivativeFunctionIdentifier(derivativeId)
{}
: loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0),
derivativeFunctionIdentifier(derivativeId) {}

SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign)
: defaultArgIndex(0), derivativeFunctionIdentifier(nullptr)
{
SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign)
: defaultArgIndex(0), derivativeFunctionIdentifier(nullptr) {
if (auto *vd = baseLoc.dyn_cast<ValueDecl*>()) {
if (auto *fd = dyn_cast<FuncDecl>(vd)) {
// Map FuncDecls directly to Func SILDeclRefs.
Expand Down Expand Up @@ -900,16 +897,14 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const {
return SILDeclRef();

// JVPs/VJPs are overridden only if the base declaration has a
// `@differentiable` with the same parameter indices.
// `@differentiable` attribute with the same parameter indices.
if (derivativeFunctionIdentifier) {
auto overriddenAttrs =
overridden.getDecl()->getAttrs().getAttributes<DifferentiableAttr>();
for (const auto *attr : overriddenAttrs) {
if (attr->getParameterIndices() !=
derivativeFunctionIdentifier->getParameterIndices())
continue;

// TODO(TF-1056): Do we need to check generic signature requirements?
auto *overriddenDerivativeId = overridden.derivativeFunctionIdentifier;
overridden.derivativeFunctionIdentifier =
AutoDiffDerivativeFunctionIdentifier::get(
Expand Down
4 changes: 2 additions & 2 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3115,8 +3115,8 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
auto *loweredIndices = autodiff::getLoweredParameterIndices(
derivativeId->getParameterIndices(), formalInterfaceType);
silFnType = origFnConstantInfo.SILFnType->getAutoDiffDerivativeFunctionType(
loweredIndices, /*resultIndex*/ 0, derivativeId->getKind(),
*this, LookUpConformanceInModule(&M));
loweredIndices, /*resultIndex*/ 0, derivativeId->getKind(), *this,
LookUpConformanceInModule(&M));
}

LLVM_DEBUG(llvm::dbgs() << "lowering type for constant ";
Expand Down