Skip to content

Commit 9bcba98

Browse files
committed
Revert "Revert "[AutoDiff] Fix two derivative type calculation bugs caught by RequirementMachine""
This reverts commit 2629654.
1 parent 1a4ea67 commit 9bcba98

File tree

9 files changed

+83
-92
lines changed

9 files changed

+83
-92
lines changed

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 75 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,32 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
408408
return buildGenericSignature(ctx, sig, {}, reqs).getCanonicalSignature();
409409
}
410410

411+
/// Given an original type, computes its tangent type for the purpose of
412+
/// building a linear map using this type. When the original type is an
413+
/// archetype or contains a type parameter, appends a new generic parameter and
414+
/// a corresponding replacement type to the given containers.
415+
static CanType getAutoDiffTangentTypeForLinearMap(
416+
Type originalType,
417+
LookupConformanceFn lookupConformance,
418+
SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
419+
SmallVectorImpl<Type> &substReplacements,
420+
ASTContext &context
421+
) {
422+
auto maybeTanType = originalType->getAutoDiffTangentSpace(lookupConformance);
423+
assert(maybeTanType && "Type does not have a tangent space?");
424+
auto tanType = maybeTanType->getCanonicalType();
425+
// If concrete, the tangent type is concrete.
426+
if (!tanType->hasArchetype() && !tanType->hasTypeParameter())
427+
return tanType;
428+
// Otherwise, the tangent type is a new generic parameter substituted for the
429+
// tangent type.
430+
auto gpIndex = substGenericParams.size();
431+
auto gpType = CanGenericTypeParamType::get(0, gpIndex, context);
432+
substGenericParams.push_back(gpType);
433+
substReplacements.push_back(tanType);
434+
return gpType;
435+
}
436+
411437
/// Returns the differential type for the given original function type,
412438
/// parameter indices, and result index.
413439
static CanSILFunctionType getAutoDiffDifferentialType(
@@ -484,45 +510,32 @@ static CanSILFunctionType getAutoDiffDifferentialType(
484510
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
485511
SmallVector<SILParameterInfo, 8> differentialParams;
486512
for (auto &param : diffParams) {
487-
auto paramTan =
488-
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
489-
assert(paramTan && "Parameter type does not have a tangent space?");
490-
auto paramTanType = paramTan->getCanonicalType();
491-
auto paramConv = getTangentParameterConvention(paramTanType,
492-
param.getConvention());
493-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
494-
differentialParams.push_back(
495-
{paramTan->getCanonicalType(), paramConv});
496-
} else {
497-
auto gpIndex = substGenericParams.size();
498-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
499-
substGenericParams.push_back(gpType);
500-
substReplacements.push_back(paramTanType);
501-
differentialParams.push_back({gpType, paramConv});
502-
}
513+
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
514+
param.getInterfaceType(), lookupConformance,
515+
substGenericParams, substReplacements, ctx);
516+
auto paramConv = getTangentParameterConvention(
517+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
518+
param.getInterfaceType()
519+
->getAutoDiffTangentSpace(lookupConformance)
520+
->getCanonicalType(),
521+
param.getConvention());
522+
differentialParams.push_back({paramTanType, paramConv});
503523
}
504524
SmallVector<SILResultInfo, 1> differentialResults;
505525
for (auto resultIndex : resultIndices->getIndices()) {
506526
// Handle formal original result.
507527
if (resultIndex < originalFnTy->getNumResults()) {
508528
auto &result = originalResults[resultIndex];
509-
auto resultTan =
510-
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
511-
assert(resultTan && "Result type does not have a tangent space?");
512-
auto resultTanType = resultTan->getCanonicalType();
513-
auto resultConv =
514-
getTangentResultConvention(resultTanType, result.getConvention());
515-
if (!resultTanType->hasArchetype() &&
516-
!resultTanType->hasTypeParameter()) {
517-
differentialResults.push_back(
518-
{resultTan->getCanonicalType(), resultConv});
519-
} else {
520-
auto gpIndex = substGenericParams.size();
521-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
522-
substGenericParams.push_back(gpType);
523-
substReplacements.push_back(resultTanType);
524-
differentialResults.push_back({gpType, resultConv});
525-
}
529+
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
530+
result.getInterfaceType(), lookupConformance,
531+
substGenericParams, substReplacements, ctx);
532+
auto resultConv = getTangentResultConvention(
533+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
534+
result.getInterfaceType()
535+
->getAutoDiffTangentSpace(lookupConformance)
536+
->getCanonicalType(),
537+
result.getConvention());
538+
differentialResults.push_back({resultTanType, resultConv});
526539
continue;
527540
}
528541
// Handle original `inout` parameter.
@@ -537,11 +550,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
537550
if (parameterIndices->contains(paramIndex))
538551
continue;
539552
auto inoutParam = originalFnTy->getParameters()[paramIndex];
540-
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
541-
lookupConformance);
542-
assert(paramTan && "Parameter type does not have a tangent space?");
553+
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
554+
inoutParam.getInterfaceType(), lookupConformance,
555+
substGenericParams, substReplacements, ctx);
543556
differentialResults.push_back(
544-
{paramTan->getCanonicalType(), ResultConvention::Indirect});
557+
{inoutParamTanType, ResultConvention::Indirect});
545558
}
546559

547560
SubstitutionMap substitutions;
@@ -648,23 +661,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
648661
// Handle formal original result.
649662
if (resultIndex < originalFnTy->getNumResults()) {
650663
auto &origRes = originalResults[resultIndex];
651-
auto resultTan = origRes.getInterfaceType()->getAutoDiffTangentSpace(
652-
lookupConformance);
653-
assert(resultTan && "Result type does not have a tangent space?");
654-
auto resultTanType = resultTan->getCanonicalType();
655-
auto paramTanConvention = getTangentParameterConventionForOriginalResult(
656-
resultTanType, origRes.getConvention());
657-
if (!resultTanType->hasArchetype() &&
658-
!resultTanType->hasTypeParameter()) {
659-
auto resultTanType = resultTan->getCanonicalType();
660-
pullbackParams.push_back({resultTanType, paramTanConvention});
661-
} else {
662-
auto gpIndex = substGenericParams.size();
663-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
664-
substGenericParams.push_back(gpType);
665-
substReplacements.push_back(resultTanType);
666-
pullbackParams.push_back({gpType, paramTanConvention});
667-
}
664+
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
665+
origRes.getInterfaceType(), lookupConformance,
666+
substGenericParams, substReplacements, ctx);
667+
auto paramConv = getTangentParameterConventionForOriginalResult(
668+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
669+
origRes.getInterfaceType()
670+
->getAutoDiffTangentSpace(lookupConformance)
671+
->getCanonicalType(),
672+
origRes.getConvention());
673+
pullbackParams.push_back({resultTanType, paramConv});
668674
continue;
669675
}
670676
// Handle original `inout` parameter.
@@ -674,28 +680,18 @@ static CanSILFunctionType getAutoDiffPullbackType(
674680
auto paramIndex =
675681
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
676682
auto inoutParam = originalFnTy->getParameters()[paramIndex];
677-
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
678-
lookupConformance);
679-
assert(paramTan && "Parameter type does not have a tangent space?");
680683
// The pullback parameter convention depends on whether the original `inout`
681684
// paramater is a differentiability parameter.
682685
// - If yes, the pullback parameter convention is `@inout`.
683686
// - If no, the pullback parameter convention is `@in_guaranteed`.
687+
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
688+
inoutParam.getInterfaceType(), lookupConformance,
689+
substGenericParams, substReplacements, ctx);
684690
bool isWrtInoutParameter = parameterIndices->contains(paramIndex);
685691
auto paramTanConvention = isWrtInoutParameter
686-
? inoutParam.getConvention()
687-
: ParameterConvention::Indirect_In_Guaranteed;
688-
auto paramTanType = paramTan->getCanonicalType();
689-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
690-
pullbackParams.push_back(
691-
SILParameterInfo(paramTanType, paramTanConvention));
692-
} else {
693-
auto gpIndex = substGenericParams.size();
694-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
695-
substGenericParams.push_back(gpType);
696-
substReplacements.push_back(paramTanType);
697-
pullbackParams.push_back({gpType, paramTanConvention});
698-
}
692+
? inoutParam.getConvention()
693+
: ParameterConvention::Indirect_In_Guaranteed;
694+
pullbackParams.push_back({inoutParamTanType, paramTanConvention});
699695
}
700696

701697
// Collect pullback results.
@@ -707,21 +703,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
707703
// and always appear as pullback parameters.
708704
if (param.isIndirectInOut())
709705
continue;
710-
auto paramTan =
711-
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
712-
assert(paramTan && "Parameter type does not have a tangent space?");
713-
auto paramTanType = paramTan->getCanonicalType();
706+
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
707+
param.getInterfaceType(), lookupConformance,
708+
substGenericParams, substReplacements, ctx);
714709
auto resultTanConvention = getTangentResultConventionForOriginalParameter(
715-
paramTanType, param.getConvention());
716-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
717-
pullbackResults.push_back({paramTanType, resultTanConvention});
718-
} else {
719-
auto gpIndex = substGenericParams.size();
720-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
721-
substGenericParams.push_back(gpType);
722-
substReplacements.push_back(paramTanType);
723-
pullbackResults.push_back({gpType, resultTanConvention});
724-
}
710+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
711+
param.getInterfaceType()
712+
->getAutoDiffTangentSpace(lookupConformance)
713+
->getCanonicalType(),
714+
param.getConvention());
715+
pullbackResults.push_back({paramTanType, resultTanConvention});
725716
}
726717
SubstitutionMap substitutions;
727718
if (!substGenericParams.empty()) {

test/AutoDiff/SIL/Parse/sildeclref.sil

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-sil-opt %s -module-name=sildeclref_parse -requirement-machine=off | %target-sil-opt -module-name=sildeclref_parse -requirement-machine=off | %FileCheck %s
1+
// RUN: %target-sil-opt %s -module-name=sildeclref_parse | %target-sil-opt -module-name=sildeclref_parse | %FileCheck %s
22
// Parse AutoDiff derivative SILDeclRefs via `witness_method` and `class_method` instructions.
33

44
import Swift

test/AutoDiff/SILGen/vtable.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-silgen %s -requirement-machine=off | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
22

33
// Test derivative function vtable entries for `@differentiable` class members:
44
// - Methods.

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s
1+
// RUN: %target-swift-frontend -emit-sil -requirement-machine=off -verify %s
22

33
// Test differentiation transform diagnostics.
44

test/AutoDiff/compiler_crashers_fixed/sr12744-unhandled-pullback-indirect-result.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
22

33
// SR-12744: Pullback generation crash for unhandled indirect result.
44
// May be due to inconsistent derivative function type calculation logic in

test/AutoDiff/compiler_crashers_fixed/sr14240-symbol-in-ir-file-not-tbd-file.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
1+
// RUN: %target-run-simple-swift
22

33
// REQUIRES: executable_test
44

test/AutoDiff/mangling.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify -requirement-machine=off %s | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify %s | %FileCheck %s
22

33
import _Differentiation
44

test/AutoDiff/validation-test/forward_mode_simple.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation -Xfrontend -requirement-machine=off)
1+
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
22
// REQUIRES: executable_test
33

44
import StdlibUnittest

test/AutoDiff/validation-test/inout_parameters.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
1+
// RUN: %target-run-simple-swift
22
// REQUIRES: executable_test
33

44
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.

0 commit comments

Comments
 (0)