Skip to content

Commit 033f09a

Browse files
authored
make createVJP and createEmptyAdjoint work with generic functions (#21062)
1 parent 29fad40 commit 033f09a

File tree

2 files changed

+171
-75
lines changed

2 files changed

+171
-75
lines changed

lib/SIL/SILFunctionType.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
219219
// VJP: (T...) -> ((R...),
220220
// (R.CotangentVector...) -> (T.CotangentVector...))
221221

222+
// RAII that pushes our generic context to `module.Types` so that the calls
223+
// `module.Types.getTypeLowering()` below can understand our generic
224+
// parameter types.
225+
GenericContextScope genericContextScope(module.Types, getGenericSignature());
226+
222227
auto &ctx = getASTContext();
223228

224229
unsigned numParamsWithoutSelf = hasSelfParam() ? getNumParameters() - 1
@@ -263,24 +268,24 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
263268
};
264269

265270
// Given a type, returns its formal SIL parameter info.
266-
auto getFormalParameterInfo = [&module](CanType type) -> SILParameterInfo {
267-
SILType silTy = SILType::getPrimitiveObjectType(type);
271+
auto getFormalParameterInfo = [&](CanType type) -> SILParameterInfo {
272+
auto &typeLowering = module.Types.getTypeLowering(type);
268273
ParameterConvention conv;
269-
if (SILModuleConventions::isPassedIndirectlyInSIL(silTy, module))
274+
if (typeLowering.isFormallyPassedIndirectly())
270275
conv = ParameterConvention::Indirect_In_Guaranteed;
271-
else if (silTy.isTrivial(module))
276+
else if (typeLowering.isTrivial())
272277
conv = ParameterConvention::Direct_Unowned;
273278
else
274279
conv = ParameterConvention::Direct_Guaranteed;
275280
return {type, conv};
276281
};
277282
// Given a type, returns its formal SIL result info.
278-
auto getFormalResultInfo = [&module](CanType type) -> SILResultInfo {
279-
SILType silTy = SILType::getPrimitiveObjectType(type);
283+
auto getFormalResultInfo = [&](CanType type) -> SILResultInfo {
284+
auto &typeLowering = module.Types.getTypeLowering(type);
280285
ResultConvention conv;
281-
if (SILModuleConventions::isPassedIndirectlyInSIL(silTy, module))
286+
if (typeLowering.isFormallyReturnedIndirectly())
282287
conv = ResultConvention::Indirect;
283-
else if (silTy.isTrivial(module))
288+
else if (typeLowering.isTrivial())
284289
conv = ResultConvention::Unowned;
285290
else
286291
conv = ResultConvention::Owned;
@@ -312,7 +317,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
312317
result.getConvention()
313318
});
314319
auto differentialType = SILFunctionType::get(
315-
getGenericSignature(), ExtInfo(), SILCoroutineKind::None,
320+
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
316321
ParameterConvention::Direct_Guaranteed, tangentParams, {},
317322
tangentResults, None, ctx);
318323
SmallVector<SILResultInfo, 8> jvpResults(getResults().begin(),
@@ -343,7 +348,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
343348
getAssociatedType(param.getType(), cotangentDependentType)));
344349
}
345350
auto pullbackType = SILFunctionType::get(
346-
getGenericSignature(), ExtInfo(), SILCoroutineKind::None,
351+
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
347352
ParameterConvention::Direct_Guaranteed, cotangentParams, {},
348353
cotangentResults, {}, ctx);
349354
SmallVector<SILResultInfo, 8> vjpResults(getResults().begin(),

0 commit comments

Comments
 (0)