Skip to content

Commit 419eea2

Browse files
committed
Revamp serialization to enable lookup by key.
Use mangling to support string key lookup.
1 parent a2ae0f2 commit 419eea2

File tree

12 files changed

+205
-111
lines changed

12 files changed

+205
-111
lines changed

include/swift/AST/ASTMangler.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,23 +155,31 @@ class ASTMangler : public Mangler {
155155
ModuleDecl *Module);
156156

157157
// SWIFT_ENABLE_TENSORFLOW
158-
// Mangle the derivative function (JVP/VJP) with the given:
159-
// - Mangled original function name.
160-
// - Derivative function kind.
161-
// - Parameter/result indices.
158+
/// Mangle the derivative function (JVP/VJP) with the given:
159+
/// - Mangled original function name.
160+
/// - Derivative function kind.
161+
/// - Parameter/result indices.
162162
std::string mangleAutoDiffDerivativeFunctionHelper(
163163
StringRef name, AutoDiffDerivativeFunctionKind kind,
164164
const SILAutoDiffIndices &indices);
165165

166-
// SWIFT_ENABLE_TENSORFLOW
167-
// Mangle the autodiff linear map (differential/pullback) with the given:
168-
// - Mangled original function name.
169-
// - Linear map kind.
170-
// - Parameter/result indices.
166+
/// Mangle the autodiff linear map (differential/pullback) with the given:
167+
/// - Mangled original function name.
168+
/// - Linear map kind.
169+
/// - Parameter/result indices.
171170
std::string mangleAutoDiffLinearMapHelper(
172171
StringRef name, AutoDiffLinearMapKind kind,
173172
const SILAutoDiffIndices &indices);
174173

174+
/// Mangle a SIL differentiability witness key.
175+
/// - Mangled original function name.
176+
/// - Parameter indices.
177+
/// - Result indices.
178+
/// - Derivative generic signature (optional).
179+
std::string mangleSILDifferentiabilityWitnessKey(
180+
SILDifferentiabilityWitnessKey key);
181+
// SWIFT_ENABLE_TENSORFLOW END
182+
175183
std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property,
176184
GenericSignature *signature,
177185
CanType baseType,

include/swift/AST/AutoDiff.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,14 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
481481
}
482482
};
483483

484+
/// The key type used for uniquing `SILDifferentiabilityWitness` in
485+
/// `SILModule`: original function name, parameter indices, result indices, and
486+
/// derivative generic signature.
487+
// TODO: Unify with `AutoDiffDerivativeFunctionIdentifier`.
488+
using SILDifferentiabilityWitnessKey =
489+
std::tuple<StringRef, AutoDiffIndexSubset *,
490+
AutoDiffIndexSubset *, GenericSignature *>;
491+
484492
/// Automatic differentiation utility namespace.
485493
namespace autodiff {
486494
/// Appends the subset's parameter's types to `result`, in the order in

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,7 @@ class SILDifferentiabilityWitness
7575
serialized(isSerialized) {}
7676

7777
public:
78-
/// The key type, used for uniquing `SILDifferentiabilityWitness` in
79-
/// `SILModule`, original function, parameter indices, result indices, and
80-
/// derivative generic signature.
81-
using Key = std::tuple<const SILFunction *, AutoDiffIndexSubset *,
82-
AutoDiffIndexSubset *, GenericSignature *>;
83-
Key getKey() {
84-
return std::make_tuple(originalFunction, parameterIndices, resultIndices,
85-
derivativeGenericSignature);
86-
}
87-
78+
SILDifferentiabilityWitnessKey getKey() const;
8879
SILModule &getModule() const { return module; }
8980
SILLinkage getLinkage() const { return linkage; }
9081
SILFunction *getOriginalFunction() const { return originalFunction; }

include/swift/SIL/SILModule.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,8 @@ class SILModule {
207207
/// Lookup table for SIL differentiability witnesses from original functions.
208208
/// Indexed by key type: original function, parameter indices, result indices,
209209
/// and derivative generic signature.
210-
llvm::DenseMap<SILDifferentiabilityWitness::Key,
211-
SILDifferentiabilityWitness *>
212-
DifferentiabilityWitnessMap;
210+
llvm::DenseMap<SILDifferentiabilityWitnessKey, SILDifferentiabilityWitness *>
211+
DifferentiabilityWitnessMap;
213212

214213
/// The list of SILDifferentiabilityWitnesses in the module.
215214
DifferentiabilityWitnessListType differentiabilityWitnesses;

include/swift/Serialization/SerializedSILLoader.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#ifndef SWIFT_SERIALIZATION_SILLOADER_H
1414
#define SWIFT_SERIALIZATION_SILLOADER_H
1515

16+
// SWIFT_ENABLE_TENSORFLOW
17+
#include "swift/AST/AutoDiff.h"
18+
// SWIFT_ENABLE_TENSORFLOW END
1619
#include "swift/AST/Decl.h"
1720
#include "swift/AST/Identifier.h"
1821
#include "swift/SIL/Notifications.h"
@@ -32,6 +35,9 @@ class SILModule;
3235
class SILVTable;
3336
class SILWitnessTable;
3437
class SILDefaultWitnessTable;
38+
// SWIFT_ENABLE_TENSORFLOW
39+
class SILDifferentiabilityWitness;
40+
// SWIFT_ENABLE_TENSORFLOW END
3541

3642
/// Maintains a list of SILDeserializer, one for each serialized modules
3743
/// in ASTContext. It provides lookupSILFunction that will perform lookup
@@ -64,6 +70,10 @@ class SerializedSILLoader {
6470
SILVTable *lookupVTable(const ClassDecl *C);
6571
SILWitnessTable *lookupWitnessTable(SILWitnessTable *C);
6672
SILDefaultWitnessTable *lookupDefaultWitnessTable(SILDefaultWitnessTable *C);
73+
// SWIFT_ENABLE_TENSORFLOW
74+
SILDifferentiabilityWitness *
75+
lookupDifferentiabilityWitness(SILDifferentiabilityWitnessKey key);
76+
// SWIFT_ENABLE_TENSORFLOW END
6777

6878
/// Invalidate the cached entries for deserialized SILFunctions.
6979
void invalidateCaches();

lib/AST/ASTMangler.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,11 +379,12 @@ std::string ASTMangler::mangleReabstractionThunkHelper(
379379
return finalize();
380380
}
381381

382+
// SWIFT_ENABLE_TENSORFLOW
382383
std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper(
383384
StringRef name, AutoDiffDerivativeFunctionKind kind,
384385
const SILAutoDiffIndices &indices) {
385386
// TODO(TF-20): Make the mangling scheme robust.
386-
// TODO(TF-680): Mangle `@differentiable` atttribute requirements as well.
387+
// TODO(TF-680): Mangle derivative generic signature as well.
387388
beginManglingWithoutPrefix();
388389

389390
Buffer << "AD__" << name << '_';
@@ -406,7 +407,7 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper(
406407
StringRef name, AutoDiffLinearMapKind kind,
407408
const SILAutoDiffIndices &indices) {
408409
// TODO(TF-20): Make the mangling scheme robust.
409-
// TODO(TF-680): Mangle `@differentiable` atttribute requirements as well.
410+
// TODO(TF-680): Mangle derivative generic signature as well.
410411
beginManglingWithoutPrefix();
411412

412413
Buffer << "AD__" << name << '_';
@@ -425,6 +426,29 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper(
425426
return result;
426427
}
427428

429+
std::string ASTMangler::mangleSILDifferentiabilityWitnessKey(
430+
SILDifferentiabilityWitnessKey key) {
431+
// TODO(TF-20): Make the mangling scheme robust.
432+
// TODO(TF-680): Mangle derivative generic signature as well.
433+
beginManglingWithoutPrefix();
434+
435+
auto originalName = std::get<0>(key);
436+
auto *parameterIndices = std::get<1>(key);
437+
auto *resultIndices = std::get<2>(key);
438+
auto *derivativeGenericSignature = std::get<3>(key);
439+
440+
Buffer << "AD__" << originalName << '_';
441+
Buffer << "P" << parameterIndices->getString();
442+
Buffer << "R" << resultIndices->getString();
443+
if (derivativeGenericSignature)
444+
appendGenericSignature(derivativeGenericSignature);
445+
446+
auto result = Storage.str().str();
447+
Storage.clear();
448+
return result;
449+
}
450+
// SWIFT_ENABLE_TENSORFLOW END
451+
428452
std::string ASTMangler::mangleTypeForDebugger(Type Ty, const DeclContext *DC) {
429453
PrettyStackTraceType prettyStackTrace(Ty->getASTContext(),
430454
"mangling type for debugger", Ty);

lib/SIL/SILDifferentiabilityWitness.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,8 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
3434
module.getDifferentiabilityWitnessList().push_back(diffWitness);
3535
return diffWitness;
3636
}
37+
38+
SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
39+
return std::make_tuple(originalFunction->getName(), parameterIndices,
40+
resultIndices, derivativeGenericSignature);
41+
}

0 commit comments

Comments
 (0)