@@ -313,6 +313,63 @@ class TypeSubstCloner : public SILClonerWithScopes<ImplClass> {
313
313
super::visitDestroyValueInst (Destroy);
314
314
}
315
315
316
+ void visitDifferentiableFunctionExtractInst (
317
+ DifferentiableFunctionExtractInst *dfei) {
318
+ // If the extractee is the original function, do regular cloning.
319
+ if (dfei->getExtractee () ==
320
+ NormalDifferentiableFunctionTypeComponent::Original) {
321
+ super::visitDifferentiableFunctionExtractInst (dfei);
322
+ return ;
323
+ }
324
+ // If the extractee is a derivative function, check whether the *remapped
325
+ // derivative function type* (BC) is equal to the *derivative remapped
326
+ // function type* (AD).
327
+ //
328
+ // +----------------+ remap +-------------------------+
329
+ // | orig. fn type | -------(A)------> | remapped orig. fn type |
330
+ // +----------------+ +-------------------------+
331
+ // | |
332
+ // (B, SILGen) getAutoDiffDerivativeFunctionType (D, here)
333
+ // V V
334
+ // +----------------+ remap +-------------------------+
335
+ // | deriv. fn type | -------(C)------> | remapped deriv. fn type |
336
+ // +----------------+ +-------------------------+
337
+ //
338
+ // (AD) does not always commute with (BC):
339
+ // - (AD) is the result of remapping, then computing the derivative type.
340
+ // This is the default cloning behavior, but may break invariants in the
341
+ // initial SIL generated by SILGen.
342
+ // - (BC) is the result of computing the derivative type (SILGen), then
343
+ // remapping. This is the expected type, preserving invariants from
344
+ // earlier transforms.
345
+ //
346
+ // If (AD) is not equal to (BC), use (BC) as the explicit type.
347
+ SILType remappedOrigType = getOpType (dfei->getOperand ()->getType ());
348
+ auto remappedOrigFnType = remappedOrigType.castTo <SILFunctionType>();
349
+ auto derivativeRemappedFnType =
350
+ remappedOrigFnType
351
+ ->getAutoDiffDerivativeFunctionType (
352
+ remappedOrigFnType->getDifferentiabilityParameterIndices (),
353
+ /* resultIndex*/ 0 , dfei->getDerivativeFunctionKind (),
354
+ getBuilder ().getModule ().Types ,
355
+ LookUpConformanceInModule (SwiftMod))
356
+ ->getWithoutDifferentiability ();
357
+ SILType remappedDerivativeFnType = getOpType (dfei->getType ());
358
+ // If remapped derivative type and derivative remapped type are equal, do
359
+ // regular cloning.
360
+ if (SILType::getPrimitiveObjectType (derivativeRemappedFnType) ==
361
+ remappedDerivativeFnType) {
362
+ super::visitDifferentiableFunctionExtractInst (dfei);
363
+ return ;
364
+ }
365
+ // Otherwise, explicitly use the remapped derivative type.
366
+ recordClonedInstruction (
367
+ dfei,
368
+ getBuilder ().createDifferentiableFunctionExtract (
369
+ getOpLocation (dfei->getLoc ()), dfei->getExtractee (),
370
+ getOpValue (dfei->getOperand ()), remappedDerivativeFnType));
371
+ }
372
+
316
373
// / One abstract function in the debug info can only have one set of variables
317
374
// / and types. This function determines whether applying the substitutions in
318
375
// / \p SubsMap on the generic signature \p Sig will change the generic type
0 commit comments