@@ -255,14 +255,29 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
255
255
ArrayRef<SILResultInfo> newResults,
256
256
GenericSignature *genericSignature)
257
257
-> CanSILFunctionType {
258
- return SILFunctionType::get (genericSignature
259
- ? genericSignature
260
- : base->getGenericSignature (),
258
+ if (!genericSignature)
259
+ genericSignature = base->getGenericSignature ();
260
+ // If generic signature is specified, use it to canonical result types.
261
+ // This is important for consistent typing for types like:
262
+ // <T : Differentiable, T == T.CotangentVector> (...) ->
263
+ // (@out T.CotangentVector)
264
+ // Which should be canonicalized to:
265
+ // <T : Differentiable, T == T.CotangentVector> (...) ->
266
+ // (@out T)
267
+ ArrayRef<SILResultInfo> results =
268
+ genericSignature
269
+ ? map<SmallVector<SILResultInfo, 4 >>(
270
+ newResults, [&](SILResultInfo resInfo) {
271
+ return resInfo.getWithType (
272
+ resInfo.getType ()->getCanonicalType (genericSignature));
273
+ })
274
+ : newResults;
275
+ return SILFunctionType::get (genericSignature,
261
276
base->getExtInfo (),
262
277
base->getCoroutineKind (),
263
278
base->getCalleeConvention (),
264
279
base->getParameters (), base->getYields (),
265
- newResults , base->getOptionalErrorResult (), ctx,
280
+ results , base->getOptionalErrorResult (), ctx,
266
281
base->getWitnessMethodConformanceOrNone ());
267
282
};
268
283
@@ -328,17 +343,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
328
343
results.push_back ({closureType, ResultConvention::Owned});
329
344
CanSILFunctionType associatedFunction =
330
345
withNewResults (curryLevels.back (), results,
331
- curryLevels. size () == 1 ? whereClauseGenSig : nullptr );
346
+ whereClauseGenSig);
332
347
333
348
auto curryLevelsWithoutLast =
334
349
ArrayRef<SILFunctionType *>(curryLevels).drop_back (1 );
335
- for (auto pair : enumerate(reversed (curryLevelsWithoutLast))) {
336
- unsigned i = pair.index ();
337
- auto *curryLevel = pair.value ();
350
+ for (auto *curryLevel : reversed (curryLevelsWithoutLast))
338
351
associatedFunction = withNewResults (
339
352
curryLevel, {{associatedFunction, ResultConvention::Owned}},
340
- i == curryLevelsWithoutLast.size () - 1 ? whereClauseGenSig : nullptr );
341
- }
353
+ whereClauseGenSig);
342
354
return associatedFunction;
343
355
}
344
356
0 commit comments