@@ -219,6 +219,11 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
219
219
// VJP: (T...) -> ((R...),
220
220
// (R.CotangentVector...) -> (T.CotangentVector...))
221
221
222
+ // RAII that pushes our generic context to `module.Types` so that the calls
223
+ // `module.Types.getTypeLowering()` below can understand our generic
224
+ // parameter types.
225
+ GenericContextScope genericContextScope (module .Types , getGenericSignature ());
226
+
222
227
auto &ctx = getASTContext ();
223
228
224
229
unsigned numParamsWithoutSelf = hasSelfParam () ? getNumParameters () - 1
@@ -263,24 +268,24 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
263
268
};
264
269
265
270
// Given a type, returns its formal SIL parameter info.
266
- auto getFormalParameterInfo = [&module ](CanType type) -> SILParameterInfo {
267
- SILType silTy = SILType::getPrimitiveObjectType (type);
271
+ auto getFormalParameterInfo = [&](CanType type) -> SILParameterInfo {
272
+ auto &typeLowering = module . Types . getTypeLowering (type);
268
273
ParameterConvention conv;
269
- if (SILModuleConventions::isPassedIndirectlyInSIL (silTy, module ))
274
+ if (typeLowering. isFormallyPassedIndirectly ( ))
270
275
conv = ParameterConvention::Indirect_In_Guaranteed;
271
- else if (silTy .isTrivial (module ))
276
+ else if (typeLowering .isTrivial ())
272
277
conv = ParameterConvention::Direct_Unowned;
273
278
else
274
279
conv = ParameterConvention::Direct_Guaranteed;
275
280
return {type, conv};
276
281
};
277
282
// Given a type, returns its formal SIL result info.
278
- auto getFormalResultInfo = [&module ](CanType type) -> SILResultInfo {
279
- SILType silTy = SILType::getPrimitiveObjectType (type);
283
+ auto getFormalResultInfo = [&](CanType type) -> SILResultInfo {
284
+ auto &typeLowering = module . Types . getTypeLowering (type);
280
285
ResultConvention conv;
281
- if (SILModuleConventions::isPassedIndirectlyInSIL (silTy, module ))
286
+ if (typeLowering. isFormallyReturnedIndirectly ( ))
282
287
conv = ResultConvention::Indirect;
283
- else if (silTy .isTrivial (module ))
288
+ else if (typeLowering .isTrivial ())
284
289
conv = ResultConvention::Unowned;
285
290
else
286
291
conv = ResultConvention::Owned;
@@ -312,7 +317,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
312
317
result.getConvention ()
313
318
});
314
319
auto differentialType = SILFunctionType::get (
315
- getGenericSignature () , ExtInfo (), SILCoroutineKind::None,
320
+ /* genericSignature */ nullptr , ExtInfo (), SILCoroutineKind::None,
316
321
ParameterConvention::Direct_Guaranteed, tangentParams, {},
317
322
tangentResults, None, ctx);
318
323
SmallVector<SILResultInfo, 8 > jvpResults (getResults ().begin (),
@@ -343,7 +348,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
343
348
getAssociatedType (param.getType (), cotangentDependentType)));
344
349
}
345
350
auto pullbackType = SILFunctionType::get (
346
- getGenericSignature () , ExtInfo (), SILCoroutineKind::None,
351
+ /* genericSignature */ nullptr , ExtInfo (), SILCoroutineKind::None,
347
352
ParameterConvention::Direct_Guaranteed, cotangentParams, {},
348
353
cotangentResults, {}, ctx);
349
354
SmallVector<SILResultInfo, 8 > vjpResults (getResults ().begin (),
0 commit comments