Skip to content

[AutoDiff] WIP: Use owned callee convention for linear maps. #34935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions include/swift/SIL/AbstractionPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ class AbstractionPattern {
/// member function. OrigType is valid and is a function type. CXXMethod is
/// valid.
PartialCurriedCXXOperatorMethodType,
/// A derivative function type.
DerivativeFunctionType,
/// A Swift function whose parameters and results are opaque. This is
/// like `AP::Type<T>((T) -> T)`, except that the number of parameters is
/// unspecified.
Expand All @@ -225,31 +227,31 @@ class AbstractionPattern {
///
/// differentiable_function
/// [parameters 0]
/// %0 : $@callee_guaranteed (Float) -> Float
/// %0 : $@callee_owned (Float) -> Float
/// with_derivative {
/// %1 : $@callee_guaranteed (Float) -> (
/// %1 : $@callee_owned (Float) -> (
/// Float,
/// @owned @callee_guaranteed (Float) -> Float
/// @owned @callee_owned (Float) -> Float
/// ),
/// %2 : $@callee_guaranteed (Float) -> (
/// %2 : $@callee_owned (Float) -> (
/// Float,
/// @owned @callee_guaranteed (Float) -> Float
/// @owned @callee_owned (Float) -> Float
/// )
/// }
///
/// The invariant-respecting abstraction of this value to `AP::Opaque` is:
///
/// differentiable_function
/// [parameters 0]
/// %3 : $@callee_guaranteed (@in_guaranteed Float) -> @out Float
/// %3 : $@callee_owned (@in_guaranteed Float) -> @out Float
/// with_derivative {
/// %4 : $@callee_guaranteed (@in_guaranteed Float) -> (
/// %4 : $@callee_owned (@in_guaranteed Float) -> (
/// @out Float,
/// @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float
/// @owned @callee_owned (@in_guaranteed Float) -> @out Float
/// ),
/// %5 : $@callee_guaranteed (@in_guaranteed Float) -> (
/// %5 : $@callee_owned (@in_guaranteed Float) -> (
/// @out Float,
/// @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float
/// @owned @callee_owned (@in_guaranteed Float) -> @out Float
/// )
/// }
///
Expand Down Expand Up @@ -618,6 +620,7 @@ class AbstractionPattern {
case Kind::CurriedCXXOperatorMethodType:
case Kind::PartialCurriedCXXOperatorMethodType:
case Kind::ObjCCompletionHandlerArgumentsType:
case Kind::DerivativeFunctionType:
return true;
case Kind::Invalid:
case Kind::Opaque:
Expand Down Expand Up @@ -1029,6 +1032,7 @@ class AbstractionPattern {
case Kind::CXXOperatorMethodType:
case Kind::CurriedCXXOperatorMethodType:
case Kind::PartialCurriedCXXOperatorMethodType:
case Kind::DerivativeFunctionType:
case Kind::Type:
case Kind::Discard:
return OrigType;
Expand Down Expand Up @@ -1068,6 +1072,7 @@ class AbstractionPattern {
case Kind::CXXOperatorMethodType:
case Kind::CurriedCXXOperatorMethodType:
case Kind::PartialCurriedCXXOperatorMethodType:
case Kind::DerivativeFunctionType:
case Kind::Type:
case Kind::Discard:
case Kind::ObjCCompletionHandlerArgumentsType:
Expand All @@ -1093,6 +1098,7 @@ class AbstractionPattern {
case Kind::Tuple:
case Kind::Type:
case Kind::Discard:
case Kind::DerivativeFunctionType:
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
return false;
Expand Down Expand Up @@ -1120,6 +1126,11 @@ class AbstractionPattern {
return getKind() == Kind::Discard;
}

/// True if the value is a derivative function type.
bool isDerivativeFunctionType() const {
return getKind() == Kind::DerivativeFunctionType;
}

/// Return whether this abstraction pattern represents a Clang type.
/// If so, it is legal to return getClangType().
bool isClangType() const {
Expand Down Expand Up @@ -1190,6 +1201,7 @@ class AbstractionPattern {
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
case Kind::ObjCCompletionHandlerArgumentsType:
case Kind::DerivativeFunctionType:
return false;
case Kind::PartialCurriedObjCMethodType:
case Kind::CurriedObjCMethodType:
Expand All @@ -1211,6 +1223,7 @@ class AbstractionPattern {
return typename CanTypeWrapperTraits<TYPE>::type();
case Kind::Tuple:
return typename CanTypeWrapperTraits<TYPE>::type();
case Kind::DerivativeFunctionType:
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
return typename CanTypeWrapperTraits<TYPE>::type();
Expand Down Expand Up @@ -1266,6 +1279,7 @@ class AbstractionPattern {
return false;
case Kind::Type:
case Kind::Discard:
case Kind::DerivativeFunctionType:
return getType() == type;
}
llvm_unreachable("bad kind");
Expand Down Expand Up @@ -1294,6 +1308,7 @@ class AbstractionPattern {
case Kind::PartialCurriedCXXOperatorMethodType:
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
case Kind::DerivativeFunctionType:
return false;
case Kind::ObjCCompletionHandlerArgumentsType:
case Kind::Tuple:
Expand Down Expand Up @@ -1325,6 +1340,7 @@ class AbstractionPattern {
case Kind::PartialCurriedCXXOperatorMethodType:
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
case Kind::DerivativeFunctionType:
llvm_unreachable("pattern is not a tuple");
case Kind::Tuple:
return getNumTupleElements_Stored();
Expand Down
2 changes: 1 addition & 1 deletion include/swift/SILOptimizer/Differentiation/ADContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct NestedApplyInfo {
AutoDiffConfig config;
/// The original pullback type before reabstraction. `None` if the pullback
/// type is not reabstracted.
Optional<CanSILFunctionType> originalPullbackType;
CanSILFunctionType originalPullbackType;
};

/// Per-module contextual information for the Differentiation pass.
Expand Down
26 changes: 22 additions & 4 deletions lib/SIL/IR/AbstractionPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ AbstractionPattern::getOptional(AbstractionPattern object) {
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
case Kind::ObjCCompletionHandlerArgumentsType:
case Kind::DerivativeFunctionType:
llvm_unreachable("cannot add optionality to non-type abstraction");
case Kind::Opaque:
return AbstractionPattern::getOpaque();
Expand Down Expand Up @@ -300,6 +301,7 @@ bool AbstractionPattern::matchesTuple(CanTupleType substType) {
case Kind::PartialCurriedCXXOperatorMethodType:
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
case Kind::DerivativeFunctionType:
return false;
case Kind::Opaque:
return true;
Expand Down Expand Up @@ -376,6 +378,7 @@ AbstractionPattern::getTupleElementType(unsigned index) const {
case Kind::PartialCurriedCXXOperatorMethodType:
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
case Kind::DerivativeFunctionType:
llvm_unreachable("function types are not tuples");
case Kind::Opaque:
return *this;
Expand Down Expand Up @@ -484,6 +487,7 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
case Kind::Opaque:
return *this;
case Kind::Type:
case Kind::DerivativeFunctionType:
if (isTypeParameterOrOpaqueArchetype())
return AbstractionPattern::getOpaque();
return AbstractionPattern(getGenericSignatureForFunctionComponent(),
Expand Down Expand Up @@ -633,6 +637,7 @@ AbstractionPattern::getObjCMethodAsyncCompletionHandlerType(
case Kind::Opaque:
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
case Kind::DerivativeFunctionType:
case Kind::Type:
return AbstractionPattern(getGenericSignature(),
swiftCompletionHandlerType);
Expand Down Expand Up @@ -685,7 +690,8 @@ AbstractionPattern::getFunctionParamType(unsigned index) const {
switch (getKind()) {
case Kind::Opaque:
return *this;
case Kind::Type: {
case Kind::Type:
case Kind::DerivativeFunctionType: {
if (isTypeParameterOrOpaqueArchetype())
return AbstractionPattern::getOpaque();
auto params = cast<AnyFunctionType>(getType()).getParams();
Expand Down Expand Up @@ -883,6 +889,7 @@ AbstractionPattern AbstractionPattern::getOptionalObjectType() const {
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
case Kind::ObjCCompletionHandlerArgumentsType:
case Kind::DerivativeFunctionType:
llvm_unreachable("pattern for function or tuple cannot be for optional");

case Kind::Opaque:
Expand Down Expand Up @@ -928,6 +935,7 @@ AbstractionPattern AbstractionPattern::getReferenceStorageReferentType() const {
case Kind::OpaqueFunction:
case Kind::OpaqueDerivativeFunction:
case Kind::ObjCCompletionHandlerArgumentsType:
case Kind::DerivativeFunctionType:
return *this;
case Kind::Type:
return AbstractionPattern(getGenericSignature(),
Expand Down Expand Up @@ -965,10 +973,13 @@ void AbstractionPattern::print(raw_ostream &out) const {
return;
case Kind::Type:
case Kind::Discard:
case Kind::DerivativeFunctionType:
out << (getKind() == Kind::Type
? "AP::Type" :
getKind() == Kind::Discard
? "AP::Discard" : "<<UNHANDLED CASE>>");
? "AP::Discard" :
getKind() == Kind::DerivativeFunctionType
? "AP:DerivativeFunctionType" : "<<UNHANDLED CASE>>");
if (auto sig = getGenericSignature()) {
sig->print(out);
}
Expand Down Expand Up @@ -1193,6 +1204,7 @@ const {
case Kind::ClangType:
case Kind::Type:
case Kind::Discard:
case Kind::DerivativeFunctionType:
auto memberTy = getType()->getTypeOfMember(member->getModuleContext(),
member, origMemberInterfaceType)
->getCanonicalType(getGenericSignature());
Expand All @@ -1215,9 +1227,12 @@ AbstractionPattern AbstractionPattern::getAutoDiffDerivativeFunctionType(
parameterIndices, kind, lookupConformance, derivativeGenericSignature,
makeSelfParamFirst);
assert(derivativeFnTy);
return AbstractionPattern(
AbstractionPattern pattern;
pattern.initSwiftType(
getGenericSignature(),
derivativeFnTy->getCanonicalType(getGenericSignature()));
derivativeFnTy->getCanonicalType(getGenericSignature()),
Kind::DerivativeFunctionType);
return pattern;
}
case Kind::Opaque:
return getOpaqueDerivativeFunction();
Expand Down Expand Up @@ -1251,6 +1266,7 @@ AbstractionPattern::getResultConvention(TypeConverter &TC) const {
case Kind::CXXOperatorMethodType:
case Kind::CurriedCXXOperatorMethodType:
case Kind::PartialCurriedCXXOperatorMethodType:
case Kind::DerivativeFunctionType:
// Function types are always passed directly
return Direct;

Expand Down Expand Up @@ -1295,6 +1311,7 @@ AbstractionPattern::getParameterConvention(TypeConverter &TC) const {
case Kind::CXXOperatorMethodType:
case Kind::CurriedCXXOperatorMethodType:
case Kind::PartialCurriedCXXOperatorMethodType:
case Kind::DerivativeFunctionType:
// Function types are always passed directly
return Direct;

Expand Down Expand Up @@ -1340,6 +1357,7 @@ AbstractionPattern::operator==(const AbstractionPattern &other) const {

case Kind::Type:
case Kind::Discard:
case Kind::DerivativeFunctionType:
return OrigType == other.OrigType
&& GenericSig == other.GenericSig;

Expand Down
50 changes: 37 additions & 13 deletions lib/SIL/IR/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,10 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy,
/// Collects the semantic results of the given function type in
/// `originalResults`. The semantic results are formal results followed by
/// `inout` parameters, in type order.
static void
getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
IndexSubset *&inoutParameterIndices,
SmallVectorImpl<SILResultInfo> &originalResults) {
static void getAutoDiffSemanticResults(
SILFunctionType *functionType, IndexSubset *parameterIndices,
IndexSubset *&inoutParameterIndices,
SmallVectorImpl<SILResultInfo> &originalResults) {
auto &C = functionType->getASTContext();
SmallVector<unsigned, 4> inoutParamIndices;
// Collect original formal results.
Expand All @@ -361,9 +361,10 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
IndexSubset::get(C, parameterIndices->getCapacity(), inoutParamIndices);
}

static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignature sig,
CanType tanType,
CanType origTypeOfAbstraction) {
static CanGenericSignature
buildDifferentiableGenericSignature(CanGenericSignature sig,
CanType tanType,
CanType origTypeOfAbstraction) {
if (!sig)
return sig;

Expand Down Expand Up @@ -504,8 +505,8 @@ static CanSILFunctionType getAutoDiffDifferentialType(

IndexSubset *inoutParamIndices;
SmallVector<SILResultInfo, 2> originalResults;
getSemanticResults(originalFnTy, parameterIndices, inoutParamIndices,
originalResults);
getAutoDiffSemanticResults(originalFnTy, parameterIndices, inoutParamIndices,
originalResults);

SmallVector<SILParameterInfo, 4> diffParams;
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
Expand Down Expand Up @@ -569,7 +570,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
}
return SILFunctionType::get(
GenericSignature(), SILFunctionType::ExtInfo(), SILCoroutineKind::None,
ParameterConvention::Direct_Guaranteed, differentialParams, {},
ParameterConvention::Direct_Owned, differentialParams, {},
differentialResults, None, substitutions,
/*invocationSubstitutions*/ SubstitutionMap(), ctx);
}
Expand All @@ -588,8 +589,8 @@ static CanSILFunctionType getAutoDiffPullbackType(

IndexSubset *inoutParamIndices;
SmallVector<SILResultInfo, 2> originalResults;
getSemanticResults(originalFnTy, parameterIndices, inoutParamIndices,
originalResults);
getAutoDiffSemanticResults(originalFnTy, parameterIndices, inoutParamIndices,
originalResults);

// Given a type, returns its formal SIL parameter info.
auto getTangentParameterConventionForOriginalResult =
Expand Down Expand Up @@ -726,7 +727,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
}
return SILFunctionType::get(
GenericSignature(), SILFunctionType::ExtInfo(), SILCoroutineKind::None,
ParameterConvention::Direct_Guaranteed, pullbackParams, {},
ParameterConvention::Direct_Owned, pullbackParams, {},
pullbackResults, None, substitutions,
/*invocationSubstitutions*/ SubstitutionMap(), ctx);
}
Expand Down Expand Up @@ -2022,6 +2023,29 @@ static CanSILFunctionType getSILFunctionType(
destructurer.destructure(origResultType, substFormalResultType);
}

// If it's a derivative function, its linear map result has `@callee_owned`
// convention.
if (origType.isDerivativeFunctionType()) {
assert(results.size() == 2);
auto &linearMapResult = results[1];
auto linearMapType = linearMapResult.getInterfaceType()
->getAs<SILFunctionType>();
auto newLinearMapType = SILFunctionType::get(
linearMapType->getInvocationGenericSignature(),
linearMapType->getExtInfo(),
linearMapType->getCoroutineKind(),
ParameterConvention::Direct_Owned,
linearMapType->getParameters(),
linearMapType->getYields(),
linearMapType->getResults(),
linearMapType->getOptionalErrorResult(),
linearMapType->getPatternSubstitutions(),
linearMapType->getInvocationSubstitutions(),
linearMapType->getASTContext(),
linearMapType->getWitnessMethodConformanceOrInvalid());
linearMapResult = linearMapResult.getWithInterfaceType(newLinearMapType);
}

// Lower the capture context parameters, if any.
if (constant && constant->getAnyFunctionRef()) {
// Lower in the context of the closure. Since the set of captures is a
Expand Down
Loading