Skip to content

Commit d993112

Browse files
authored
Merge pull request #39505 from rxwei/82549134-autodiff-requirement-machine-fix-1
2 parents 783b619 + 808c7ec commit d993112

File tree

10 files changed

+90
-93
lines changed

10 files changed

+90
-93
lines changed

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 75 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,32 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
396396
GenericSignature()).getCanonicalSignature();
397397
}
398398

399+
/// Given an original type, computes its tangent type for the purpose of
400+
/// building a linear map using this type. When the original type is an
401+
/// archetype or contains a type parameter, appends a new generic parameter and
402+
/// a corresponding replacement type to the given containers.
403+
static CanType getAutoDiffTangentTypeForLinearMap(
404+
Type originalType,
405+
LookupConformanceFn lookupConformance,
406+
SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
407+
SmallVectorImpl<Type> &substReplacements,
408+
ASTContext &context
409+
) {
410+
auto maybeTanType = originalType->getAutoDiffTangentSpace(lookupConformance);
411+
assert(maybeTanType && "Type does not have a tangent space?");
412+
auto tanType = maybeTanType->getCanonicalType();
413+
// If concrete, the tangent type is concrete.
414+
if (!tanType->hasArchetype() && !tanType->hasTypeParameter())
415+
return tanType;
416+
// Otherwise, the tangent type is a new generic parameter substituted for the
417+
// tangent type.
418+
auto gpIndex = substGenericParams.size();
419+
auto gpType = CanGenericTypeParamType::get(0, gpIndex, context);
420+
substGenericParams.push_back(gpType);
421+
substReplacements.push_back(tanType);
422+
return gpType;
423+
}
424+
399425
/// Returns the differential type for the given original function type,
400426
/// parameter indices, and result index.
401427
static CanSILFunctionType getAutoDiffDifferentialType(
@@ -471,45 +497,32 @@ static CanSILFunctionType getAutoDiffDifferentialType(
471497
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
472498
SmallVector<SILParameterInfo, 8> differentialParams;
473499
for (auto &param : diffParams) {
474-
auto paramTan =
475-
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
476-
assert(paramTan && "Parameter type does not have a tangent space?");
477-
auto paramTanType = paramTan->getCanonicalType();
478-
auto paramConv = getTangentParameterConvention(paramTanType,
479-
param.getConvention());
480-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
481-
differentialParams.push_back(
482-
{paramTan->getCanonicalType(), paramConv});
483-
} else {
484-
auto gpIndex = substGenericParams.size();
485-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
486-
substGenericParams.push_back(gpType);
487-
substReplacements.push_back(paramTanType);
488-
differentialParams.push_back({gpType, paramConv});
489-
}
500+
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
501+
param.getInterfaceType(), lookupConformance,
502+
substGenericParams, substReplacements, ctx);
503+
auto paramConv = getTangentParameterConvention(
504+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
505+
param.getInterfaceType()
506+
->getAutoDiffTangentSpace(lookupConformance)
507+
->getCanonicalType(),
508+
param.getConvention());
509+
differentialParams.push_back({paramTanType, paramConv});
490510
}
491511
SmallVector<SILResultInfo, 1> differentialResults;
492512
for (auto resultIndex : resultIndices->getIndices()) {
493513
// Handle formal original result.
494514
if (resultIndex < originalFnTy->getNumResults()) {
495515
auto &result = originalResults[resultIndex];
496-
auto resultTan =
497-
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
498-
assert(resultTan && "Result type does not have a tangent space?");
499-
auto resultTanType = resultTan->getCanonicalType();
500-
auto resultConv =
501-
getTangentResultConvention(resultTanType, result.getConvention());
502-
if (!resultTanType->hasArchetype() &&
503-
!resultTanType->hasTypeParameter()) {
504-
differentialResults.push_back(
505-
{resultTan->getCanonicalType(), resultConv});
506-
} else {
507-
auto gpIndex = substGenericParams.size();
508-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
509-
substGenericParams.push_back(gpType);
510-
substReplacements.push_back(resultTanType);
511-
differentialResults.push_back({gpType, resultConv});
512-
}
516+
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
517+
result.getInterfaceType(), lookupConformance,
518+
substGenericParams, substReplacements, ctx);
519+
auto resultConv = getTangentResultConvention(
520+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
521+
result.getInterfaceType()
522+
->getAutoDiffTangentSpace(lookupConformance)
523+
->getCanonicalType(),
524+
result.getConvention());
525+
differentialResults.push_back({resultTanType, resultConv});
513526
continue;
514527
}
515528
// Handle original `inout` parameter.
@@ -524,11 +537,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
524537
if (parameterIndices->contains(paramIndex))
525538
continue;
526539
auto inoutParam = originalFnTy->getParameters()[paramIndex];
527-
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
528-
lookupConformance);
529-
assert(paramTan && "Parameter type does not have a tangent space?");
540+
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
541+
inoutParam.getInterfaceType(), lookupConformance,
542+
substGenericParams, substReplacements, ctx);
530543
differentialResults.push_back(
531-
{paramTan->getCanonicalType(), ResultConvention::Indirect});
544+
{inoutParamTanType, ResultConvention::Indirect});
532545
}
533546

534547
SubstitutionMap substitutions;
@@ -635,23 +648,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
635648
// Handle formal original result.
636649
if (resultIndex < originalFnTy->getNumResults()) {
637650
auto &origRes = originalResults[resultIndex];
638-
auto resultTan = origRes.getInterfaceType()->getAutoDiffTangentSpace(
639-
lookupConformance);
640-
assert(resultTan && "Result type does not have a tangent space?");
641-
auto resultTanType = resultTan->getCanonicalType();
642-
auto paramTanConvention = getTangentParameterConventionForOriginalResult(
643-
resultTanType, origRes.getConvention());
644-
if (!resultTanType->hasArchetype() &&
645-
!resultTanType->hasTypeParameter()) {
646-
auto resultTanType = resultTan->getCanonicalType();
647-
pullbackParams.push_back({resultTanType, paramTanConvention});
648-
} else {
649-
auto gpIndex = substGenericParams.size();
650-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
651-
substGenericParams.push_back(gpType);
652-
substReplacements.push_back(resultTanType);
653-
pullbackParams.push_back({gpType, paramTanConvention});
654-
}
651+
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
652+
origRes.getInterfaceType(), lookupConformance,
653+
substGenericParams, substReplacements, ctx);
654+
auto paramConv = getTangentParameterConventionForOriginalResult(
655+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
656+
origRes.getInterfaceType()
657+
->getAutoDiffTangentSpace(lookupConformance)
658+
->getCanonicalType(),
659+
origRes.getConvention());
660+
pullbackParams.push_back({resultTanType, paramConv});
655661
continue;
656662
}
657663
// Handle original `inout` parameter.
@@ -661,28 +667,18 @@ static CanSILFunctionType getAutoDiffPullbackType(
661667
auto paramIndex =
662668
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
663669
auto inoutParam = originalFnTy->getParameters()[paramIndex];
664-
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
665-
lookupConformance);
666-
assert(paramTan && "Parameter type does not have a tangent space?");
667670
// The pullback parameter convention depends on whether the original `inout`
668671
// paramater is a differentiability parameter.
669672
// - If yes, the pullback parameter convention is `@inout`.
670673
// - If no, the pullback parameter convention is `@in_guaranteed`.
674+
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
675+
inoutParam.getInterfaceType(), lookupConformance,
676+
substGenericParams, substReplacements, ctx);
671677
bool isWrtInoutParameter = parameterIndices->contains(paramIndex);
672678
auto paramTanConvention = isWrtInoutParameter
673-
? inoutParam.getConvention()
674-
: ParameterConvention::Indirect_In_Guaranteed;
675-
auto paramTanType = paramTan->getCanonicalType();
676-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
677-
pullbackParams.push_back(
678-
SILParameterInfo(paramTanType, paramTanConvention));
679-
} else {
680-
auto gpIndex = substGenericParams.size();
681-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
682-
substGenericParams.push_back(gpType);
683-
substReplacements.push_back(paramTanType);
684-
pullbackParams.push_back({gpType, paramTanConvention});
685-
}
679+
? inoutParam.getConvention()
680+
: ParameterConvention::Indirect_In_Guaranteed;
681+
pullbackParams.push_back({inoutParamTanType, paramTanConvention});
686682
}
687683

688684
// Collect pullback results.
@@ -694,21 +690,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
694690
// and always appear as pullback parameters.
695691
if (param.isIndirectInOut())
696692
continue;
697-
auto paramTan =
698-
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
699-
assert(paramTan && "Parameter type does not have a tangent space?");
700-
auto paramTanType = paramTan->getCanonicalType();
693+
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
694+
param.getInterfaceType(), lookupConformance,
695+
substGenericParams, substReplacements, ctx);
701696
auto resultTanConvention = getTangentResultConventionForOriginalParameter(
702-
paramTanType, param.getConvention());
703-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
704-
pullbackResults.push_back({paramTanType, resultTanConvention});
705-
} else {
706-
auto gpIndex = substGenericParams.size();
707-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
708-
substGenericParams.push_back(gpType);
709-
substReplacements.push_back(paramTanType);
710-
pullbackResults.push_back({gpType, resultTanConvention});
711-
}
697+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
698+
param.getInterfaceType()
699+
->getAutoDiffTangentSpace(lookupConformance)
700+
->getCanonicalType(),
701+
param.getConvention());
702+
pullbackResults.push_back({paramTanType, resultTanConvention});
712703
}
713704
SubstitutionMap substitutions;
714705
if (!substGenericParams.empty()) {

lib/SIL/IR/TypeLowering.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2599,9 +2599,15 @@ CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) {
25992599
if (auto *derivativeId = c.getDerivativeFunctionIdentifier()) {
26002600
auto originalFnTy =
26012601
makeConstantInterfaceType(c.asAutoDiffOriginalFunction());
2602+
// Protocol witness derivatives cannot have a derivative generic signature,
2603+
// but class method derivatives can.
2604+
GenericSignature derivativeGenSig = nullptr;
2605+
if (isa<ClassDecl>(c.getAbstractFunctionDecl()->getInnermostTypeContext()))
2606+
derivativeGenSig = derivativeId->getDerivativeGenericSignature();
26022607
auto *derivativeFnTy = originalFnTy->getAutoDiffDerivativeFunctionType(
26032608
derivativeId->getParameterIndices(), derivativeId->getKind(),
2604-
LookUpConformanceInModule(&M));
2609+
LookUpConformanceInModule(&M),
2610+
derivativeGenSig);
26052611
return cast<AnyFunctionType>(derivativeFnTy->getCanonicalType());
26062612
}
26072613

test/AutoDiff/SIL/Parse/sildeclref.sil

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-sil-opt %s -module-name=sildeclref_parse -requirement-machine=off | %target-sil-opt -module-name=sildeclref_parse -requirement-machine=off | %FileCheck %s
1+
// RUN: %target-sil-opt %s -module-name=sildeclref_parse | %target-sil-opt -module-name=sildeclref_parse | %FileCheck %s
22
// Parse AutoDiff derivative SILDeclRefs via `witness_method` and `class_method` instructions.
33

44
import Swift

test/AutoDiff/SILGen/vtable.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-silgen %s -requirement-machine=off | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
22

33
// Test derivative function vtable entries for `@differentiable` class members:
44
// - Methods.

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s
1+
// RUN: %target-swift-frontend -emit-sil -requirement-machine=off -verify %s
22

33
// Test differentiation transform diagnostics.
44

test/AutoDiff/compiler_crashers_fixed/sr12744-unhandled-pullback-indirect-result.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
22

33
// SR-12744: Pullback generation crash for unhandled indirect result.
44
// May be due to inconsistent derivative function type calculation logic in

test/AutoDiff/compiler_crashers_fixed/sr14240-symbol-in-ir-file-not-tbd-file.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
1+
// RUN: %target-run-simple-swift
22

33
// REQUIRES: executable_test
44

test/AutoDiff/mangling.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify -requirement-machine=off %s | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify %s | %FileCheck %s
22

33
import _Differentiation
44

test/AutoDiff/validation-test/forward_mode_simple.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation -Xfrontend -requirement-machine=off)
1+
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
22
// REQUIRES: executable_test
33

44
import StdlibUnittest

test/AutoDiff/validation-test/inout_parameters.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
1+
// RUN: %target-run-simple-swift
22
// REQUIRES: executable_test
33

44
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.

0 commit comments

Comments
 (0)