Skip to content

Commit 2629654

Browse files
authored
Revert "[AutoDiff] Fix two derivative type calculation bugs caught by RequirementMachine"
1 parent 246896f commit 2629654

File tree

10 files changed

+93
-90
lines changed

10 files changed

+93
-90
lines changed

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 84 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -408,32 +408,6 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
408408
return buildGenericSignature(ctx, sig, {}, reqs).getCanonicalSignature();
409409
}
410410

411-
/// Given an original type, computes its tangent type for the purpose of
412-
/// building a linear map using this type. When the original type is an
413-
/// archetype or contains a type parameter, appends a new generic parameter and
414-
/// a corresponding replacement type to the given containers.
415-
static CanType getAutoDiffTangentTypeForLinearMap(
416-
Type originalType,
417-
LookupConformanceFn lookupConformance,
418-
SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
419-
SmallVectorImpl<Type> &substReplacements,
420-
ASTContext &context
421-
) {
422-
auto maybeTanType = originalType->getAutoDiffTangentSpace(lookupConformance);
423-
assert(maybeTanType && "Type does not have a tangent space?");
424-
auto tanType = maybeTanType->getCanonicalType();
425-
// If concrete, the tangent type is concrete.
426-
if (!tanType->hasArchetype() && !tanType->hasTypeParameter())
427-
return tanType;
428-
// Otherwise, the tangent type is a new generic parameter substituted for the
429-
// tangent type.
430-
auto gpIndex = substGenericParams.size();
431-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, context);
432-
substGenericParams.push_back(gpType);
433-
substReplacements.push_back(tanType);
434-
return gpType;
435-
}
436-
437411
/// Returns the differential type for the given original function type,
438412
/// parameter indices, and result index.
439413
static CanSILFunctionType getAutoDiffDifferentialType(
@@ -510,32 +484,45 @@ static CanSILFunctionType getAutoDiffDifferentialType(
510484
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
511485
SmallVector<SILParameterInfo, 8> differentialParams;
512486
for (auto &param : diffParams) {
513-
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
514-
param.getInterfaceType(), lookupConformance,
515-
substGenericParams, substReplacements, ctx);
516-
auto paramConv = getTangentParameterConvention(
517-
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
518-
param.getInterfaceType()
519-
->getAutoDiffTangentSpace(lookupConformance)
520-
->getCanonicalType(),
521-
param.getConvention());
522-
differentialParams.push_back({paramTanType, paramConv});
487+
auto paramTan =
488+
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
489+
assert(paramTan && "Parameter type does not have a tangent space?");
490+
auto paramTanType = paramTan->getCanonicalType();
491+
auto paramConv = getTangentParameterConvention(paramTanType,
492+
param.getConvention());
493+
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
494+
differentialParams.push_back(
495+
{paramTan->getCanonicalType(), paramConv});
496+
} else {
497+
auto gpIndex = substGenericParams.size();
498+
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
499+
substGenericParams.push_back(gpType);
500+
substReplacements.push_back(paramTanType);
501+
differentialParams.push_back({gpType, paramConv});
502+
}
523503
}
524504
SmallVector<SILResultInfo, 1> differentialResults;
525505
for (auto resultIndex : resultIndices->getIndices()) {
526506
// Handle formal original result.
527507
if (resultIndex < originalFnTy->getNumResults()) {
528508
auto &result = originalResults[resultIndex];
529-
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
530-
result.getInterfaceType(), lookupConformance,
531-
substGenericParams, substReplacements, ctx);
532-
auto resultConv = getTangentResultConvention(
533-
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
534-
result.getInterfaceType()
535-
->getAutoDiffTangentSpace(lookupConformance)
536-
->getCanonicalType(),
537-
result.getConvention());
538-
differentialResults.push_back({resultTanType, resultConv});
509+
auto resultTan =
510+
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
511+
assert(resultTan && "Result type does not have a tangent space?");
512+
auto resultTanType = resultTan->getCanonicalType();
513+
auto resultConv =
514+
getTangentResultConvention(resultTanType, result.getConvention());
515+
if (!resultTanType->hasArchetype() &&
516+
!resultTanType->hasTypeParameter()) {
517+
differentialResults.push_back(
518+
{resultTan->getCanonicalType(), resultConv});
519+
} else {
520+
auto gpIndex = substGenericParams.size();
521+
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
522+
substGenericParams.push_back(gpType);
523+
substReplacements.push_back(resultTanType);
524+
differentialResults.push_back({gpType, resultConv});
525+
}
539526
continue;
540527
}
541528
// Handle original `inout` parameter.
@@ -550,11 +537,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
550537
if (parameterIndices->contains(paramIndex))
551538
continue;
552539
auto inoutParam = originalFnTy->getParameters()[paramIndex];
553-
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
554-
inoutParam.getInterfaceType(), lookupConformance,
555-
substGenericParams, substReplacements, ctx);
540+
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
541+
lookupConformance);
542+
assert(paramTan && "Parameter type does not have a tangent space?");
556543
differentialResults.push_back(
557-
{inoutParamTanType, ResultConvention::Indirect});
544+
{paramTan->getCanonicalType(), ResultConvention::Indirect});
558545
}
559546

560547
SubstitutionMap substitutions;
@@ -661,16 +648,23 @@ static CanSILFunctionType getAutoDiffPullbackType(
661648
// Handle formal original result.
662649
if (resultIndex < originalFnTy->getNumResults()) {
663650
auto &origRes = originalResults[resultIndex];
664-
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
665-
origRes.getInterfaceType(), lookupConformance,
666-
substGenericParams, substReplacements, ctx);
667-
auto paramConv = getTangentParameterConventionForOriginalResult(
668-
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
669-
origRes.getInterfaceType()
670-
->getAutoDiffTangentSpace(lookupConformance)
671-
->getCanonicalType(),
672-
origRes.getConvention());
673-
pullbackParams.push_back({resultTanType, paramConv});
651+
auto resultTan = origRes.getInterfaceType()->getAutoDiffTangentSpace(
652+
lookupConformance);
653+
assert(resultTan && "Result type does not have a tangent space?");
654+
auto resultTanType = resultTan->getCanonicalType();
655+
auto paramTanConvention = getTangentParameterConventionForOriginalResult(
656+
resultTanType, origRes.getConvention());
657+
if (!resultTanType->hasArchetype() &&
658+
!resultTanType->hasTypeParameter()) {
659+
auto resultTanType = resultTan->getCanonicalType();
660+
pullbackParams.push_back({resultTanType, paramTanConvention});
661+
} else {
662+
auto gpIndex = substGenericParams.size();
663+
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
664+
substGenericParams.push_back(gpType);
665+
substReplacements.push_back(resultTanType);
666+
pullbackParams.push_back({gpType, paramTanConvention});
667+
}
674668
continue;
675669
}
676670
// Handle original `inout` parameter.
@@ -680,18 +674,28 @@ static CanSILFunctionType getAutoDiffPullbackType(
680674
auto paramIndex =
681675
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
682676
auto inoutParam = originalFnTy->getParameters()[paramIndex];
677+
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
678+
lookupConformance);
679+
assert(paramTan && "Parameter type does not have a tangent space?");
683680
// The pullback parameter convention depends on whether the original `inout`
684681
// paramater is a differentiability parameter.
685682
// - If yes, the pullback parameter convention is `@inout`.
686683
// - If no, the pullback parameter convention is `@in_guaranteed`.
687-
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
688-
inoutParam.getInterfaceType(), lookupConformance,
689-
substGenericParams, substReplacements, ctx);
690684
bool isWrtInoutParameter = parameterIndices->contains(paramIndex);
691685
auto paramTanConvention = isWrtInoutParameter
692-
? inoutParam.getConvention()
693-
: ParameterConvention::Indirect_In_Guaranteed;
694-
pullbackParams.push_back({inoutParamTanType, paramTanConvention});
686+
? inoutParam.getConvention()
687+
: ParameterConvention::Indirect_In_Guaranteed;
688+
auto paramTanType = paramTan->getCanonicalType();
689+
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
690+
pullbackParams.push_back(
691+
SILParameterInfo(paramTanType, paramTanConvention));
692+
} else {
693+
auto gpIndex = substGenericParams.size();
694+
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
695+
substGenericParams.push_back(gpType);
696+
substReplacements.push_back(paramTanType);
697+
pullbackParams.push_back({gpType, paramTanConvention});
698+
}
695699
}
696700

697701
// Collect pullback results.
@@ -703,16 +707,21 @@ static CanSILFunctionType getAutoDiffPullbackType(
703707
// and always appear as pullback parameters.
704708
if (param.isIndirectInOut())
705709
continue;
706-
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
707-
param.getInterfaceType(), lookupConformance,
708-
substGenericParams, substReplacements, ctx);
710+
auto paramTan =
711+
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
712+
assert(paramTan && "Parameter type does not have a tangent space?");
713+
auto paramTanType = paramTan->getCanonicalType();
709714
auto resultTanConvention = getTangentResultConventionForOriginalParameter(
710-
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
711-
param.getInterfaceType()
712-
->getAutoDiffTangentSpace(lookupConformance)
713-
->getCanonicalType(),
714-
param.getConvention());
715-
pullbackResults.push_back({paramTanType, resultTanConvention});
715+
paramTanType, param.getConvention());
716+
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
717+
pullbackResults.push_back({paramTanType, resultTanConvention});
718+
} else {
719+
auto gpIndex = substGenericParams.size();
720+
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
721+
substGenericParams.push_back(gpType);
722+
substReplacements.push_back(paramTanType);
723+
pullbackResults.push_back({gpType, resultTanConvention});
724+
}
716725
}
717726
SubstitutionMap substitutions;
718727
if (!substGenericParams.empty()) {

lib/SIL/IR/TypeLowering.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,15 +2629,9 @@ CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) {
26292629
if (auto *derivativeId = c.getDerivativeFunctionIdentifier()) {
26302630
auto originalFnTy =
26312631
makeConstantInterfaceType(c.asAutoDiffOriginalFunction());
2632-
// Protocol witness derivatives cannot have a derivative generic signature,
2633-
// but class method derivatives can.
2634-
GenericSignature derivativeGenSig = nullptr;
2635-
if (isa<ClassDecl>(c.getAbstractFunctionDecl()->getInnermostTypeContext()))
2636-
derivativeGenSig = derivativeId->getDerivativeGenericSignature();
26372632
auto *derivativeFnTy = originalFnTy->getAutoDiffDerivativeFunctionType(
26382633
derivativeId->getParameterIndices(), derivativeId->getKind(),
2639-
LookUpConformanceInModule(&M),
2640-
derivativeGenSig);
2634+
LookUpConformanceInModule(&M));
26412635
return cast<AnyFunctionType>(derivativeFnTy->getCanonicalType());
26422636
}
26432637

test/AutoDiff/SIL/Parse/sildeclref.sil

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-sil-opt %s -module-name=sildeclref_parse | %target-sil-opt -module-name=sildeclref_parse | %FileCheck %s
1+
// RUN: %target-sil-opt %s -module-name=sildeclref_parse -requirement-machine=off | %target-sil-opt -module-name=sildeclref_parse -requirement-machine=off | %FileCheck %s
22
// Parse AutoDiff derivative SILDeclRefs via `witness_method` and `class_method` instructions.
33

44
import Swift

test/AutoDiff/SILGen/vtable.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-silgen %s -requirement-machine=off | %FileCheck %s
22

33
// Test derivative function vtable entries for `@differentiable` class members:
44
// - Methods.

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -requirement-machine=off -verify %s
1+
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s
22

33
// Test differentiation transform diagnostics.
44

test/AutoDiff/compiler_crashers_fixed/sr12744-unhandled-pullback-indirect-result.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify %s
1+
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s
22

33
// SR-12744: Pullback generation crash for unhandled indirect result.
44
// May be due to inconsistent derivative function type calculation logic in

test/AutoDiff/compiler_crashers_fixed/sr14240-symbol-in-ir-file-not-tbd-file.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift
1+
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
22

33
// REQUIRES: executable_test
44

test/AutoDiff/mangling.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify %s | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify -requirement-machine=off %s | %FileCheck %s
22

33
import _Differentiation
44

test/AutoDiff/validation-test/forward_mode_simple.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
1+
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation -Xfrontend -requirement-machine=off)
22
// REQUIRES: executable_test
33

44
import StdlibUnittest

test/AutoDiff/validation-test/inout_parameters.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift
1+
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
22
// REQUIRES: executable_test
33

44
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.

0 commit comments

Comments
 (0)