Skip to content

Commit 4ffc714

Browse files
authored
[AutoDiff] Use GenericSignature instead of GenericSignatureImpl *. (#27831)
Use `GenericSignature` instead of `GenericSignatureImpl *` in `AutoDiffConfig` and the differentiation transform. `AutoDiffConfig` changes by @marcrasi.
1 parent 56e4de8 commit 4ffc714

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,11 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
269269
struct AutoDiffConfig {
270270
IndexSubset *parameterIndices;
271271
IndexSubset *resultIndices;
272-
GenericSignatureImpl* derivativeGenericSignature;
272+
GenericSignature derivativeGenericSignature;
273273

274274
/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
275275
IndexSubset *resultIndices,
276-
GenericSignatureImpl *derivativeGenericSignature)
276+
GenericSignature derivativeGenericSignature)
277277
: parameterIndices(parameterIndices), resultIndices(resultIndices),
278278
derivativeGenericSignature(derivativeGenericSignature) {}
279279

@@ -443,7 +443,7 @@ namespace llvm {
443443

444444
using swift::AutoDiffConfig;
445445
using swift::AutoDiffDerivativeFunctionKind;
446-
using swift::GenericSignatureImpl;
446+
using swift::GenericSignature;
447447
using swift::IndexSubset;
448448
using swift::SILAutoDiffIndices;
449449

@@ -453,27 +453,29 @@ template<> struct DenseMapInfo<AutoDiffConfig> {
453453
static AutoDiffConfig getEmptyKey() {
454454
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
455455
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
456-
static_cast<GenericSignatureImpl *>(ptr)};
456+
DenseMapInfo<GenericSignature>::getEmptyKey()};
457457
}
458458

459459
static AutoDiffConfig getTombstoneKey() {
460460
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
461461
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
462-
static_cast<GenericSignatureImpl *>(ptr)};
462+
DenseMapInfo<GenericSignature>::getTombstoneKey()};
463463
}
464464

465465
static unsigned getHashValue(const AutoDiffConfig &Val) {
466466
unsigned combinedHash = hash_combine(
467467
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
468468
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
469-
DenseMapInfo<void *>::getHashValue(Val.derivativeGenericSignature));
469+
DenseMapInfo<GenericSignature>::getHashValue(
470+
Val.derivativeGenericSignature));
470471
return combinedHash;
471472
}
472473

473474
static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
474475
return LHS.parameterIndices == RHS.parameterIndices &&
475476
LHS.resultIndices == RHS.resultIndices &&
476-
LHS.derivativeGenericSignature == RHS.derivativeGenericSignature;
477+
DenseMapInfo<GenericSignature>::isEqual(LHS.derivativeGenericSignature,
478+
RHS.derivativeGenericSignature);
477479
}
478480
};
479481

lib/AST/ASTMangler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ std::string ASTMangler::mangleSILDifferentiabilityWitnessKey(
434434
auto originalName = key.first;
435435
auto *parameterIndices = key.second.parameterIndices;
436436
auto *resultIndices = key.second.resultIndices;
437-
auto *derivativeGenericSignature = key.second.derivativeGenericSignature;
437+
auto derivativeGenericSignature = key.second.derivativeGenericSignature;
438438

439439
Buffer << "AD__" << originalName << '_';
440440
Buffer << "P" << parameterIndices->getString();

lib/SILGen/SILGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ void SILGenModule::emitDifferentiabilityWitness(
810810
bool reorderSelf = shouldReorderSelf();
811811

812812
CanGenericSignature derivativeCanGenSig;
813-
if (auto *derivativeGenSig = config.derivativeGenericSignature)
813+
if (auto derivativeGenSig = config.derivativeGenericSignature)
814814
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
815815
// TODO(TF-835): Use simpler derivative generic signature logic below when
816816
// type-checking no longer generates implicit `@differentiable` attributes.

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,18 +1823,18 @@ void LinearMapInfo::generateDifferentiationDataStructures(
18231823

18241824
class DifferentiableActivityCollection {
18251825
public:
1826-
SmallDenseMap<GenericSignatureImpl *, DifferentiableActivityInfo> activityInfoMap;
1826+
SmallDenseMap<GenericSignature, DifferentiableActivityInfo> activityInfoMap;
18271827
SILFunction &function;
18281828
DominanceInfo *domInfo;
18291829
PostDominanceInfo *postDomInfo;
18301830

18311831
DifferentiableActivityInfo &getActivityInfo(
18321832
GenericSignature assocGenSig, AutoDiffDerivativeFunctionKind kind) {
1833-
auto activityInfoLookup = activityInfoMap.find(assocGenSig.getPointer());
1833+
auto activityInfoLookup = activityInfoMap.find(assocGenSig);
18341834
if (activityInfoLookup != activityInfoMap.end())
18351835
return activityInfoLookup->getSecond();
18361836
auto insertion = activityInfoMap.insert(
1837-
{assocGenSig.getPointer(), DifferentiableActivityInfo(*this, assocGenSig)});
1837+
{assocGenSig, DifferentiableActivityInfo(*this, assocGenSig)});
18381838
return insertion.first->getSecond();
18391839
}
18401840

0 commit comments

Comments
 (0)