Skip to content

AutoDiff associated functions in witness table, using SILDeclRef #21241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,6 @@ class ASTMangler : public Mangler {
std::string mangleWitnessThunk(const ProtocolConformance *Conformance,
const ValueDecl *Requirement);

// SWIFT_ENABLE_TENSORFLOW
std::string mangleAutoDiffAssociatedFunctionWitnessThunk(
const ProtocolConformance *Conformance, const ValueDecl *Requirement,
const AutoDiffAssociatedFunctionIdentifier *AutoDiffFuncId);

std::string mangleClosureWitnessThunk(const ProtocolConformance *Conformance,
const AbstractClosureExpr *Closure);

Expand Down
5 changes: 0 additions & 5 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1528,11 +1528,6 @@ ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken,
ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken,
"expects an assoiacted function kind attribute, e.g. '[jvp]'", ())

ERROR(malformed_autodiff_associated_function_kind,PointsToFirstBadToken,
"autodiff associated function kind must be 'jvp' or 'vjp'", ())
ERROR(malformed_autodiff_associated_function_indices,PointsToFirstBadToken,
"malformed autodiff associated function indices", ())

// SWIFT_ENABLE_TENSORFLOW
ERROR(pound_assert_expected,PointsToFirstBadToken,
"expected '%0' in #assert directive", (StringRef))
Expand Down
7 changes: 7 additions & 0 deletions include/swift/SIL/SILDeclRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ struct SILDeclRef {
return r;
}

/// Returns this `SILDeclRef` with the `loc` replaced with `decl`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be SWIFT_ENABLE_TENSORFLOW?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pieces above it are tensorflow too, and they have a SWIFT_ENABLE_TENSORFLOW.

SILDeclRef withDecl(ValueDecl *decl) const {
SILDeclRef result = *this;
result.loc = decl;
return result;
}

/// True if the decl ref references a thunk from a natively foreign
/// declaration to Swift calling convention.
bool isForeignToNativeThunk() const;
Expand Down
31 changes: 1 addition & 30 deletions include/swift/SIL/SILWitnessTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,6 @@ class SILWitnessTable : public llvm::ilist_node<SILWitnessTable>,
/// This can be null in case dead function elimination has removed the method.
SILFunction *Witness;
};

// SWIFT_ENABLE_TENSORFLOW
/// A witness table entry describing the witness for an autodiff associated
/// function for a method.
struct AutoDiffAssociatedFunctionWitness {
/// The original method required.
SILDeclRef RequirementOriginalMethod;
/// The AutoDiffAssociatedFunctionIdentifier identifying the associated
/// function.
AutoDiffAssociatedFunctionIdentifier *RequirementIdentifier;
/// The witness for the autodiff associated function.
/// This can be null in case dead function elimination has removed the method.
SILFunction *Witness;
};

/// A witness table entry describing the witness for an associated type.
struct AssociatedTypeWitness {
Expand Down Expand Up @@ -103,9 +89,7 @@ class SILWitnessTable : public llvm::ilist_node<SILWitnessTable>,
Method,
AssociatedType,
AssociatedTypeProtocol,
BaseProtocol,
// SWIFT_ENABLE_TENSORFLOW
AutoDiffAssociatedFunction
BaseProtocol
};

/// A witness table entry.
Expand All @@ -116,8 +100,6 @@ class SILWitnessTable : public llvm::ilist_node<SILWitnessTable>,
AssociatedTypeWitness AssociatedType;
AssociatedTypeProtocolWitness AssociatedTypeProtocol;
BaseProtocolWitness BaseProtocol;
// SWIFT_ENABLE_TENSORFLOW
AutoDiffAssociatedFunctionWitness AutoDiffAssociatedFunction;
};

public:
Expand All @@ -140,12 +122,6 @@ class SILWitnessTable : public llvm::ilist_node<SILWitnessTable>,
: Kind(WitnessKind::BaseProtocol),
BaseProtocol(BaseProtocol)
{}

// SWIFT_ENABLE_TENSORFLOW
Entry(const AutoDiffAssociatedFunctionWitness &AutoDiffAssociatedFunction)
: Kind(WitnessKind::AutoDiffAssociatedFunction),
AutoDiffAssociatedFunction(AutoDiffAssociatedFunction)
{}

WitnessKind getKind() const { return Kind; }

Expand All @@ -168,11 +144,6 @@ class SILWitnessTable : public llvm::ilist_node<SILWitnessTable>,
assert(Kind == WitnessKind::BaseProtocol);
return BaseProtocol;
}
const AutoDiffAssociatedFunctionWitness
&getAutoDiffAssociatedFunctionWitness() const {
assert(Kind == WitnessKind::AutoDiffAssociatedFunction);
return AutoDiffAssociatedFunction;
}

void removeWitnessMethod() {
assert(Kind == WitnessKind::Method);
Expand Down
12 changes: 4 additions & 8 deletions include/swift/SIL/SILWitnessVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ namespace swift {
/// - addMethod()
/// - addConstructor()
/// - addAssociatedType()
/// SWIFT_ENABLE_TENSORFLOW
/// - addAutoDiffAssociatedFunction()

template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
T &asDerived() { return *static_cast<T*>(this); }
Expand Down Expand Up @@ -148,20 +146,18 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
asDerived().addMethod(funcDeclRef);

if (auto *DA = func->getAttrs().getAttribute<DifferentiableAttr>()) {
asDerived().addAutoDiffAssociatedFunction(
funcDeclRef,
asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction(
AutoDiffAssociatedFunctionIdentifier::get(
AutoDiffAssociatedFunctionKind::JVP,
/*differentiationOrder*/ 1,
DA->getCheckedParameterIndices(),
func->getASTContext()));
asDerived().addAutoDiffAssociatedFunction(
funcDeclRef,
func->getASTContext())));
asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction(
AutoDiffAssociatedFunctionIdentifier::get(
AutoDiffAssociatedFunctionKind::VJP,
/*differentiationOrder*/ 1,
DA->getCheckedParameterIndices(),
func->getASTContext()));
func->getASTContext())));
}
}

Expand Down
37 changes: 0 additions & 37 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,43 +233,6 @@ std::string ASTMangler::mangleWitnessThunk(const ProtocolConformance *Conformanc
return finalize();
}

// SWIFT_ENABLE_TENSORFLOW
std::string ASTMangler::mangleAutoDiffAssociatedFunctionWitnessThunk(
const ProtocolConformance *Conformance, const ValueDecl *Requirement,
const AutoDiffAssociatedFunctionIdentifier *id) {
assert(id);

beginMangling();

// TODO: Proper mangling for autodiff associated function witness thunks.
switch (id->getKind()) {
case AutoDiffAssociatedFunctionKind::JVP:
appendIdentifier("jvp");
break;
case AutoDiffAssociatedFunctionKind::VJP:
appendIdentifier("vjp");
break;
}
appendIdentifier(id->getParameterIndices()->getString() + " ");

// The rest of this function is copy-pasted from `mangleWitnessThunk`.

// Concrete witness thunks get a special mangling.
if (Conformance)
appendProtocolConformance(Conformance);

if (auto ctor = dyn_cast<ConstructorDecl>(Requirement)) {
appendConstructorEntity(ctor, /*isAllocating=*/true);
} else {
assert(isa<FuncDecl>(Requirement) && "expected function");
appendEntity(cast<FuncDecl>(Requirement));
}

if (Conformance)
appendOperator("TW");
return finalize();
}

std::string ASTMangler::mangleClosureWitnessThunk(
const ProtocolConformance *Conformance,
const AbstractClosureExpr *Closure) {
Expand Down
4 changes: 2 additions & 2 deletions lib/IRGen/GenKeyPath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,8 +960,8 @@ emitKeyPathComponent(IRGenModule &IGM,
auto methodProto = cast<ProtocolDecl>(dc);
auto &protoInfo = IGM.getProtocolInfo(methodProto,
ProtocolInfoKind::Full);
auto index = protoInfo.getFunctionIndex(
cast<AbstractFunctionDecl>(declRef.getDecl()));
// SWIFT_ENABLE_TENSORFLOW
auto index = protoInfo.getFunctionIndex(declRef);
idValue = llvm::ConstantInt::get(IGM.SizeTy, -index.getValue());
idResolved = true;
}
Expand Down
8 changes: 0 additions & 8 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,14 +657,6 @@ namespace {
return { flags, defaultImpl };
}

if (entry.isAutoDiffAssociatedFunction()) {
assert(!Resilient && "TODO: Resilient autodiff associated funcs");
auto flags = getMethodDescriptorFlags<Flags>(
entry.getAutoDiffAssociatedFunctionOriginal());
// TODO: Default witness.
return { flags, nullptr };
}

assert(entry.isFunction());
SILDeclRef func(entry.getFunction());

Expand Down
65 changes: 8 additions & 57 deletions lib/IRGen/GenProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,23 +776,15 @@ namespace {
}

void addMethod(SILDeclRef func) {
auto decl = cast<AbstractFunctionDecl>(func.getDecl());
Entries.push_back(WitnessTableEntry::forFunction(decl));
}

// SWIFT_ENABLE_TENSORFLOW
void addAutoDiffAssociatedFunction(
SILDeclRef origFunc,
AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId) {
auto decl = cast<AbstractFunctionDecl>(origFunc.getDecl());
Entries.push_back(WitnessTableEntry::forAutoDiffAssociatedFunction(
decl, autoDiffFuncId));
// SWIFT_ENABLE_TENSORFLOW
Entries.push_back(WitnessTableEntry::forFunction(func));
}

void addPlaceholder(MissingMemberDecl *placeholder) {
for (auto i : range(placeholder->getNumberOfVTableEntries())) {
(void)i;
Entries.push_back(WitnessTableEntry());
// SWIFT_ENABLE_TENSORFLOW
Entries.push_back(WitnessTableEntry::forPlaceholder());
}
}

Expand Down Expand Up @@ -1327,8 +1319,8 @@ class AccessorConformanceInfo : public ConformanceInfo {
&& "sil witness table does not match protocol");
assert(entry.getMethodWitness().Requirement == requirement
&& "sil witness table does not match protocol");
auto piIndex =
PI.getFunctionIndex(cast<AbstractFunctionDecl>(requirement.getDecl()));
// SWIFT_ENABLE_TENSORFLOW
auto piIndex = PI.getFunctionIndex(requirement);
assert((size_t)piIndex.getValue() ==
Table.size() - WitnessTableFirstRequirementOffset &&
"offset doesn't match ProtocolInfo layout");
Expand All @@ -1347,46 +1339,6 @@ class AccessorConformanceInfo : public ConformanceInfo {
return;
}

// SWIFT_ENABLE_TENSORFLOW
void addAutoDiffAssociatedFunction(
SILDeclRef requirementOriginalMethod,
AutoDiffAssociatedFunctionIdentifier *requirementIdentifier) {
auto &entry = SILEntries.front();
SILEntries = SILEntries.slice(1);

// Resilient conformances get a resilient witness table.
if (ResilientConformance)
return;

#ifndef NDEBUG
assert(entry.getKind() == SILWitnessTable::AutoDiffAssociatedFunction
&& "sil witness table does not match protocol");
auto silWitness = entry.getAutoDiffAssociatedFunctionWitness();
assert(silWitness.RequirementOriginalMethod == requirementOriginalMethod
&& "sil witness table does not match protocol");
assert(silWitness.RequirementIdentifier == requirementIdentifier
&& "sil witness table does not match protocol");
auto piIndex = PI.getAutoDiffAssociatedFunctionIndex(
cast<AbstractFunctionDecl>(requirementOriginalMethod.getDecl()),
requirementIdentifier);
assert((size_t)piIndex.getValue() ==
Table.size() - WitnessTableFirstRequirementOffset &&
"offset doesn't match ProtocolInfo layout");
#endif

SILFunction *Func = entry.getAutoDiffAssociatedFunctionWitness().Witness;
llvm::Constant *witness = nullptr;
if (Func) {
witness = IGM.getAddrOfSILFunction(Func, NotForDefinition);
} else {
// The method is removed by dead method elimination.
// It should be never called. We add a pointer to an error function.
witness = IGM.getDeletedMethodErrorFn();
}
Table.addBitCast(witness, IGM.Int8PtrTy);
return;
}

void addPlaceholder(MissingMemberDecl *placeholder) {
llvm_unreachable("cannot emit a witness table with placeholders in it");
}
Expand Down Expand Up @@ -2296,8 +2248,6 @@ static bool isConstantWitnessTable(SILWitnessTable *wt) {
case SILWitnessTable::AssociatedTypeProtocol:
case SILWitnessTable::BaseProtocol:
case SILWitnessTable::Method:
// SWIFT_ENABLE_TENSORFLOW
case SILWitnessTable::AutoDiffAssociatedFunction:
continue;

case SILWitnessTable::AssociatedType:
Expand Down Expand Up @@ -3430,7 +3380,8 @@ irgen::emitWitnessMethodValue(IRGenFunction &IGF,

// Find the witness we're interested in.
auto &fnProtoInfo = IGF.IGM.getProtocolInfo(proto, ProtocolInfoKind::Full);
auto index = fnProtoInfo.getFunctionIndex(fn);
// SWIFT_ENABLE_TENSORFLOW
auto index = fnProtoInfo.getFunctionIndex(member);
llvm::Value *witnessFnPtr =
emitInvariantLoadOfOpaqueWitness(IGF, wtable,
index.forProtocolWitnessTable());
Expand Down
3 changes: 3 additions & 0 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,9 @@ void IRGenSILFunction::visitAutoDiffFunctionInst(AutoDiffFunctionInst *i) {
e.add(origExp.claimAll());
for (auto &assocFnOp : i->getAssociatedFunctions())
e.add(getLoweredExplosion(assocFnOp.get()).claimAll());
assert(1 + i->getNumAssociatedFunctions() ==
getTypeInfo(i->getType()).getSchema().size() &&
"the AD pass hasn't added associated functions to this instruction");
setLoweredExplosion(i, e);
}

Expand Down
Loading