Skip to content

[ABI] [AutoDiff] Fix function differentiability kind metadata and mangling. #36601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/ABI/Mangling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -559,24 +559,24 @@ Types
FUNCTION-KIND ::= 'zC' C-TYPE // C function pointer type with with non-canonical C type
FUNCTION-KIND ::= 'A' // @auto_closure function type (escaping)
FUNCTION-KIND ::= 'E' // function type (noescape)
FUNCTION-KIND ::= 'F' // @differentiable function type
FUNCTION-KIND ::= 'G' // @differentiable function type (escaping)
FUNCTION-KIND ::= 'H' // @differentiable(_linear) function type
FUNCTION-KIND ::= 'I' // @differentiable(_linear) function type (escaping)

C-TYPE is mangled according to the Itanium ABI, and prefixed with the length.
Non-ASCII identifiers are preserved as-is; we do not use Punycode.

function-signature ::= params-type params-type async? sendable? throws? // results and parameters
function-signature ::= params-type params-type async? sendable? throws? differentiable? // results and parameters

params-type ::= type 'z'? 'h'? // tuple in case of multiple parameters or a single parameter with a single tuple type
params-type ::= type 'z'? 'h'? // tuple in case of multiple parameters or a single parameter with a single tuple type
// with optional inout convention, shared convention. parameters don't have labels,
// they are mangled separately as part of the entity.
params-type ::= empty-list // shortcut for no parameters
params-type ::= empty-list // shortcut for no parameters

sendable ::= 'J' // @Sendable on function types
sendable ::= 'J' // @Sendable on function types
async ::= 'Y' // 'async' annotation on function types
throws ::= 'K' // 'throws' annotation on function types
differentiable ::= 'jf' // @differentiable(_forward) on function type
differentiable ::= 'jr' // @differentiable(reverse) on function type
differentiable ::= 'jd' // @differentiable on function type
differentiable ::= 'jl' // @differentiable(_linear) on function type

type-list ::= list-type '_' list-type* // list of types
type-list ::= empty-list
Expand Down
23 changes: 23 additions & 0 deletions include/swift/ABI/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,7 @@ struct TargetFunctionTypeMetadata : public TargetMetadata<Runtime> {
bool isAsync() const { return Flags.isAsync(); }
bool isThrowing() const { return Flags.isThrowing(); }
bool isSendable() const { return Flags.isSendable(); }
bool isDifferentiable() const { return Flags.isDifferentiable(); }
bool hasParameterFlags() const { return Flags.hasParameterFlags(); }
bool isEscaping() const { return Flags.isEscaping(); }

Expand All @@ -1675,6 +1676,28 @@ struct TargetFunctionTypeMetadata : public TargetMetadata<Runtime> {
return reinterpret_cast<const uint32_t *>(getParameters() +
getNumParameters());
}

TargetFunctionMetadataDifferentiabilityKind<StoredSize> *
getDifferentiabilityKindAddress() {
assert(isDifferentiable());
void *previousEndAddr = hasParameterFlags()
? reinterpret_cast<void *>(getParameterFlags() + getNumParameters())
: reinterpret_cast<void *>(getParameters() + getNumParameters());
return reinterpret_cast<
TargetFunctionMetadataDifferentiabilityKind<StoredSize> *>(
llvm::alignAddr(previousEndAddr,
llvm::Align(alignof(typename Runtime::StoredPointer))));
}

TargetFunctionMetadataDifferentiabilityKind<StoredSize>
getDifferentiabilityKind() const {
if (isDifferentiable()) {
return *const_cast<TargetFunctionTypeMetadata<Runtime> *>(this)
->getDifferentiabilityKindAddress();
}
return TargetFunctionMetadataDifferentiabilityKind<StoredSize>
::NonDifferentiable;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use swift/ABI/TrailingObjects.h to help with the trailing allocation management here. Where do you initialize the tail storage?

Copy link
Contributor Author

@rxwei rxwei Mar 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, initially I wanted to use that but noticed that parameters and parameter flags are tail-allocated manually, so I followed the same approach. Would it be odd to mix both, no? The allocation size calculation is in FunctionCacheEntry::getExtraAllocationSize(); I aligned the previously calculated size (because there are a few 32-bit ints for parameter flags) and appended a differentiability word after it.

};
using FunctionTypeMetadata = TargetFunctionTypeMetadata<InProcess>;

Expand Down
53 changes: 34 additions & 19 deletions include/swift/ABI/MetadataValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -787,13 +787,29 @@ enum class FunctionMetadataConvention: uint8_t {

/// Differentiability kind for function type metadata.
/// Duplicates `DifferentiabilityKind` in AST/AutoDiff.h.
enum class FunctionMetadataDifferentiabilityKind: uint8_t {
NonDifferentiable = 0b00000,
Forward = 0b00001,
Reverse = 0b00010,
Normal = 0b00011,
Linear = 0b10000,
template <typename int_type>
struct TargetFunctionMetadataDifferentiabilityKind {
enum Value : int_type {
NonDifferentiable = 0,
Forward = 1,
Reverse = 2,
Normal = 3,
Linear = 4,
} Value;

constexpr TargetFunctionMetadataDifferentiabilityKind(
enum Value value = NonDifferentiable) : Value(value) {}

int_type getIntValue() const {
return (int_type)Value;
}

bool isDifferentiable() const {
return Value != NonDifferentiable;
}
};
using FunctionMetadataDifferentiabilityKind =
TargetFunctionMetadataDifferentiabilityKind<size_t>;

/// Flags in a function type metadata record.
template <typename int_type>
Expand All @@ -808,8 +824,7 @@ class TargetFunctionTypeFlags {
ThrowsMask = 0x01000000U,
ParamFlagsMask = 0x02000000U,
EscapingMask = 0x04000000U,
DifferentiabilityMask = 0x98000000U,
DifferentiabilityShift = 27U,
DifferentiableMask = 0x08000000U,
AsyncMask = 0x20000000U,
SendableMask = 0x40000000U,
};
Expand Down Expand Up @@ -842,10 +857,10 @@ class TargetFunctionTypeFlags {
(throws ? ThrowsMask : 0));
}

constexpr TargetFunctionTypeFlags<int_type> withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind differentiabilityKind) const {
return TargetFunctionTypeFlags((Data & ~DifferentiabilityMask)
| (int_type(differentiabilityKind) << DifferentiabilityShift));
constexpr TargetFunctionTypeFlags<int_type>
withDifferentiable(bool differentiable) const {
return TargetFunctionTypeFlags<int_type>((Data & ~DifferentiableMask) |
(differentiable ? DifferentiableMask : 0));
}

constexpr TargetFunctionTypeFlags<int_type>
Expand Down Expand Up @@ -888,13 +903,7 @@ class TargetFunctionTypeFlags {
bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }

bool isDifferentiable() const {
return getDifferentiabilityKind() !=
FunctionMetadataDifferentiabilityKind::NonDifferentiable;
}

FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
return FunctionMetadataDifferentiabilityKind(
(Data & DifferentiabilityMask) >> DifferentiabilityShift);
return bool (Data & DifferentiableMask);
}

int_type getIntValue() const {
Expand Down Expand Up @@ -947,6 +956,12 @@ class TargetParameterTypeFlags {
(Data & ~AutoClosureMask) | (isAutoClosure ? AutoClosureMask : 0));
}

constexpr TargetParameterTypeFlags<int_type>
withNoDerivative(bool isNoDerivative) const {
return TargetParameterTypeFlags<int_type>(
(Data & ~NoDerivativeMask) | (isNoDerivative ? NoDerivativeMask : 0));
}

bool isNone() const { return Data == 0; }
bool isVariadic() const { return Data & VariadicMask; }
bool isAutoClosure() const { return Data & AutoClosureMask; }
Expand Down
6 changes: 4 additions & 2 deletions include/swift/AST/ASTDemangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ class ASTBuilder {

Type createTupleType(ArrayRef<Type> eltTypes, StringRef labels);

Type createFunctionType(ArrayRef<Demangle::FunctionParam<Type>> params,
Type output, FunctionTypeFlags flags);
Type createFunctionType(
ArrayRef<Demangle::FunctionParam<Type>> params,
Type output, FunctionTypeFlags flags,
FunctionMetadataDifferentiabilityKind diffKind);

Type createImplFunctionType(
Demangle::ImplParameterConvention calleeConvention,
Expand Down
5 changes: 1 addition & 4 deletions include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ NODE(DependentProtocolConformanceInherited)
NODE(DependentProtocolConformanceAssociated)
CONTEXT_NODE(Destructor)
CONTEXT_NODE(DidSet)
NODE(DifferentiableFunctionType)
NODE(EscapingDifferentiableFunctionType)
NODE(LinearFunctionType)
NODE(EscapingLinearFunctionType)
NODE(Directness)
NODE(DynamicAttribute)
NODE(DirectMethodReferenceAttribute)
Expand All @@ -86,6 +82,7 @@ NODE(ErrorType)
NODE(EscapingAutoClosureType)
NODE(NoEscapeFunctionType)
NODE(ConcurrentFunctionType)
NODE(DifferentiableFunctionType)
NODE(ExistentialMetatype)
CONTEXT_NODE(ExplicitClosure)
CONTEXT_NODE(Extension)
Expand Down
1 change: 1 addition & 0 deletions include/swift/Demangling/Demangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ class Demangler : public NodeFactory {
NodePointer demangleAutoDiffSelfReorderingReabstractionThunk();
NodePointer demangleDifferentiabilityWitness();
NodePointer demangleIndexSubset();
NodePointer demangleDifferentiableFunctionType();

bool demangleBoundGenerics(Vector<NodePointer> &TypeListList,
NodePointer &RetroactiveConformances);
Expand Down
55 changes: 35 additions & 20 deletions include/swift/Demangling/TypeDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ class ImplFunctionTypeFlags {

bool isPseudogeneric() const { return Pseudogeneric; }

bool isDifferentiable() const {
return getDifferentiabilityKind() !=
ImplFunctionDifferentiabilityKind::NonDifferentiable;
}

ImplFunctionDifferentiabilityKind getDifferentiabilityKind() const {
return ImplFunctionDifferentiabilityKind(DifferentiabilityKind);
}
Expand Down Expand Up @@ -708,10 +713,6 @@ class TypeDecoder {
case NodeKind::NoEscapeFunctionType:
case NodeKind::AutoClosureType:
case NodeKind::EscapingAutoClosureType:
case NodeKind::DifferentiableFunctionType:
case NodeKind::EscapingDifferentiableFunctionType:
case NodeKind::LinearFunctionType:
case NodeKind::EscapingLinearFunctionType:
case NodeKind::FunctionType: {
if (Node->getNumChildren() < 2)
return MAKE_NODE_TYPE_ERROR(Node,
Expand All @@ -727,15 +728,6 @@ class TypeDecoder {
flags.withConvention(FunctionMetadataConvention::CFunctionPointer);
} else if (Node->getKind() == NodeKind::ThinFunctionType) {
flags = flags.withConvention(FunctionMetadataConvention::Thin);
} else if (Node->getKind() == NodeKind::DifferentiableFunctionType ||
Node->getKind() ==
NodeKind::EscapingDifferentiableFunctionType) {
flags = flags.withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind::Reverse);
} else if (Node->getKind() == NodeKind::LinearFunctionType ||
Node->getKind() == NodeKind::EscapingLinearFunctionType) {
flags = flags.withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind::Linear);
}

unsigned firstChildIdx = 0;
Expand Down Expand Up @@ -766,8 +758,34 @@ class TypeDecoder {
++firstChildIdx;
}

FunctionMetadataDifferentiabilityKind diffKind;
if (Node->getChild(firstChildIdx)->getKind() ==
NodeKind::DifferentiableFunctionType) {
auto mangledDiffKind = (MangledDifferentiabilityKind)
Node->getChild(firstChildIdx)->getIndex();
switch (mangledDiffKind) {
case MangledDifferentiabilityKind::NonDifferentiable:
assert(false && "Unexpected case NonDifferentiable");
break;
case MangledDifferentiabilityKind::Forward:
diffKind = FunctionMetadataDifferentiabilityKind::Forward;
break;
case MangledDifferentiabilityKind::Reverse:
diffKind = FunctionMetadataDifferentiabilityKind::Reverse;
break;
case MangledDifferentiabilityKind::Normal:
diffKind = FunctionMetadataDifferentiabilityKind::Normal;
break;
case MangledDifferentiabilityKind::Linear:
diffKind = FunctionMetadataDifferentiabilityKind::Linear;
break;
}
++firstChildIdx;
}

flags = flags.withConcurrent(isSendable)
.withAsync(isAsync).withThrows(isThrow);
.withAsync(isAsync).withThrows(isThrow)
.withDifferentiable(diffKind.isDifferentiable());

if (Node->getNumChildren() < firstChildIdx + 2)
return MAKE_NODE_TYPE_ERROR(Node,
Expand All @@ -786,16 +804,13 @@ class TypeDecoder {
.withEscaping(
Node->getKind() == NodeKind::FunctionType ||
Node->getKind() == NodeKind::EscapingAutoClosureType ||
Node->getKind() == NodeKind::EscapingObjCBlock ||
Node->getKind() ==
NodeKind::EscapingDifferentiableFunctionType ||
Node->getKind() ==
NodeKind::EscapingLinearFunctionType);
Node->getKind() == NodeKind::EscapingObjCBlock);

auto result = decodeMangledType(Node->getChild(firstChildIdx+1));
if (result.isError())
return result;
return Builder.createFunctionType(parameters, result.getType(), flags);
return Builder.createFunctionType(
parameters, result.getType(), flags, diffKind);
}
case NodeKind::ImplFunctionType: {
auto calleeConvention = ImplParameterConvention::Direct_Unowned;
Expand Down
22 changes: 15 additions & 7 deletions include/swift/Reflection/TypeRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,11 @@ class FunctionTypeRef final : public TypeRef {
std::vector<Param> Parameters;
const TypeRef *Result;
FunctionTypeFlags Flags;
FunctionMetadataDifferentiabilityKind DifferentiabilityKind;

static TypeRefID Profile(const std::vector<Param> &Parameters,
const TypeRef *Result, FunctionTypeFlags Flags) {
const TypeRef *Result, FunctionTypeFlags Flags,
FunctionMetadataDifferentiabilityKind DiffKind) {
TypeRefID ID;
for (const auto &Param : Parameters) {
ID.addString(Param.getLabel().str());
Expand All @@ -472,20 +474,22 @@ class FunctionTypeRef final : public TypeRef {
}
ID.addPointer(Result);
ID.addInteger(static_cast<uint64_t>(Flags.getIntValue()));
ID.addInteger(static_cast<uint64_t>(DiffKind.getIntValue()));
return ID;
}

public:
FunctionTypeRef(std::vector<Param> Params, const TypeRef *Result,
FunctionTypeFlags Flags)
FunctionTypeFlags Flags,
FunctionMetadataDifferentiabilityKind DiffKind)
: TypeRef(TypeRefKind::Function), Parameters(Params), Result(Result),
Flags(Flags) {}
Flags(Flags), DifferentiabilityKind(DiffKind) {}

template <typename Allocator>
static const FunctionTypeRef *create(Allocator &A, std::vector<Param> Params,
const TypeRef *Result,
FunctionTypeFlags Flags) {
FIND_OR_CREATE_TYPEREF(A, FunctionTypeRef, Params, Result, Flags);
static const FunctionTypeRef *create(
Allocator &A, std::vector<Param> Params, const TypeRef *Result,
FunctionTypeFlags Flags, FunctionMetadataDifferentiabilityKind DiffKind) {
FIND_OR_CREATE_TYPEREF(A, FunctionTypeRef, Params, Result, Flags, DiffKind);
}

const std::vector<Param> &getParameters() const { return Parameters; };
Expand All @@ -498,6 +502,10 @@ class FunctionTypeRef final : public TypeRef {
return Flags;
}

FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
return DifferentiabilityKind;
}

static bool classof(const TypeRef *TR) {
return TR->getKind() == TypeRefKind::Function;
}
Expand Down
27 changes: 24 additions & 3 deletions include/swift/Reflection/TypeRefBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,9 @@ class TypeRefBuilder {

const FunctionTypeRef *createFunctionType(
llvm::ArrayRef<remote::FunctionParam<const TypeRef *>> params,
const TypeRef *result, FunctionTypeFlags flags) {
return FunctionTypeRef::create(*this, params, result, flags);
const TypeRef *result, FunctionTypeFlags flags,
FunctionMetadataDifferentiabilityKind diffKind) {
return FunctionTypeRef::create(*this, params, result, flags, diffKind);
}
using BuiltSubstitution = std::pair<const TypeRef *, const TypeRef *>;
using BuiltRequirement = TypeRefRequirement;
Expand Down Expand Up @@ -454,9 +455,29 @@ class TypeRefBuilder {

funcFlags = funcFlags.withConcurrent(flags.isSendable());
funcFlags = funcFlags.withAsync(flags.isAsync());
funcFlags = funcFlags.withDifferentiable(flags.isDifferentiable());

FunctionMetadataDifferentiabilityKind diffKind;
switch (flags.getDifferentiabilityKind()) {
case ImplFunctionDifferentiabilityKind::NonDifferentiable:
diffKind = FunctionMetadataDifferentiabilityKind::NonDifferentiable;
break;
case ImplFunctionDifferentiabilityKind::Forward:
diffKind = FunctionMetadataDifferentiabilityKind::Forward;
break;
case ImplFunctionDifferentiabilityKind::Reverse:
diffKind = FunctionMetadataDifferentiabilityKind::Reverse;
break;
case ImplFunctionDifferentiabilityKind::Normal:
diffKind = FunctionMetadataDifferentiabilityKind::Normal;
break;
case ImplFunctionDifferentiabilityKind::Linear:
diffKind = FunctionMetadataDifferentiabilityKind::Linear;
break;
}

auto result = createTupleType({}, "");
return FunctionTypeRef::create(*this, {}, result, funcFlags);
return FunctionTypeRef::create(*this, {}, result, funcFlags, diffKind);
}

const ProtocolCompositionTypeRef *
Expand Down
Loading