Skip to content

Commit 808c7ec

Browse files
committed
[AutoDiff] Fix two derivative type calculation bugs caught by RequirementMachine.
1. When calculating the differential type of an original function with an inout parameter and when the inout parameter has a type parameter, the inout parameter should get a generic parameter in the subst generic signature of the differential but it currently doesn't. This causes SILGen to attempt to reabstract the differential value in the JVP protocol witness thunk, whilst the generic signature is lacking requirements, leading to a requirement machine error. This patch fixes the calculation so that the JVP's result type (the differential type) always matches the witness thunk's result type. Wrong type: ```swift sil private [transparent] [thunk] [ossa] @... <τ_0_0 where τ_0_0 : Differentiable> (...) -> @owned @callee_guaranteed @Substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector> { %6 = differentiable_function_extract [jvp] %5 : $@differentiable(reverse) @convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @noDerivative @inout τ_0_0, @noDerivative SR_13305_Struct) -> () // user: %7 HERE ====> %7 = apply %6<τ_0_0>(%0, %1, %3) : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @inout τ_0_0, SR_13305_Struct) -> @owned @callee_guaranteed @Substituted <τ_0_0> (@in_guaranteed τ_0_0) -> @out τ_0_0 for <τ_0_0.TangentVector> ``` Should be: ```swift %7 = apply %6<τ_0_0>(%0, %1, %3) : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @inout τ_0_0, SR_13305_Struct) -> @owned @callee_guaranteed @Substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector> ``` 2. `TypeConverter::makeConstantInterfaceType` is not passing down the derivative generic signature to `SILFunctionType::getAutoDiffDerivativeFunctionType` for class methods, and this was caught by RequirementMachine during vtable emission. This patch fixes that. Partially resolves rdar://82549134. The only remaining tests that require `-requirement-machine=off` are SILOptimizer/semantic_member_accessors_sil.swift and SILOptimizer/differentiation_diagnostics.swift which I will fix next. Then I'll do a proper fix for workaround #39416.
1 parent e77b14a commit 808c7ec

File tree

10 files changed

+90
-93
lines changed

10 files changed

+90
-93
lines changed

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 75 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,32 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
396396
GenericSignature()).getCanonicalSignature();
397397
}
398398

399+
/// Given an original type, computes its tangent type for the purpose of
400+
/// building a linear map using this type. When the original type is an
401+
/// archetype or contains a type parameter, appends a new generic parameter and
402+
/// a corresponding replacement type to the given containers.
403+
static CanType getAutoDiffTangentTypeForLinearMap(
404+
Type originalType,
405+
LookupConformanceFn lookupConformance,
406+
SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
407+
SmallVectorImpl<Type> &substReplacements,
408+
ASTContext &context
409+
) {
410+
auto maybeTanType = originalType->getAutoDiffTangentSpace(lookupConformance);
411+
assert(maybeTanType && "Type does not have a tangent space?");
412+
auto tanType = maybeTanType->getCanonicalType();
413+
// If concrete, the tangent type is concrete.
414+
if (!tanType->hasArchetype() && !tanType->hasTypeParameter())
415+
return tanType;
416+
// Otherwise, the tangent type is a new generic parameter substituted for the
417+
// tangent type.
418+
auto gpIndex = substGenericParams.size();
419+
auto gpType = CanGenericTypeParamType::get(0, gpIndex, context);
420+
substGenericParams.push_back(gpType);
421+
substReplacements.push_back(tanType);
422+
return gpType;
423+
}
424+
399425
/// Returns the differential type for the given original function type,
400426
/// parameter indices, and result index.
401427
static CanSILFunctionType getAutoDiffDifferentialType(
@@ -471,45 +497,32 @@ static CanSILFunctionType getAutoDiffDifferentialType(
471497
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
472498
SmallVector<SILParameterInfo, 8> differentialParams;
473499
for (auto &param : diffParams) {
474-
auto paramTan =
475-
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
476-
assert(paramTan && "Parameter type does not have a tangent space?");
477-
auto paramTanType = paramTan->getCanonicalType();
478-
auto paramConv = getTangentParameterConvention(paramTanType,
479-
param.getConvention());
480-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
481-
differentialParams.push_back(
482-
{paramTan->getCanonicalType(), paramConv});
483-
} else {
484-
auto gpIndex = substGenericParams.size();
485-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
486-
substGenericParams.push_back(gpType);
487-
substReplacements.push_back(paramTanType);
488-
differentialParams.push_back({gpType, paramConv});
489-
}
500+
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
501+
param.getInterfaceType(), lookupConformance,
502+
substGenericParams, substReplacements, ctx);
503+
auto paramConv = getTangentParameterConvention(
504+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
505+
param.getInterfaceType()
506+
->getAutoDiffTangentSpace(lookupConformance)
507+
->getCanonicalType(),
508+
param.getConvention());
509+
differentialParams.push_back({paramTanType, paramConv});
490510
}
491511
SmallVector<SILResultInfo, 1> differentialResults;
492512
for (auto resultIndex : resultIndices->getIndices()) {
493513
// Handle formal original result.
494514
if (resultIndex < originalFnTy->getNumResults()) {
495515
auto &result = originalResults[resultIndex];
496-
auto resultTan =
497-
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
498-
assert(resultTan && "Result type does not have a tangent space?");
499-
auto resultTanType = resultTan->getCanonicalType();
500-
auto resultConv =
501-
getTangentResultConvention(resultTanType, result.getConvention());
502-
if (!resultTanType->hasArchetype() &&
503-
!resultTanType->hasTypeParameter()) {
504-
differentialResults.push_back(
505-
{resultTan->getCanonicalType(), resultConv});
506-
} else {
507-
auto gpIndex = substGenericParams.size();
508-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
509-
substGenericParams.push_back(gpType);
510-
substReplacements.push_back(resultTanType);
511-
differentialResults.push_back({gpType, resultConv});
512-
}
516+
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
517+
result.getInterfaceType(), lookupConformance,
518+
substGenericParams, substReplacements, ctx);
519+
auto resultConv = getTangentResultConvention(
520+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
521+
result.getInterfaceType()
522+
->getAutoDiffTangentSpace(lookupConformance)
523+
->getCanonicalType(),
524+
result.getConvention());
525+
differentialResults.push_back({resultTanType, resultConv});
513526
continue;
514527
}
515528
// Handle original `inout` parameter.
@@ -524,11 +537,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
524537
if (parameterIndices->contains(paramIndex))
525538
continue;
526539
auto inoutParam = originalFnTy->getParameters()[paramIndex];
527-
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
528-
lookupConformance);
529-
assert(paramTan && "Parameter type does not have a tangent space?");
540+
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
541+
inoutParam.getInterfaceType(), lookupConformance,
542+
substGenericParams, substReplacements, ctx);
530543
differentialResults.push_back(
531-
{paramTan->getCanonicalType(), ResultConvention::Indirect});
544+
{inoutParamTanType, ResultConvention::Indirect});
532545
}
533546

534547
SubstitutionMap substitutions;
@@ -635,23 +648,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
635648
// Handle formal original result.
636649
if (resultIndex < originalFnTy->getNumResults()) {
637650
auto &origRes = originalResults[resultIndex];
638-
auto resultTan = origRes.getInterfaceType()->getAutoDiffTangentSpace(
639-
lookupConformance);
640-
assert(resultTan && "Result type does not have a tangent space?");
641-
auto resultTanType = resultTan->getCanonicalType();
642-
auto paramTanConvention = getTangentParameterConventionForOriginalResult(
643-
resultTanType, origRes.getConvention());
644-
if (!resultTanType->hasArchetype() &&
645-
!resultTanType->hasTypeParameter()) {
646-
auto resultTanType = resultTan->getCanonicalType();
647-
pullbackParams.push_back({resultTanType, paramTanConvention});
648-
} else {
649-
auto gpIndex = substGenericParams.size();
650-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
651-
substGenericParams.push_back(gpType);
652-
substReplacements.push_back(resultTanType);
653-
pullbackParams.push_back({gpType, paramTanConvention});
654-
}
651+
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
652+
origRes.getInterfaceType(), lookupConformance,
653+
substGenericParams, substReplacements, ctx);
654+
auto paramConv = getTangentParameterConventionForOriginalResult(
655+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
656+
origRes.getInterfaceType()
657+
->getAutoDiffTangentSpace(lookupConformance)
658+
->getCanonicalType(),
659+
origRes.getConvention());
660+
pullbackParams.push_back({resultTanType, paramConv});
655661
continue;
656662
}
657663
// Handle original `inout` parameter.
@@ -661,28 +667,18 @@ static CanSILFunctionType getAutoDiffPullbackType(
661667
auto paramIndex =
662668
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
663669
auto inoutParam = originalFnTy->getParameters()[paramIndex];
664-
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
665-
lookupConformance);
666-
assert(paramTan && "Parameter type does not have a tangent space?");
667670
// The pullback parameter convention depends on whether the original `inout`
668671
// paramater is a differentiability parameter.
669672
// - If yes, the pullback parameter convention is `@inout`.
670673
// - If no, the pullback parameter convention is `@in_guaranteed`.
674+
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
675+
inoutParam.getInterfaceType(), lookupConformance,
676+
substGenericParams, substReplacements, ctx);
671677
bool isWrtInoutParameter = parameterIndices->contains(paramIndex);
672678
auto paramTanConvention = isWrtInoutParameter
673-
? inoutParam.getConvention()
674-
: ParameterConvention::Indirect_In_Guaranteed;
675-
auto paramTanType = paramTan->getCanonicalType();
676-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
677-
pullbackParams.push_back(
678-
SILParameterInfo(paramTanType, paramTanConvention));
679-
} else {
680-
auto gpIndex = substGenericParams.size();
681-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
682-
substGenericParams.push_back(gpType);
683-
substReplacements.push_back(paramTanType);
684-
pullbackParams.push_back({gpType, paramTanConvention});
685-
}
679+
? inoutParam.getConvention()
680+
: ParameterConvention::Indirect_In_Guaranteed;
681+
pullbackParams.push_back({inoutParamTanType, paramTanConvention});
686682
}
687683

688684
// Collect pullback results.
@@ -694,21 +690,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
694690
// and always appear as pullback parameters.
695691
if (param.isIndirectInOut())
696692
continue;
697-
auto paramTan =
698-
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
699-
assert(paramTan && "Parameter type does not have a tangent space?");
700-
auto paramTanType = paramTan->getCanonicalType();
693+
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
694+
param.getInterfaceType(), lookupConformance,
695+
substGenericParams, substReplacements, ctx);
701696
auto resultTanConvention = getTangentResultConventionForOriginalParameter(
702-
paramTanType, param.getConvention());
703-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
704-
pullbackResults.push_back({paramTanType, resultTanConvention});
705-
} else {
706-
auto gpIndex = substGenericParams.size();
707-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
708-
substGenericParams.push_back(gpType);
709-
substReplacements.push_back(paramTanType);
710-
pullbackResults.push_back({gpType, resultTanConvention});
711-
}
697+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
698+
param.getInterfaceType()
699+
->getAutoDiffTangentSpace(lookupConformance)
700+
->getCanonicalType(),
701+
param.getConvention());
702+
pullbackResults.push_back({paramTanType, resultTanConvention});
712703
}
713704
SubstitutionMap substitutions;
714705
if (!substGenericParams.empty()) {

lib/SIL/IR/TypeLowering.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2599,9 +2599,15 @@ CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) {
25992599
if (auto *derivativeId = c.getDerivativeFunctionIdentifier()) {
26002600
auto originalFnTy =
26012601
makeConstantInterfaceType(c.asAutoDiffOriginalFunction());
2602+
// Protocol witness derivatives cannot have a derivative generic signature,
2603+
// but class method derivatives can.
2604+
GenericSignature derivativeGenSig = nullptr;
2605+
if (isa<ClassDecl>(c.getAbstractFunctionDecl()->getInnermostTypeContext()))
2606+
derivativeGenSig = derivativeId->getDerivativeGenericSignature();
26022607
auto *derivativeFnTy = originalFnTy->getAutoDiffDerivativeFunctionType(
26032608
derivativeId->getParameterIndices(), derivativeId->getKind(),
2604-
LookUpConformanceInModule(&M));
2609+
LookUpConformanceInModule(&M),
2610+
derivativeGenSig);
26052611
return cast<AnyFunctionType>(derivativeFnTy->getCanonicalType());
26062612
}
26072613

test/AutoDiff/SIL/Parse/sildeclref.sil

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-sil-opt %s -module-name=sildeclref_parse -requirement-machine=off | %target-sil-opt -module-name=sildeclref_parse -requirement-machine=off | %FileCheck %s
1+
// RUN: %target-sil-opt %s -module-name=sildeclref_parse | %target-sil-opt -module-name=sildeclref_parse | %FileCheck %s
22
// Parse AutoDiff derivative SILDeclRefs via `witness_method` and `class_method` instructions.
33

44
import Swift

test/AutoDiff/SILGen/vtable.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-silgen %s -requirement-machine=off | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
22

33
// Test derivative function vtable entries for `@differentiable` class members:
44
// - Methods.

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s
1+
// RUN: %target-swift-frontend -emit-sil -requirement-machine=off -verify %s
22

33
// Test differentiation transform diagnostics.
44

test/AutoDiff/compiler_crashers_fixed/sr12744-unhandled-pullback-indirect-result.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
22

33
// SR-12744: Pullback generation crash for unhandled indirect result.
44
// May be due to inconsistent derivative function type calculation logic in

test/AutoDiff/compiler_crashers_fixed/sr14240-symbol-in-ir-file-not-tbd-file.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
1+
// RUN: %target-run-simple-swift
22

33
// REQUIRES: executable_test
44

test/AutoDiff/mangling.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify -requirement-machine=off %s | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify %s | %FileCheck %s
22

33
import _Differentiation
44

test/AutoDiff/validation-test/forward_mode_simple.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation -Xfrontend -requirement-machine=off)
1+
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
22
// REQUIRES: executable_test
33

44
import StdlibUnittest

test/AutoDiff/validation-test/inout_parameters.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
1+
// RUN: %target-run-simple-swift
22
// REQUIRES: executable_test
33

44
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.

0 commit comments

Comments
 (0)