Skip to content

Revert "[AutoDiff] Fix two derivative type calculation bugs caught by RequirementMachine" #40057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 84 additions & 75 deletions lib/SIL/IR/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,32 +408,6 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
return buildGenericSignature(ctx, sig, {}, reqs).getCanonicalSignature();
}

/// Given an original type, computes its tangent type for the purpose of
/// building a linear map using this type. When the original type is an
/// archetype or contains a type parameter, appends a new generic parameter and
/// a corresponding replacement type to the given containers.
static CanType getAutoDiffTangentTypeForLinearMap(
Type originalType,
LookupConformanceFn lookupConformance,
SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
SmallVectorImpl<Type> &substReplacements,
ASTContext &context
) {
auto maybeTanType = originalType->getAutoDiffTangentSpace(lookupConformance);
assert(maybeTanType && "Type does not have a tangent space?");
auto tanType = maybeTanType->getCanonicalType();
// If concrete, the tangent type is concrete.
if (!tanType->hasArchetype() && !tanType->hasTypeParameter())
return tanType;
// Otherwise, the tangent type is a new generic parameter substituted for the
// tangent type.
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, context);
substGenericParams.push_back(gpType);
substReplacements.push_back(tanType);
return gpType;
}

/// Returns the differential type for the given original function type,
/// parameter indices, and result index.
static CanSILFunctionType getAutoDiffDifferentialType(
Expand Down Expand Up @@ -510,32 +484,45 @@ static CanSILFunctionType getAutoDiffDifferentialType(
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
SmallVector<SILParameterInfo, 8> differentialParams;
for (auto &param : diffParams) {
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
param.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
auto paramConv = getTangentParameterConvention(
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
param.getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getCanonicalType(),
param.getConvention());
differentialParams.push_back({paramTanType, paramConv});
auto paramTan =
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
auto paramTanType = paramTan->getCanonicalType();
auto paramConv = getTangentParameterConvention(paramTanType,
param.getConvention());
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
differentialParams.push_back(
{paramTan->getCanonicalType(), paramConv});
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(paramTanType);
differentialParams.push_back({gpType, paramConv});
}
}
SmallVector<SILResultInfo, 1> differentialResults;
for (auto resultIndex : resultIndices->getIndices()) {
// Handle formal original result.
if (resultIndex < originalFnTy->getNumResults()) {
auto &result = originalResults[resultIndex];
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
result.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
auto resultConv = getTangentResultConvention(
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
result.getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getCanonicalType(),
result.getConvention());
differentialResults.push_back({resultTanType, resultConv});
auto resultTan =
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(resultTan && "Result type does not have a tangent space?");
auto resultTanType = resultTan->getCanonicalType();
auto resultConv =
getTangentResultConvention(resultTanType, result.getConvention());
if (!resultTanType->hasArchetype() &&
!resultTanType->hasTypeParameter()) {
differentialResults.push_back(
{resultTan->getCanonicalType(), resultConv});
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(resultTanType);
differentialResults.push_back({gpType, resultConv});
}
continue;
}
// Handle original `inout` parameter.
Expand All @@ -550,11 +537,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
if (parameterIndices->contains(paramIndex))
continue;
auto inoutParam = originalFnTy->getParameters()[paramIndex];
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
inoutParam.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
differentialResults.push_back(
{inoutParamTanType, ResultConvention::Indirect});
{paramTan->getCanonicalType(), ResultConvention::Indirect});
}

SubstitutionMap substitutions;
Expand Down Expand Up @@ -661,16 +648,23 @@ static CanSILFunctionType getAutoDiffPullbackType(
// Handle formal original result.
if (resultIndex < originalFnTy->getNumResults()) {
auto &origRes = originalResults[resultIndex];
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
origRes.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
auto paramConv = getTangentParameterConventionForOriginalResult(
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
origRes.getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getCanonicalType(),
origRes.getConvention());
pullbackParams.push_back({resultTanType, paramConv});
auto resultTan = origRes.getInterfaceType()->getAutoDiffTangentSpace(
lookupConformance);
assert(resultTan && "Result type does not have a tangent space?");
auto resultTanType = resultTan->getCanonicalType();
auto paramTanConvention = getTangentParameterConventionForOriginalResult(
resultTanType, origRes.getConvention());
if (!resultTanType->hasArchetype() &&
!resultTanType->hasTypeParameter()) {
auto resultTanType = resultTan->getCanonicalType();
pullbackParams.push_back({resultTanType, paramTanConvention});
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(resultTanType);
pullbackParams.push_back({gpType, paramTanConvention});
}
continue;
}
// Handle original `inout` parameter.
Expand All @@ -680,18 +674,28 @@ static CanSILFunctionType getAutoDiffPullbackType(
auto paramIndex =
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
auto inoutParam = originalFnTy->getParameters()[paramIndex];
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
// The pullback parameter convention depends on whether the original `inout`
// paramater is a differentiability parameter.
// - If yes, the pullback parameter convention is `@inout`.
// - If no, the pullback parameter convention is `@in_guaranteed`.
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
inoutParam.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
bool isWrtInoutParameter = parameterIndices->contains(paramIndex);
auto paramTanConvention = isWrtInoutParameter
? inoutParam.getConvention()
: ParameterConvention::Indirect_In_Guaranteed;
pullbackParams.push_back({inoutParamTanType, paramTanConvention});
? inoutParam.getConvention()
: ParameterConvention::Indirect_In_Guaranteed;
auto paramTanType = paramTan->getCanonicalType();
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
pullbackParams.push_back(
SILParameterInfo(paramTanType, paramTanConvention));
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(paramTanType);
pullbackParams.push_back({gpType, paramTanConvention});
}
}

// Collect pullback results.
Expand All @@ -703,16 +707,21 @@ static CanSILFunctionType getAutoDiffPullbackType(
// and always appear as pullback parameters.
if (param.isIndirectInOut())
continue;
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
param.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
auto paramTan =
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
auto paramTanType = paramTan->getCanonicalType();
auto resultTanConvention = getTangentResultConventionForOriginalParameter(
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
param.getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getCanonicalType(),
param.getConvention());
pullbackResults.push_back({paramTanType, resultTanConvention});
paramTanType, param.getConvention());
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
pullbackResults.push_back({paramTanType, resultTanConvention});
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(paramTanType);
pullbackResults.push_back({gpType, resultTanConvention});
}
}
SubstitutionMap substitutions;
if (!substGenericParams.empty()) {
Expand Down
8 changes: 1 addition & 7 deletions lib/SIL/IR/TypeLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2629,15 +2629,9 @@ CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) {
if (auto *derivativeId = c.getDerivativeFunctionIdentifier()) {
auto originalFnTy =
makeConstantInterfaceType(c.asAutoDiffOriginalFunction());
// Protocol witness derivatives cannot have a derivative generic signature,
// but class method derivatives can.
GenericSignature derivativeGenSig = nullptr;
if (isa<ClassDecl>(c.getAbstractFunctionDecl()->getInnermostTypeContext()))
derivativeGenSig = derivativeId->getDerivativeGenericSignature();
auto *derivativeFnTy = originalFnTy->getAutoDiffDerivativeFunctionType(
derivativeId->getParameterIndices(), derivativeId->getKind(),
LookUpConformanceInModule(&M),
derivativeGenSig);
LookUpConformanceInModule(&M));
return cast<AnyFunctionType>(derivativeFnTy->getCanonicalType());
}

Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/SIL/Parse/sildeclref.sil
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-sil-opt %s -module-name=sildeclref_parse | %target-sil-opt -module-name=sildeclref_parse | %FileCheck %s
// RUN: %target-sil-opt %s -module-name=sildeclref_parse -requirement-machine=off | %target-sil-opt -module-name=sildeclref_parse -requirement-machine=off | %FileCheck %s
// Parse AutoDiff derivative SILDeclRefs via `witness_method` and `class_method` instructions.

import Swift
Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/SILGen/vtable.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
// RUN: %target-swift-frontend -emit-silgen %s -requirement-machine=off | %FileCheck %s

// Test derivative function vtable entries for `@differentiable` class members:
// - Methods.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-swift-frontend -emit-sil -requirement-machine=off -verify %s
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s

// Test differentiation transform diagnostics.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-swift-frontend -emit-sil -verify %s
// RUN: %target-swift-frontend -emit-sil -verify -requirement-machine=off %s

// SR-12744: Pullback generation crash for unhandled indirect result.
// May be due to inconsistent derivative function type calculation logic in
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)

// REQUIRES: executable_test

Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/mangling.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify %s | %FileCheck %s
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -module-name=mangling -verify -requirement-machine=off %s | %FileCheck %s

import _Differentiation

Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/validation-test/forward_mode_simple.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation -Xfrontend -requirement-machine=off)
// REQUIRES: executable_test

import StdlibUnittest
Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/validation-test/inout_parameters.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
// REQUIRES: executable_test

// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.
Expand Down