Skip to content

Commit b435eef

Browse files
authored
AutoDiff associated functions in witness table, using SILDeclRef (#21241)
This deletes the AutoDiffAssociatedFunctionWitness from the witness table, and instead puts the autodiff associated functions in normal method witness table entries, using the autodiff SILDeclRef from #21224.
1 parent cb9ac39 commit b435eef

26 files changed

+408
-1020
lines changed

include/swift/AST/ASTMangler.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,6 @@ class ASTMangler : public Mangler {
120120
std::string mangleWitnessThunk(const ProtocolConformance *Conformance,
121121
const ValueDecl *Requirement);
122122

123-
// SWIFT_ENABLE_TENSORFLOW
124-
std::string mangleAutoDiffAssociatedFunctionWitnessThunk(
125-
const ProtocolConformance *Conformance, const ValueDecl *Requirement,
126-
const AutoDiffAssociatedFunctionIdentifier *AutoDiffFuncId);
127-
128123
std::string mangleClosureWitnessThunk(const ProtocolConformance *Conformance,
129124
const AbstractClosureExpr *Closure);
130125

include/swift/AST/DiagnosticsParse.def

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,11 +1528,6 @@ ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken,
15281528
ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken,
15291529
"expects an assoiacted function kind attribute, e.g. '[jvp]'", ())
15301530

1531-
ERROR(malformed_autodiff_associated_function_kind,PointsToFirstBadToken,
1532-
"autodiff associated function kind must be 'jvp' or 'vjp'", ())
1533-
ERROR(malformed_autodiff_associated_function_indices,PointsToFirstBadToken,
1534-
"malformed autodiff associated function indices", ())
1535-
15361531
// SWIFT_ENABLE_TENSORFLOW
15371532
ERROR(pound_assert_expected,PointsToFirstBadToken,
15381533
"expected '%0' in #assert directive", (StringRef))

include/swift/SIL/SILDeclRef.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,13 @@ struct SILDeclRef {
361361
return r;
362362
}
363363

364+
/// Returns this `SILDeclRef` with the `loc` replaced with `decl`.
365+
SILDeclRef withDecl(ValueDecl *decl) const {
366+
SILDeclRef result = *this;
367+
result.loc = decl;
368+
return result;
369+
}
370+
364371
/// True if the decl ref references a thunk from a natively foreign
365372
/// declaration to Swift calling convention.
366373
bool isForeignToNativeThunk() const;

include/swift/SIL/SILWitnessTable.h

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,6 @@ class SILWitnessTable : public llvm::ilist_node<SILWitnessTable>,
5252
/// This can be null in case dead function elimination has removed the method.
5353
SILFunction *Witness;
5454
};
55-
56-
// SWIFT_ENABLE_TENSORFLOW
57-
/// A witness table entry describing the witness for an autodiff associated
58-
/// function for a method.
59-
struct AutoDiffAssociatedFunctionWitness {
60-
/// The original method required.
61-
SILDeclRef RequirementOriginalMethod;
62-
/// The AutoDiffAssociatedFunctionIdentifier identifying the associated
63-
/// function.
64-
AutoDiffAssociatedFunctionIdentifier *RequirementIdentifier;
65-
/// The witness for the autodiff associated function.
66-
/// This can be null in case dead function elimination has removed the method.
67-
SILFunction *Witness;
68-
};
6955

7056
/// A witness table entry describing the witness for an associated type.
7157
struct AssociatedTypeWitness {
@@ -103,9 +89,7 @@ class SILWitnessTable : public llvm::ilist_node<SILWitnessTable>,
10389
Method,
10490
AssociatedType,
10591
AssociatedTypeProtocol,
106-
BaseProtocol,
107-
// SWIFT_ENABLE_TENSORFLOW
108-
AutoDiffAssociatedFunction
92+
BaseProtocol
10993
};
11094

11195
/// A witness table entry.
@@ -116,8 +100,6 @@ class SILWitnessTable : public llvm::ilist_node<SILWitnessTable>,
116100
AssociatedTypeWitness AssociatedType;
117101
AssociatedTypeProtocolWitness AssociatedTypeProtocol;
118102
BaseProtocolWitness BaseProtocol;
119-
// SWIFT_ENABLE_TENSORFLOW
120-
AutoDiffAssociatedFunctionWitness AutoDiffAssociatedFunction;
121103
};
122104

123105
public:
@@ -140,12 +122,6 @@ class SILWitnessTable : public llvm::ilist_node<SILWitnessTable>,
140122
: Kind(WitnessKind::BaseProtocol),
141123
BaseProtocol(BaseProtocol)
142124
{}
143-
144-
// SWIFT_ENABLE_TENSORFLOW
145-
Entry(const AutoDiffAssociatedFunctionWitness &AutoDiffAssociatedFunction)
146-
: Kind(WitnessKind::AutoDiffAssociatedFunction),
147-
AutoDiffAssociatedFunction(AutoDiffAssociatedFunction)
148-
{}
149125

150126
WitnessKind getKind() const { return Kind; }
151127

@@ -168,11 +144,6 @@ class SILWitnessTable : public llvm::ilist_node<SILWitnessTable>,
168144
assert(Kind == WitnessKind::BaseProtocol);
169145
return BaseProtocol;
170146
}
171-
const AutoDiffAssociatedFunctionWitness
172-
&getAutoDiffAssociatedFunctionWitness() const {
173-
assert(Kind == WitnessKind::AutoDiffAssociatedFunction);
174-
return AutoDiffAssociatedFunction;
175-
}
176147

177148
void removeWitnessMethod() {
178149
assert(Kind == WitnessKind::Method);

include/swift/SIL/SILWitnessVisitor.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ namespace swift {
4242
/// - addMethod()
4343
/// - addConstructor()
4444
/// - addAssociatedType()
45-
/// SWIFT_ENABLE_TENSORFLOW
46-
/// - addAutoDiffAssociatedFunction()
4745

4846
template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
4947
T &asDerived() { return *static_cast<T*>(this); }
@@ -148,20 +146,18 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
148146
asDerived().addMethod(funcDeclRef);
149147

150148
if (auto *DA = func->getAttrs().getAttribute<DifferentiableAttr>()) {
151-
asDerived().addAutoDiffAssociatedFunction(
152-
funcDeclRef,
149+
asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction(
153150
AutoDiffAssociatedFunctionIdentifier::get(
154151
AutoDiffAssociatedFunctionKind::JVP,
155152
/*differentiationOrder*/ 1,
156153
DA->getCheckedParameterIndices(),
157-
func->getASTContext()));
158-
asDerived().addAutoDiffAssociatedFunction(
159-
funcDeclRef,
154+
func->getASTContext())));
155+
asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction(
160156
AutoDiffAssociatedFunctionIdentifier::get(
161157
AutoDiffAssociatedFunctionKind::VJP,
162158
/*differentiationOrder*/ 1,
163159
DA->getCheckedParameterIndices(),
164-
func->getASTContext()));
160+
func->getASTContext())));
165161
}
166162
}
167163

lib/AST/ASTMangler.cpp

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -233,43 +233,6 @@ std::string ASTMangler::mangleWitnessThunk(const ProtocolConformance *Conformanc
233233
return finalize();
234234
}
235235

236-
// SWIFT_ENABLE_TENSORFLOW
237-
std::string ASTMangler::mangleAutoDiffAssociatedFunctionWitnessThunk(
238-
const ProtocolConformance *Conformance, const ValueDecl *Requirement,
239-
const AutoDiffAssociatedFunctionIdentifier *id) {
240-
assert(id);
241-
242-
beginMangling();
243-
244-
// TODO: Proper mangling for autodiff associated function witness thunks.
245-
switch (id->getKind()) {
246-
case AutoDiffAssociatedFunctionKind::JVP:
247-
appendIdentifier("jvp");
248-
break;
249-
case AutoDiffAssociatedFunctionKind::VJP:
250-
appendIdentifier("vjp");
251-
break;
252-
}
253-
appendIdentifier(id->getParameterIndices()->getString() + " ");
254-
255-
// The rest of this function is copy-pasted from `mangleWitnessThunk`.
256-
257-
// Concrete witness thunks get a special mangling.
258-
if (Conformance)
259-
appendProtocolConformance(Conformance);
260-
261-
if (auto ctor = dyn_cast<ConstructorDecl>(Requirement)) {
262-
appendConstructorEntity(ctor, /*isAllocating=*/true);
263-
} else {
264-
assert(isa<FuncDecl>(Requirement) && "expected function");
265-
appendEntity(cast<FuncDecl>(Requirement));
266-
}
267-
268-
if (Conformance)
269-
appendOperator("TW");
270-
return finalize();
271-
}
272-
273236
std::string ASTMangler::mangleClosureWitnessThunk(
274237
const ProtocolConformance *Conformance,
275238
const AbstractClosureExpr *Closure) {

lib/IRGen/GenKeyPath.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,8 +960,8 @@ emitKeyPathComponent(IRGenModule &IGM,
960960
auto methodProto = cast<ProtocolDecl>(dc);
961961
auto &protoInfo = IGM.getProtocolInfo(methodProto,
962962
ProtocolInfoKind::Full);
963-
auto index = protoInfo.getFunctionIndex(
964-
cast<AbstractFunctionDecl>(declRef.getDecl()));
963+
// SWIFT_ENABLE_TENSORFLOW
964+
auto index = protoInfo.getFunctionIndex(declRef);
965965
idValue = llvm::ConstantInt::get(IGM.SizeTy, -index.getValue());
966966
idResolved = true;
967967
}

lib/IRGen/GenMeta.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -657,14 +657,6 @@ namespace {
657657
return { flags, defaultImpl };
658658
}
659659

660-
if (entry.isAutoDiffAssociatedFunction()) {
661-
assert(!Resilient && "TODO: Resilient autodiff associated funcs");
662-
auto flags = getMethodDescriptorFlags<Flags>(
663-
entry.getAutoDiffAssociatedFunctionOriginal());
664-
// TODO: Default witness.
665-
return { flags, nullptr };
666-
}
667-
668660
assert(entry.isFunction());
669661
SILDeclRef func(entry.getFunction());
670662

lib/IRGen/GenProto.cpp

Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -776,23 +776,15 @@ namespace {
776776
}
777777

778778
void addMethod(SILDeclRef func) {
779-
auto decl = cast<AbstractFunctionDecl>(func.getDecl());
780-
Entries.push_back(WitnessTableEntry::forFunction(decl));
781-
}
782-
783-
// SWIFT_ENABLE_TENSORFLOW
784-
void addAutoDiffAssociatedFunction(
785-
SILDeclRef origFunc,
786-
AutoDiffAssociatedFunctionIdentifier *autoDiffFuncId) {
787-
auto decl = cast<AbstractFunctionDecl>(origFunc.getDecl());
788-
Entries.push_back(WitnessTableEntry::forAutoDiffAssociatedFunction(
789-
decl, autoDiffFuncId));
779+
// SWIFT_ENABLE_TENSORFLOW
780+
Entries.push_back(WitnessTableEntry::forFunction(func));
790781
}
791782

792783
void addPlaceholder(MissingMemberDecl *placeholder) {
793784
for (auto i : range(placeholder->getNumberOfVTableEntries())) {
794785
(void)i;
795-
Entries.push_back(WitnessTableEntry());
786+
// SWIFT_ENABLE_TENSORFLOW
787+
Entries.push_back(WitnessTableEntry::forPlaceholder());
796788
}
797789
}
798790

@@ -1327,8 +1319,8 @@ class AccessorConformanceInfo : public ConformanceInfo {
13271319
&& "sil witness table does not match protocol");
13281320
assert(entry.getMethodWitness().Requirement == requirement
13291321
&& "sil witness table does not match protocol");
1330-
auto piIndex =
1331-
PI.getFunctionIndex(cast<AbstractFunctionDecl>(requirement.getDecl()));
1322+
// SWIFT_ENABLE_TENSORFLOW
1323+
auto piIndex = PI.getFunctionIndex(requirement);
13321324
assert((size_t)piIndex.getValue() ==
13331325
Table.size() - WitnessTableFirstRequirementOffset &&
13341326
"offset doesn't match ProtocolInfo layout");
@@ -1347,46 +1339,6 @@ class AccessorConformanceInfo : public ConformanceInfo {
13471339
return;
13481340
}
13491341

1350-
// SWIFT_ENABLE_TENSORFLOW
1351-
void addAutoDiffAssociatedFunction(
1352-
SILDeclRef requirementOriginalMethod,
1353-
AutoDiffAssociatedFunctionIdentifier *requirementIdentifier) {
1354-
auto &entry = SILEntries.front();
1355-
SILEntries = SILEntries.slice(1);
1356-
1357-
// Resilient conformances get a resilient witness table.
1358-
if (ResilientConformance)
1359-
return;
1360-
1361-
#ifndef NDEBUG
1362-
assert(entry.getKind() == SILWitnessTable::AutoDiffAssociatedFunction
1363-
&& "sil witness table does not match protocol");
1364-
auto silWitness = entry.getAutoDiffAssociatedFunctionWitness();
1365-
assert(silWitness.RequirementOriginalMethod == requirementOriginalMethod
1366-
&& "sil witness table does not match protocol");
1367-
assert(silWitness.RequirementIdentifier == requirementIdentifier
1368-
&& "sil witness table does not match protocol");
1369-
auto piIndex = PI.getAutoDiffAssociatedFunctionIndex(
1370-
cast<AbstractFunctionDecl>(requirementOriginalMethod.getDecl()),
1371-
requirementIdentifier);
1372-
assert((size_t)piIndex.getValue() ==
1373-
Table.size() - WitnessTableFirstRequirementOffset &&
1374-
"offset doesn't match ProtocolInfo layout");
1375-
#endif
1376-
1377-
SILFunction *Func = entry.getAutoDiffAssociatedFunctionWitness().Witness;
1378-
llvm::Constant *witness = nullptr;
1379-
if (Func) {
1380-
witness = IGM.getAddrOfSILFunction(Func, NotForDefinition);
1381-
} else {
1382-
// The method is removed by dead method elimination.
1383-
// It should be never called. We add a pointer to an error function.
1384-
witness = IGM.getDeletedMethodErrorFn();
1385-
}
1386-
Table.addBitCast(witness, IGM.Int8PtrTy);
1387-
return;
1388-
}
1389-
13901342
void addPlaceholder(MissingMemberDecl *placeholder) {
13911343
llvm_unreachable("cannot emit a witness table with placeholders in it");
13921344
}
@@ -2296,8 +2248,6 @@ static bool isConstantWitnessTable(SILWitnessTable *wt) {
22962248
case SILWitnessTable::AssociatedTypeProtocol:
22972249
case SILWitnessTable::BaseProtocol:
22982250
case SILWitnessTable::Method:
2299-
// SWIFT_ENABLE_TENSORFLOW
2300-
case SILWitnessTable::AutoDiffAssociatedFunction:
23012251
continue;
23022252

23032253
case SILWitnessTable::AssociatedType:
@@ -3430,7 +3380,8 @@ irgen::emitWitnessMethodValue(IRGenFunction &IGF,
34303380

34313381
// Find the witness we're interested in.
34323382
auto &fnProtoInfo = IGF.IGM.getProtocolInfo(proto, ProtocolInfoKind::Full);
3433-
auto index = fnProtoInfo.getFunctionIndex(fn);
3383+
// SWIFT_ENABLE_TENSORFLOW
3384+
auto index = fnProtoInfo.getFunctionIndex(member);
34343385
llvm::Value *witnessFnPtr =
34353386
emitInvariantLoadOfOpaqueWitness(IGF, wtable,
34363387
index.forProtocolWitnessTable());

lib/IRGen/IRGenSIL.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,9 @@ void IRGenSILFunction::visitAutoDiffFunctionInst(AutoDiffFunctionInst *i) {
19851985
e.add(origExp.claimAll());
19861986
for (auto &assocFnOp : i->getAssociatedFunctions())
19871987
e.add(getLoweredExplosion(assocFnOp.get()).claimAll());
1988+
assert(1 + i->getNumAssociatedFunctions() ==
1989+
getTypeInfo(i->getType()).getSchema().size() &&
1990+
"the AD pass hasn't added associated functions to this instruction");
19881991
setLoweredExplosion(i, e);
19891992
}
19901993

0 commit comments

Comments
 (0)