Skip to content

Commit 15f512b

Browse files
author
marcrasi
authored
[AutoDiff] fix SR-12493 (#30817)
We simply needed to upstream `TypeSubstCloner::visitDifferentiableFunctionExtractInst`. The code has a detailed comment explaining what it does.
1 parent c0960d3 commit 15f512b

File tree

3 files changed

+58
-4
lines changed

3 files changed

+58
-4
lines changed

include/swift/SIL/TypeSubstCloner.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,63 @@ class TypeSubstCloner : public SILClonerWithScopes<ImplClass> {
313313
super::visitDestroyValueInst(Destroy);
314314
}
315315

316+
void visitDifferentiableFunctionExtractInst(
317+
DifferentiableFunctionExtractInst *dfei) {
318+
// If the extractee is the original function, do regular cloning.
319+
if (dfei->getExtractee() ==
320+
NormalDifferentiableFunctionTypeComponent::Original) {
321+
super::visitDifferentiableFunctionExtractInst(dfei);
322+
return;
323+
}
324+
// If the extractee is a derivative function, check whether the *remapped
325+
// derivative function type* (BC) is equal to the *derivative remapped
326+
// function type* (AD).
327+
//
328+
// +----------------+ remap +-------------------------+
329+
// | orig. fn type | -------(A)------> | remapped orig. fn type |
330+
// +----------------+ +-------------------------+
331+
// | |
332+
// (B, SILGen) getAutoDiffDerivativeFunctionType (D, here)
333+
// V V
334+
// +----------------+ remap +-------------------------+
335+
// | deriv. fn type | -------(C)------> | remapped deriv. fn type |
336+
// +----------------+ +-------------------------+
337+
//
338+
// (AD) does not always commute with (BC):
339+
// - (AD) is the result of remapping, then computing the derivative type.
340+
// This is the default cloning behavior, but may break invariants in the
341+
// initial SIL generated by SILGen.
342+
// - (BC) is the result of computing the derivative type (SILGen), then
343+
// remapping. This is the expected type, preserving invariants from
344+
// earlier transforms.
345+
//
346+
// If (AD) is not equal to (BC), use (BC) as the explicit type.
347+
SILType remappedOrigType = getOpType(dfei->getOperand()->getType());
348+
auto remappedOrigFnType = remappedOrigType.castTo<SILFunctionType>();
349+
auto derivativeRemappedFnType =
350+
remappedOrigFnType
351+
->getAutoDiffDerivativeFunctionType(
352+
remappedOrigFnType->getDifferentiabilityParameterIndices(),
353+
/*resultIndex*/ 0, dfei->getDerivativeFunctionKind(),
354+
getBuilder().getModule().Types,
355+
LookUpConformanceInModule(SwiftMod))
356+
->getWithoutDifferentiability();
357+
SILType remappedDerivativeFnType = getOpType(dfei->getType());
358+
// If remapped derivative type and derivative remapped type are equal, do
359+
// regular cloning.
360+
if (SILType::getPrimitiveObjectType(derivativeRemappedFnType) ==
361+
remappedDerivativeFnType) {
362+
super::visitDifferentiableFunctionExtractInst(dfei);
363+
return;
364+
}
365+
// Otherwise, explicitly use the remapped derivative type.
366+
recordClonedInstruction(
367+
dfei,
368+
getBuilder().createDifferentiableFunctionExtract(
369+
getOpLocation(dfei->getLoc()), dfei->getExtractee(),
370+
getOpValue(dfei->getOperand()), remappedDerivativeFnType));
371+
}
372+
316373
/// One abstract function in the debug info can only have one set of variables
317374
/// and types. This function determines whether applying the substitutions in
318375
/// \p SubsMap on the generic signature \p Sig will change the generic type

test/AutoDiff/compiler_crashers/sr12493-differentiable-function-extract-subst-function-type.swift renamed to test/AutoDiff/compiler_crashers_fixed/sr12493-differentiable-function-extract-subst-function-type.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: not %target-build-swift -O %s
1+
// RUN: %target-build-swift -O %s
22

33
// SR-12493: SIL verification error regarding substituted function types and
44
// `differentiable_function_extract` instruction. Occurs only with `-O`.

test/AutoDiff/stdlib/differential_operators.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
// RUN: %target-run %t/differential_operators
55
// REQUIRES: executable_test
66

7-
// FIXME(SR-12493): Disable test for `-O` due to SIL verification error.
8-
// REQUIRES: swift_test_mode_optimize_none
9-
107
import _Differentiation
118

129
import StdlibUnittest

0 commit comments

Comments
 (0)