Skip to content

[AutoDiff upstream] Add differentiable function type mangling. #30675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/ABI/Mangling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,10 @@ Types
FUNCTION-KIND ::= 'C' // C function pointer type
FUNCTION-KIND ::= 'A' // @auto_closure function type (escaping)
FUNCTION-KIND ::= 'E' // function type (noescape)
FUNCTION-KIND ::= 'F' // @differentiable function type
FUNCTION-KIND ::= 'G' // @differentiable function type (escaping)
FUNCTION-KIND ::= 'H' // @differentiable(linear) function type
FUNCTION-KIND ::= 'I' // @differentiable(linear) function type (escaping)

function-signature ::= params-type params-type throws? // results and parameters

Expand Down Expand Up @@ -585,14 +589,18 @@ mangled in to disambiguate.
impl-function-type ::= type* 'I' FUNC-ATTRIBUTES '_'
impl-function-type ::= type* generic-signature 'I' FUNC-ATTRIBUTES '_'

FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? PARAM-CONVENTION* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)?
FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? DIFFERENTIABILITY-KIND? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? PARAM-CONVENTION* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)?

PATTERN-SUBS ::= 's' // has pattern substitutions
INVOCATION-SUB ::= 'I' // has invocation substitutions
PSEUDO-GENERIC ::= 'P'

CALLEE-ESCAPE ::= 'e' // @escaping (inverse of SIL @noescape)

DIFFERENTIABILITY-KIND ::= DIFFERENTIABLE | LINEAR
DIFFERENTIABLE ::= 'd' // @differentiable
LINEAR ::= 'l' // @differentiable(linear)

CALLEE-CONVENTION ::= 'y' // @callee_unowned
CALLEE-CONVENTION ::= 'g' // @callee_guaranteed
CALLEE-CONVENTION ::= 'x' // @callee_owned
Expand Down
41 changes: 38 additions & 3 deletions include/swift/ABI/MetadataValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,14 @@ enum class FunctionMetadataConvention: uint8_t {
CFunctionPointer = 3,
};

/// Differentiability kind for function type metadata.
/// Duplicates `DifferentiabilityKind` in AutoDiff.h.
enum class FunctionMetadataDifferentiabilityKind: uint8_t {
NonDifferentiable = 0b00,
Normal = 0b01,
Linear = 0b11
};

/// Flags in a function type metadata record.
template <typename int_type>
class TargetFunctionTypeFlags {
Expand All @@ -777,6 +785,8 @@ class TargetFunctionTypeFlags {
ThrowsMask = 0x01000000U,
ParamFlagsMask = 0x02000000U,
EscapingMask = 0x04000000U,
DifferentiableMask = 0x08000000U,
LinearMask = 0x10000000U
};
int_type Data;

Expand All @@ -801,6 +811,16 @@ class TargetFunctionTypeFlags {
(throws ? ThrowsMask : 0));
}

constexpr TargetFunctionTypeFlags<int_type> withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind differentiability) const {
return TargetFunctionTypeFlags<int_type>(
(Data & ~DifferentiableMask & ~LinearMask) |
(differentiability == FunctionMetadataDifferentiabilityKind::Normal
? DifferentiableMask : 0) |
(differentiability == FunctionMetadataDifferentiabilityKind::Linear
? LinearMask : 0));
}

constexpr TargetFunctionTypeFlags<int_type>
withParameterFlags(bool hasFlags) const {
return TargetFunctionTypeFlags<int_type>((Data & ~ParamFlagsMask) |
Expand Down Expand Up @@ -829,6 +849,19 @@ class TargetFunctionTypeFlags {

bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }

bool isDifferentiable() const {
return getDifferentiabilityKind() >=
FunctionMetadataDifferentiabilityKind::Normal;
}

FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
if (bool(Data & DifferentiableMask))
return FunctionMetadataDifferentiabilityKind::Normal;
if (bool(Data & LinearMask))
return FunctionMetadataDifferentiabilityKind::Linear;
return FunctionMetadataDifferentiabilityKind::NonDifferentiable;
}

int_type getIntValue() const {
return Data;
}
Expand All @@ -849,9 +882,10 @@ using FunctionTypeFlags = TargetFunctionTypeFlags<size_t>;
template <typename int_type>
class TargetParameterTypeFlags {
enum : int_type {
ValueOwnershipMask = 0x7F,
VariadicMask = 0x80,
AutoClosureMask = 0x100,
ValueOwnershipMask = 0x7F,
VariadicMask = 0x80,
AutoClosureMask = 0x100,
NoDerivativeMask = 0x200
};
int_type Data;

Expand Down Expand Up @@ -881,6 +915,7 @@ class TargetParameterTypeFlags {
bool isNone() const { return Data == 0; }
bool isVariadic() const { return Data & VariadicMask; }
bool isAutoClosure() const { return Data & AutoClosureMask; }
bool isNoDerivative() const { return Data & NoDerivativeMask; }

ValueOwnership getValueOwnership() const {
return (ValueOwnership)(Data & ValueOwnershipMask);
Expand Down
6 changes: 6 additions & 0 deletions include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ NODE(DependentProtocolConformanceInherited)
NODE(DependentProtocolConformanceAssociated)
CONTEXT_NODE(Destructor)
CONTEXT_NODE(DidSet)
NODE(DifferentiableFunctionType)
NODE(EscapingDifferentiableFunctionType)
NODE(LinearFunctionType)
NODE(EscapingLinearFunctionType)
NODE(Directness)
NODE(DynamicAttribute)
NODE(DirectMethodReferenceAttribute)
Expand Down Expand Up @@ -109,6 +113,8 @@ NODE(Identifier)
NODE(Index)
CONTEXT_NODE(IVarInitializer)
CONTEXT_NODE(IVarDestroyer)
NODE(ImplDifferentiable)
NODE(ImplLinear)
NODE(ImplEscaping)
NODE(ImplConvention)
NODE(ImplFunctionAttribute)
Expand Down
19 changes: 18 additions & 1 deletion include/swift/Demangling/TypeDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ class TypeDecoder {
case NodeKind::NoEscapeFunctionType:
case NodeKind::AutoClosureType:
case NodeKind::EscapingAutoClosureType:
case NodeKind::DifferentiableFunctionType:
case NodeKind::EscapingDifferentiableFunctionType:
case NodeKind::LinearFunctionType:
case NodeKind::EscapingLinearFunctionType:
case NodeKind::FunctionType: {
if (Node->getNumChildren() < 2)
return BuiltType();
Expand All @@ -507,6 +511,15 @@ class TypeDecoder {
flags.withConvention(FunctionMetadataConvention::CFunctionPointer);
} else if (Node->getKind() == NodeKind::ThinFunctionType) {
flags = flags.withConvention(FunctionMetadataConvention::Thin);
} else if (Node->getKind() == NodeKind::DifferentiableFunctionType ||
Node->getKind() ==
NodeKind::EscapingDifferentiableFunctionType) {
flags = flags.withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind::Normal);
} else if (Node->getKind() == NodeKind::LinearFunctionType ||
Node->getKind() == NodeKind::EscapingLinearFunctionType) {
flags = flags.withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind::Linear);
}

bool isThrow =
Expand All @@ -527,7 +540,11 @@ class TypeDecoder {
.withEscaping(
Node->getKind() == NodeKind::FunctionType ||
Node->getKind() == NodeKind::EscapingAutoClosureType ||
Node->getKind() == NodeKind::EscapingObjCBlock);
Node->getKind() == NodeKind::EscapingObjCBlock ||
Node->getKind() ==
NodeKind::EscapingDifferentiableFunctionType ||
Node->getKind() ==
NodeKind::EscapingLinearFunctionType);

auto result = decodeMangledType(Node->getChild(isThrow ? 2 : 1));
if (!result) return BuiltType();
Expand Down
20 changes: 16 additions & 4 deletions lib/AST/ASTDemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ Type ASTBuilder::createFunctionType(
auto parameterFlags = ParameterTypeFlags()
.withValueOwnership(ownership)
.withVariadic(flags.isVariadic())
.withAutoClosure(flags.isAutoClosure());
.withAutoClosure(flags.isAutoClosure())
.withNoDerivative(flags.isNoDerivative());

funcParams.push_back(AnyFunctionType::Param(type, label, parameterFlags));
}
Expand All @@ -386,16 +387,27 @@ Type ASTBuilder::createFunctionType(
break;
}

DifferentiabilityKind diffKind;
switch (flags.getDifferentiabilityKind()) {
case FunctionMetadataDifferentiabilityKind::NonDifferentiable:
diffKind = DifferentiabilityKind::NonDifferentiable;
break;
case FunctionMetadataDifferentiabilityKind::Normal:
diffKind = DifferentiabilityKind::Normal;
break;
case FunctionMetadataDifferentiabilityKind::Linear:
diffKind = DifferentiabilityKind::Linear;
break;
}

auto noescape =
(representation == FunctionTypeRepresentation::Swift
|| representation == FunctionTypeRepresentation::Block)
&& !flags.isEscaping();

FunctionType::ExtInfo incompleteExtInfo(
FunctionTypeRepresentation::Swift,
noescape, flags.throws(),
DifferentiabilityKind::NonDifferentiable,
/*clangFunctionType*/nullptr);
noescape, flags.throws(), diffKind, /*clangFunctionType*/nullptr);

const clang::Type *clangFunctionType = nullptr;
if (representation == FunctionTypeRepresentation::CFunctionPointer)
Expand Down
24 changes: 24 additions & 0 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,18 @@ void ASTMangler::appendImplFunctionType(SILFunctionType *fn) {
if (!fn->isNoEscape())
OpArgs.push_back('e');

// Differentiability kind.
switch (fn->getExtInfo().getDifferentiabilityKind()) {
case DifferentiabilityKind::NonDifferentiable:
break;
case DifferentiabilityKind::Normal:
OpArgs.push_back('d');
break;
case DifferentiabilityKind::Linear:
OpArgs.push_back('l');
break;
}

// <impl-callee-convention>
if (fn->getExtInfo().hasContext()) {
OpArgs.push_back(getParamConvention(fn->getCalleeConvention()));
Expand Down Expand Up @@ -2117,6 +2129,18 @@ void ASTMangler::appendFunctionType(AnyFunctionType *fn, bool isAutoClosure,
case AnyFunctionType::Representation::Thin:
return appendOperator("Xf");
case AnyFunctionType::Representation::Swift:
if (fn->getDifferentiabilityKind() == DifferentiabilityKind::Normal) {
if (fn->isNoEscape())
return appendOperator("XF");
else
return appendOperator("XG");
}
if (fn->getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
if (fn->isNoEscape())
return appendOperator("XH");
else
return appendOperator("XI");
}
if (isAutoClosure) {
if (fn->isNoEscape())
return appendOperator("XK");
Expand Down
13 changes: 13 additions & 0 deletions lib/Demangling/Demangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,11 @@ NodePointer Demangler::demangleImplFunctionType() {
if (nextIf('e'))
type->addChild(createNode(Node::Kind::ImplEscaping), *this);

if (nextIf('d'))
type->addChild(createNode(Node::Kind::ImplDifferentiable), *this);
if (nextIf('l'))
type->addChild(createNode(Node::Kind::ImplLinear), *this);

const char *CAttr = nullptr;
switch (nextChar()) {
case 'y': CAttr = "@callee_unowned"; break;
Expand Down Expand Up @@ -2791,6 +2796,14 @@ NodePointer Demangler::demangleSpecialType() {
return popFunctionType(Node::Kind::ObjCBlock);
case 'C':
return popFunctionType(Node::Kind::CFunctionPointer);
case 'F':
return popFunctionType(Node::Kind::DifferentiableFunctionType);
case 'G':
return popFunctionType(Node::Kind::EscapingDifferentiableFunctionType);
case 'H':
return popFunctionType(Node::Kind::LinearFunctionType);
case 'I':
return popFunctionType(Node::Kind::EscapingLinearFunctionType);
case 'o':
return createType(createWithChild(Node::Kind::Unowned,
popNode(Node::Kind::Type)));
Expand Down
36 changes: 36 additions & 0 deletions lib/Demangling/NodePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ class NodePrinter {
case Node::Kind::DependentPseudogenericSignature:
case Node::Kind::Destructor:
case Node::Kind::DidSet:
case Node::Kind::DifferentiableFunctionType:
case Node::Kind::EscapingDifferentiableFunctionType:
case Node::Kind::LinearFunctionType:
case Node::Kind::EscapingLinearFunctionType:
case Node::Kind::DirectMethodReferenceAttribute:
case Node::Kind::Directness:
case Node::Kind::DynamicAttribute:
Expand Down Expand Up @@ -386,6 +390,8 @@ class NodePrinter {
case Node::Kind::Index:
case Node::Kind::IVarInitializer:
case Node::Kind::IVarDestroyer:
case Node::Kind::ImplDifferentiable:
case Node::Kind::ImplLinear:
case Node::Kind::ImplEscaping:
case Node::Kind::ImplConvention:
case Node::Kind::ImplFunctionAttribute:
Expand Down Expand Up @@ -1234,6 +1240,22 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
Printer << "@convention(thin) ";
printFunctionType(nullptr, Node);
return nullptr;
case Node::Kind::DifferentiableFunctionType:
Printer << "@differentiable ";
printFunctionType(nullptr, Node);
return nullptr;
case Node::Kind::EscapingDifferentiableFunctionType:
Printer << "@escaping @differentiable ";
printFunctionType(nullptr, Node);
return nullptr;
case Node::Kind::LinearFunctionType:
Printer << "@differentiable(linear) ";
printFunctionType(nullptr, Node);
return nullptr;
case Node::Kind::EscapingLinearFunctionType:
Printer << "@escaping @differentiable(linear) ";
printFunctionType(nullptr, Node);
return nullptr;
case Node::Kind::FunctionType:
case Node::Kind::UncurriedFunctionType:
printFunctionType(nullptr, Node);
Expand Down Expand Up @@ -2026,6 +2048,12 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
return nullptr;
case Node::Kind::LabelList:
return nullptr;
case Node::Kind::ImplDifferentiable:
Printer << "@differentiable";
return nullptr;
case Node::Kind::ImplLinear:
Printer << "@differentiable(linear)";
return nullptr;
case Node::Kind::ImplEscaping:
Printer << "@escaping";
return nullptr;
Expand Down Expand Up @@ -2527,6 +2555,14 @@ void NodePrinter::printEntityType(NodePointer Entity, NodePointer type,
Printer << ' ';
type = dependentType->getFirstChild();
}
if (type->getKind() == Node::Kind::DifferentiableFunctionType)
Printer << "@differentiable ";
else if (type->getKind() == Node::Kind::EscapingDifferentiableFunctionType)
Printer << "@escaping @differentiable ";
else if (type->getKind() == Node::Kind::LinearFunctionType)
Printer << "@differentiable(linear) ";
else if (type->getKind() == Node::Kind::EscapingLinearFunctionType)
Printer << "@escaping @differentiable(linear) ";
printFunctionType(labelList, type);
} else {
print(type);
Expand Down
30 changes: 30 additions & 0 deletions lib/Demangling/OldRemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,26 @@ void Remangler::mangleThinFunctionType(Node *node) {
mangleChildNodes(node); // argument tuple, result type
}

void Remangler::mangleDifferentiableFunctionType(Node *node) {
Buffer << "XF";
mangleChildNodes(node); // argument tuple, result type
}

void Remangler::mangleEscapingDifferentiableFunctionType(Node *node) {
Buffer << "XG";
mangleChildNodes(node); // argument tuple, result type
}

void Remangler::mangleLinearFunctionType(Node *node) {
Buffer << "XH";
mangleChildNodes(node); // argument tuple, result type
}

void Remangler::mangleEscapingLinearFunctionType(Node *node) {
Buffer << "XI";
mangleChildNodes(node); // argument tuple, result type
}

void Remangler::mangleArgumentTuple(Node *node) {
mangleSingleChildNode(node);
}
Expand Down Expand Up @@ -1258,6 +1278,16 @@ void Remangler::mangleImplYield(Node *node) {
mangleChildNodes(node); // impl convention, type
}

void Remangler::mangleImplDifferentiable(Node *node) {
// TODO(TF-750): Check if this code path actually triggers and add a test.
Buffer << 'd';
}

void Remangler::mangleImplLinear(Node *node) {
// TODO(TF-750): Check if this code path actually triggers and add a test.
Buffer << 'l';
}

void Remangler::mangleImplEscaping(Node *node) {
// The old mangler does not encode escaping.
}
Expand Down
Loading