Skip to content

Commit 6cc8406

Browse files
committed
Fix conformance lookup within DifferentiableActivityInfo.
Define `DifferentiableActivityInfo::getLookupConformanceFunction` helper. Use `LookUpConformanceInModule` when derivative generic signature is undefined.
1 parent 3692167 commit 6cc8406

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,7 +1383,9 @@ using Activity = OptionSet<ActivityFlags>;
13831383
class DifferentiableActivityInfo {
13841384
private:
13851385
DifferentiableActivityCollection &parent;
1386-
GenericSignature assocGenSig = GenericSignature();
1386+
1387+
/// The derivative generic signature.
1388+
GenericSignature derivativeGenericSignature;
13871389

13881390
/// Input values, i.e. parameters (both direct and indirect).
13891391
SmallVector<SILValue, 4> inputValues;
@@ -1401,6 +1403,17 @@ class DifferentiableActivityInfo {
14011403
/// The original function.
14021404
SILFunction &getFunction();
14031405

1406+
/// The conformance lookup function.
1407+
LookupConformanceFn getLookupConformanceFunction() {
1408+
// Look up in derivative generic signature, if defined.
1409+
if (derivativeGenericSignature)
1410+
return LookUpConformanceInSignature(
1411+
derivativeGenericSignature.getPointer());
1412+
// Otherwise, look up in the module.
1413+
return LookUpConformanceInModule(
1414+
getFunction().getModule().getSwiftModule());
1415+
}
1416+
14041417
/// Perform analysis and populate sets.
14051418
void analyze(DominanceInfo *di, PostDominanceInfo *pdi);
14061419

@@ -1420,7 +1433,8 @@ class DifferentiableActivityInfo {
14201433

14211434
public:
14221435
explicit DifferentiableActivityInfo(
1423-
DifferentiableActivityCollection &parent, GenericSignature assocGenSig);
1436+
DifferentiableActivityCollection &parent,
1437+
GenericSignature derivativeGenericSignature);
14241438

14251439
bool isVaried(SILValue value, unsigned independentVariableIndex) const;
14261440
bool isUseful(SILValue value, unsigned dependentVariableIndex) const;
@@ -1834,14 +1848,18 @@ DifferentiableActivityCollection::DifferentiableActivityCollection(
18341848
: function(f), domInfo(di), postDomInfo(pdi) {}
18351849

18361850
DifferentiableActivityInfo::DifferentiableActivityInfo(
1837-
DifferentiableActivityCollection &parent, GenericSignature assocGenSig)
1838-
: parent(parent), assocGenSig(assocGenSig) {
1851+
DifferentiableActivityCollection &parent, GenericSignature derivGenSig)
1852+
: parent(parent), derivativeGenericSignature(derivGenSig) {
18391853
analyze(parent.domInfo, parent.postDomInfo);
18401854
}
18411855

1856+
SILFunction &DifferentiableActivityInfo::getFunction() {
1857+
return parent.function;
1858+
}
1859+
18421860
void DifferentiableActivityInfo::analyze(DominanceInfo *di,
18431861
PostDominanceInfo *pdi) {
1844-
auto &function = parent.function;
1862+
auto &function = getFunction();
18451863
LLVM_DEBUG(getADDebugStream()
18461864
<< "Running activity analysis on @" << function.getName() << '\n');
18471865
// Inputs are just function's arguments, count `n`.
@@ -1905,11 +1923,11 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
19051923
else if (auto *teai = dyn_cast<TupleElementAddrInst>(&inst)) {
19061924
if (isVaried(teai->getOperand(), i)) {
19071925
auto projType = teai->getType().getASTType();
1908-
if (assocGenSig && projType->hasArchetype())
1909-
projType = assocGenSig->getCanonicalTypeInContext(
1926+
if (derivativeGenericSignature && projType->hasArchetype())
1927+
projType = derivativeGenericSignature->getCanonicalTypeInContext(
19101928
projType->mapTypeOutOfContext());
19111929
if (projType->getAutoDiffAssociatedTangentSpace(
1912-
LookUpConformanceInSignature(assocGenSig.getPointer())))
1930+
getLookupConformanceFunction()))
19131931
setVaried(teai, i);
19141932
}
19151933
}

0 commit comments

Comments
 (0)