Skip to content

Commit 9ae76a4

Browse files
committed
Fix remaining test failures.
Remaining test failures were related to `differentiable_function_extract` cloning during generic specialization. Add explanatory comment to `TypeSubstCloner::visitDifferentiableFunctionExtractInst`: +----------------+ remap +-------------------------+ | orig. fn type | -------(A)------> | remapped orig. fn type | +----------------+ +-------------------------+ | | (B, SILGen) getAutoDiffDerivativeFunctionType (D, here) V V +----------------+ remap +-------------------------+ | deriv. fn type | -------(C)------> | remapped deriv. fn type | +----------------+ +-------------------------+ (AD) does not always commute with (BC): - (AD) is the result of remapping, then computing the derivative type. This is the default cloning behavior, but may break invariants in the initial SIL generated by SILGen. - (BC) is the result of computing the derivative type (SILGen), then remapping. This is the expected type, preserving invariants from earlier transforms. If (AD) is not equal to (BC), (BC) must be used as the explicit type. Done with @marcrasi.
1 parent 2035bce commit 9ae76a4

File tree

5 files changed

+79
-9
lines changed

5 files changed

+79
-9
lines changed

include/swift/SIL/SILCloner.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2859,10 +2859,13 @@ template<typename ImplClass>
28592859
void SILCloner<ImplClass>::
28602860
visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *Inst) {
28612861
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
2862+
Optional<SILType> explicitExtracteeType = None;
2863+
if (Inst->hasExplicitExtracteeType())
2864+
explicitExtracteeType = Inst->getType();
28622865
recordClonedInstruction(
28632866
Inst, getBuilder().createDifferentiableFunctionExtract(
28642867
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
2865-
getOpValue(Inst->getFunctionOperand())));
2868+
getOpValue(Inst->getFunctionOperand()), explicitExtracteeType));
28662869
}
28672870

28682871
template<typename ImplClass>

include/swift/SIL/TypeSubstCloner.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,65 @@ class TypeSubstCloner : public SILClonerWithScopes<ImplClass> {
315315
super::visitDestroyValueInst(Destroy);
316316
}
317317

318+
// SWIFT_ENABLE_TENSORFLOW
319+
void visitDifferentiableFunctionExtractInst(
320+
DifferentiableFunctionExtractInst *dfei) {
321+
// If the extractee is the original function, do regular cloning.
322+
if (dfei->getExtractee() ==
323+
NormalDifferentiableFunctionTypeComponent::Original) {
324+
super::visitDifferentiableFunctionExtractInst(dfei);
325+
return;
326+
}
327+
// If the extractee is a derivative function, check whether the *remapped
328+
// derivative function type* (BC) is equal to the *derivative remapped
329+
// function type* (AD).
330+
//
331+
// +----------------+ remap +-------------------------+
332+
// | orig. fn type | -------(A)------> | remapped orig. fn type |
333+
// +----------------+ +-------------------------+
334+
// | |
335+
// (B, SILGen) getAutoDiffDerivativeFunctionType (D, here)
336+
// V V
337+
// +----------------+ remap +-------------------------+
338+
// | deriv. fn type | -------(C)------> | remapped deriv. fn type |
339+
// +----------------+ +-------------------------+
340+
//
341+
// (AD) does not always commute with (BC):
342+
// - (AD) is the result of remapping, then computing the derivative type.
343+
// This is the default cloning behavior, but may break invariants in the
344+
// initial SIL generated by SILGen.
345+
// - (BC) is the result of computing the derivative type (SILGen), then
346+
// remapping. This is the expected type, preserving invariants from
347+
// earlier transforms.
348+
//
349+
// If (AD) is not equal to (BC), use (BC) as the explicit type.
350+
SILType remappedOrigType = getOpType(dfei->getFunctionOperand()->getType());
351+
auto remappedOrigFnType = remappedOrigType.castTo<SILFunctionType>();
352+
auto derivativeRemappedFnType =
353+
remappedOrigFnType
354+
->getAutoDiffDerivativeFunctionType(
355+
remappedOrigFnType->getDifferentiabilityParameterIndices(),
356+
/*resultIndex*/ 0, dfei->getDerivativeFunctionKind(),
357+
getBuilder().getModule().Types,
358+
LookUpConformanceInModule(SwiftMod))
359+
->getWithoutDifferentiability();
360+
SILType remappedDerivativeFnType = getOpType(dfei->getType());
361+
// If remapped derivative type and derivative remapped type are equal, do
362+
// regular cloning.
363+
if (SILType::getPrimitiveObjectType(derivativeRemappedFnType) ==
364+
remappedDerivativeFnType) {
365+
super::visitDifferentiableFunctionExtractInst(dfei);
366+
return;
367+
}
368+
// Otherwise, explicitly use the remapped derivative type.
369+
recordClonedInstruction(
370+
dfei,
371+
getBuilder().createDifferentiableFunctionExtract(
372+
getOpLocation(dfei->getLoc()), dfei->getExtractee(),
373+
getOpValue(dfei->getFunctionOperand()), remappedDerivativeFnType));
374+
}
375+
// SWIFT_ENABLE_TENSORFLOW END
376+
318377
/// One abstract function in the debug info can only have one set of variables
319378
/// and types. This function determines whether applying the substitutions in
320379
/// \p SubsMap on the generic signature \p Sig will change the generic type

lib/SIL/SILInstructions.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,6 @@ getExtracteeType(
709709
auto resultFnTy = originalFnTy->getAutoDiffDerivativeFunctionType(
710710
fnTy->getDifferentiabilityParameterIndices(), /*resultIndex*/ 0,
711711
*kindOpt, module.Types, LookUpConformanceInModule(module.getSwiftModule()));
712-
llvm::dbgs() << "getExtracteeType for " << function << ":\n" << SILType::getPrimitiveObjectType(resultFnTy) << "\n";
713712
return SILType::getPrimitiveObjectType(resultFnTy);
714713
}
715714

@@ -725,8 +724,14 @@ DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst(
725724
HasExplicitExtracteeType(extracteeType.hasValue()) {
726725
#ifndef NDEBUG
727726
if (extracteeType.hasValue()) {
728-
assert(module.getStage() == SILStage::Lowered &&
729-
"Explicit type is valid only in lowered SIL");
727+
// Note: explicit extractee type is used to avoid inconsistent typing in:
728+
// - Canonical SIL, due to generic specialization.
729+
// - Lowered SIL, due to LoadableByAddress.
730+
// See `TypeSubstCloner::visitDifferentiableFunctionExtractInst` for an
731+
// explanation of how explicit extractee type is used.
732+
assert((module.getStage() == SILStage::Canonical ||
733+
module.getStage() == SILStage::Lowered) &&
734+
"Explicit type is valid only in canonical or lowered SIL");
730735
}
731736
#endif
732737
}

lib/SILOptimizer/Utils/Differentiation/JVPEmitter.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,8 +1431,12 @@ void JVPEmitter::visitReturnInst(ReturnInst *ri) {
14311431
loc, differentialRef, jvpSubstMap, {diffStructVal},
14321432
ParameterConvention::Direct_Guaranteed);
14331433

1434-
auto differentialType = jvp->getLoweredFunctionType()->getResults().back().getSILStorageInterfaceType();
1435-
differentialType = differentialType.substGenericArgs(getModule(), jvpSubstMap, TypeExpansionContext::minimal());
1434+
auto differentialType = jvp->getLoweredFunctionType()
1435+
->getResults()
1436+
.back()
1437+
.getSILStorageInterfaceType();
1438+
differentialType = differentialType.substGenericArgs(
1439+
getModule(), jvpSubstMap, TypeExpansionContext::minimal());
14361440
differentialType = differentialType.subst(getModule(), jvpSubstMap);
14371441
auto differentialFnType = differentialType.castTo<SILFunctionType>();
14381442

@@ -1452,8 +1456,7 @@ void JVPEmitter::visitReturnInst(ReturnInst *ri) {
14521456
// applies, the return type may be ABI-incomaptible with the type of the
14531457
// partially applied differential. In these cases, produce an undef and rely
14541458
// on other code to emit a diagnostic.
1455-
differentialValue = SILUndef::get(differentialType,
1456-
*differentialPartialApply->getFunction());
1459+
differentialValue = SILUndef::get(differentialType, *jvp);
14571460
}
14581461

14591462
// Return a tuple of the original result and differential.

lib/SILOptimizer/Utils/Differentiation/VJPEmitter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ void VJPEmitter::visitReturnInst(ReturnInst *ri) {
363363
// applies, the return type may be ABI-incomaptible with the type of the
364364
// partially applied pullback. In these cases, produce an undef and rely on
365365
// other code to emit a diagnostic.
366-
pullbackValue = SILUndef::get(pullbackType, *pullbackPartialApply->getFunction());
366+
pullbackValue = SILUndef::get(pullbackType, *vjp);
367367
}
368368

369369
// Return a tuple of the original result and pullback.

0 commit comments

Comments
 (0)