@@ -1383,7 +1383,9 @@ using Activity = OptionSet<ActivityFlags>;
1383
1383
class DifferentiableActivityInfo {
1384
1384
private:
1385
1385
DifferentiableActivityCollection &parent;
1386
- GenericSignature assocGenSig = GenericSignature();
1386
+
1387
+ // / The derivative generic signature.
1388
+ GenericSignature derivativeGenericSignature;
1387
1389
1388
1390
// / Input values, i.e. parameters (both direct and indirect).
1389
1391
SmallVector<SILValue, 4 > inputValues;
@@ -1401,6 +1403,17 @@ class DifferentiableActivityInfo {
1401
1403
// / The original function.
1402
1404
SILFunction &getFunction ();
1403
1405
1406
+ // / The conformance lookup function.
1407
+ LookupConformanceFn getLookupConformanceFunction () {
1408
+ // Look up in derivative generic signature, if defined.
1409
+ if (derivativeGenericSignature)
1410
+ return LookUpConformanceInSignature (
1411
+ derivativeGenericSignature.getPointer ());
1412
+ // Otherwise, look up in the module.
1413
+ return LookUpConformanceInModule (
1414
+ getFunction ().getModule ().getSwiftModule ());
1415
+ }
1416
+
1404
1417
// / Perform analysis and populate sets.
1405
1418
void analyze (DominanceInfo *di, PostDominanceInfo *pdi);
1406
1419
@@ -1420,7 +1433,8 @@ class DifferentiableActivityInfo {
1420
1433
1421
1434
public:
1422
1435
explicit DifferentiableActivityInfo (
1423
- DifferentiableActivityCollection &parent, GenericSignature assocGenSig);
1436
+ DifferentiableActivityCollection &parent,
1437
+ GenericSignature derivativeGenericSignature);
1424
1438
1425
1439
bool isVaried (SILValue value, unsigned independentVariableIndex) const ;
1426
1440
bool isUseful (SILValue value, unsigned dependentVariableIndex) const ;
@@ -1834,14 +1848,18 @@ DifferentiableActivityCollection::DifferentiableActivityCollection(
1834
1848
: function(f), domInfo(di), postDomInfo(pdi) {}
1835
1849
1836
1850
DifferentiableActivityInfo::DifferentiableActivityInfo (
1837
- DifferentiableActivityCollection &parent, GenericSignature assocGenSig )
1838
- : parent(parent), assocGenSig(assocGenSig ) {
1851
+ DifferentiableActivityCollection &parent, GenericSignature derivGenSig )
1852
+ : parent(parent), derivativeGenericSignature(derivGenSig ) {
1839
1853
analyze (parent.domInfo , parent.postDomInfo );
1840
1854
}
1841
1855
1856
+ SILFunction &DifferentiableActivityInfo::getFunction () {
1857
+ return parent.function ;
1858
+ }
1859
+
1842
1860
void DifferentiableActivityInfo::analyze (DominanceInfo *di,
1843
1861
PostDominanceInfo *pdi) {
1844
- auto &function = parent. function ;
1862
+ auto &function = getFunction () ;
1845
1863
LLVM_DEBUG (getADDebugStream ()
1846
1864
<< " Running activity analysis on @" << function.getName () << ' \n ' );
1847
1865
// Inputs are just function's arguments, count `n`.
@@ -1905,11 +1923,11 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
1905
1923
else if (auto *teai = dyn_cast<TupleElementAddrInst>(&inst)) {
1906
1924
if (isVaried (teai->getOperand (), i)) {
1907
1925
auto projType = teai->getType ().getASTType ();
1908
- if (assocGenSig && projType->hasArchetype ())
1909
- projType = assocGenSig ->getCanonicalTypeInContext (
1926
+ if (derivativeGenericSignature && projType->hasArchetype ())
1927
+ projType = derivativeGenericSignature ->getCanonicalTypeInContext (
1910
1928
projType->mapTypeOutOfContext ());
1911
1929
if (projType->getAutoDiffAssociatedTangentSpace (
1912
- LookUpConformanceInSignature (assocGenSig. getPointer () )))
1930
+ getLookupConformanceFunction ( )))
1913
1931
setVaried (teai, i);
1914
1932
}
1915
1933
}
0 commit comments