Skip to content

Commit f5282ed

Browse files
committed
Fix function differentiability kind metadata and mangling.
* Move differentiability kinds from target function type metadata to trailing objects so that we don't exhaust all remaining bits of function type metadata. * Fix mangling of different differentiability kinds in function types. Mangle it like `ConcurrentFunctionType` so that we can drop special cases for escaping functions. Resolves rdar://75240064.
1 parent 8f52c26 commit f5282ed

File tree

28 files changed

+596
-232
lines changed

28 files changed

+596
-232
lines changed

docs/ABI/Mangling.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -559,24 +559,24 @@ Types
559559
FUNCTION-KIND ::= 'zC' C-TYPE // C function pointer type with with non-canonical C type
560560
FUNCTION-KIND ::= 'A' // @auto_closure function type (escaping)
561561
FUNCTION-KIND ::= 'E' // function type (noescape)
562-
FUNCTION-KIND ::= 'F' // @differentiable function type
563-
FUNCTION-KIND ::= 'G' // @differentiable function type (escaping)
564-
FUNCTION-KIND ::= 'H' // @differentiable(_linear) function type
565-
FUNCTION-KIND ::= 'I' // @differentiable(_linear) function type (escaping)
566562

567563
C-TYPE is mangled according to the Itanium ABI, and prefixed with the length.
568564
Non-ASCII identifiers are preserved as-is; we do not use Punycode.
569565

570-
function-signature ::= params-type params-type async? sendable? throws? // results and parameters
566+
function-signature ::= params-type params-type async? sendable? throws? differentiable? // results and parameters
571567

572-
params-type ::= type 'z'? 'h'? // tuple in case of multiple parameters or a single parameter with a single tuple type
568+
params-type ::= type 'z'? 'h'? // tuple in case of multiple parameters or a single parameter with a single tuple type
573569
// with optional inout convention, shared convention. parameters don't have labels,
574570
// they are mangled separately as part of the entity.
575-
params-type ::= empty-list // shortcut for no parameters
571+
params-type ::= empty-list // shortcut for no parameters
576572

577-
sendable ::= 'J' // @Sendable on function types
573+
sendable ::= 'J' // @Sendable on function types
578574
async ::= 'Y' // 'async' annotation on function types
579575
throws ::= 'K' // 'throws' annotation on function types
576+
differentiable ::= 'jf' // @differentiable(_forward) on function type
577+
differentiable ::= 'jr' // @differentiable(reverse) on function type
578+
differentiable ::= 'jd' // @differentiable on function type
579+
differentiable ::= 'jl' // @differentiable(_linear) on function type
580580

581581
type-list ::= list-type '_' list-type* // list of types
582582
type-list ::= empty-list

include/swift/ABI/Metadata.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,6 +1658,7 @@ struct TargetFunctionTypeMetadata : public TargetMetadata<Runtime> {
16581658
bool isAsync() const { return Flags.isAsync(); }
16591659
bool isThrowing() const { return Flags.isThrowing(); }
16601660
bool isSendable() const { return Flags.isSendable(); }
1661+
bool isDifferentiable() const { return Flags.isDifferentiable(); }
16611662
bool hasParameterFlags() const { return Flags.hasParameterFlags(); }
16621663
bool isEscaping() const { return Flags.isEscaping(); }
16631664

@@ -1675,6 +1676,28 @@ struct TargetFunctionTypeMetadata : public TargetMetadata<Runtime> {
16751676
return reinterpret_cast<const uint32_t *>(getParameters() +
16761677
getNumParameters());
16771678
}
1679+
1680+
TargetFunctionMetadataDifferentiabilityKind<StoredSize> *
1681+
getDifferentiabilityKindAddress() {
1682+
assert(isDifferentiable());
1683+
void *previousEndAddr = hasParameterFlags()
1684+
? reinterpret_cast<void *>(getParameterFlags() + getNumParameters())
1685+
: reinterpret_cast<void *>(getParameters() + getNumParameters());
1686+
return reinterpret_cast<
1687+
TargetFunctionMetadataDifferentiabilityKind<StoredSize> *>(
1688+
llvm::alignAddr(previousEndAddr,
1689+
llvm::Align(alignof(typename Runtime::StoredPointer))));
1690+
}
1691+
1692+
TargetFunctionMetadataDifferentiabilityKind<StoredSize>
1693+
getDifferentiabilityKind() const {
1694+
if (isDifferentiable()) {
1695+
return *const_cast<TargetFunctionTypeMetadata<Runtime> *>(this)
1696+
->getDifferentiabilityKindAddress();
1697+
}
1698+
return TargetFunctionMetadataDifferentiabilityKind<StoredSize>
1699+
::NonDifferentiable;
1700+
}
16781701
};
16791702
using FunctionTypeMetadata = TargetFunctionTypeMetadata<InProcess>;
16801703

include/swift/ABI/MetadataValues.h

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -787,13 +787,29 @@ enum class FunctionMetadataConvention: uint8_t {
787787

788788
/// Differentiability kind for function type metadata.
789789
/// Duplicates `DifferentiabilityKind` in AST/AutoDiff.h.
790-
enum class FunctionMetadataDifferentiabilityKind: uint8_t {
791-
NonDifferentiable = 0b00000,
792-
Forward = 0b00001,
793-
Reverse = 0b00010,
794-
Normal = 0b00011,
795-
Linear = 0b10000,
790+
template <typename int_type>
791+
struct TargetFunctionMetadataDifferentiabilityKind {
792+
enum Value : int_type {
793+
NonDifferentiable = 0,
794+
Forward = 1,
795+
Reverse = 2,
796+
Normal = 3,
797+
Linear = 4,
798+
} Value;
799+
800+
constexpr TargetFunctionMetadataDifferentiabilityKind(
801+
enum Value value = NonDifferentiable) : Value(value) {}
802+
803+
int_type getIntValue() const {
804+
return (int_type)Value;
805+
}
806+
807+
bool isDifferentiable() const {
808+
return Value != NonDifferentiable;
809+
}
796810
};
811+
using FunctionMetadataDifferentiabilityKind =
812+
TargetFunctionMetadataDifferentiabilityKind<size_t>;
797813

798814
/// Flags in a function type metadata record.
799815
template <typename int_type>
@@ -808,8 +824,7 @@ class TargetFunctionTypeFlags {
808824
ThrowsMask = 0x01000000U,
809825
ParamFlagsMask = 0x02000000U,
810826
EscapingMask = 0x04000000U,
811-
DifferentiabilityMask = 0x98000000U,
812-
DifferentiabilityShift = 27U,
827+
DifferentiableMask = 0x08000000U,
813828
AsyncMask = 0x20000000U,
814829
SendableMask = 0x40000000U,
815830
};
@@ -842,10 +857,10 @@ class TargetFunctionTypeFlags {
842857
(throws ? ThrowsMask : 0));
843858
}
844859

845-
constexpr TargetFunctionTypeFlags<int_type> withDifferentiabilityKind(
846-
FunctionMetadataDifferentiabilityKind differentiabilityKind) const {
847-
return TargetFunctionTypeFlags((Data & ~DifferentiabilityMask)
848-
| (int_type(differentiabilityKind) << DifferentiabilityShift));
860+
constexpr TargetFunctionTypeFlags<int_type>
861+
withDifferentiable(bool differentiable) const {
862+
return TargetFunctionTypeFlags<int_type>((Data & ~DifferentiableMask) |
863+
(differentiable ? DifferentiableMask : 0));
849864
}
850865

851866
constexpr TargetFunctionTypeFlags<int_type>
@@ -888,13 +903,7 @@ class TargetFunctionTypeFlags {
888903
bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }
889904

890905
bool isDifferentiable() const {
891-
return getDifferentiabilityKind() !=
892-
FunctionMetadataDifferentiabilityKind::NonDifferentiable;
893-
}
894-
895-
FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
896-
return FunctionMetadataDifferentiabilityKind(
897-
(Data & DifferentiabilityMask) >> DifferentiabilityShift);
906+
return bool (Data & DifferentiableMask);
898907
}
899908

900909
int_type getIntValue() const {
@@ -947,6 +956,12 @@ class TargetParameterTypeFlags {
947956
(Data & ~AutoClosureMask) | (isAutoClosure ? AutoClosureMask : 0));
948957
}
949958

959+
constexpr TargetParameterTypeFlags<int_type>
960+
withNoDerivative(bool isNoDerivative) const {
961+
return TargetParameterTypeFlags<int_type>(
962+
(Data & ~NoDerivativeMask) | (isNoDerivative ? NoDerivativeMask : 0));
963+
}
964+
950965
bool isNone() const { return Data == 0; }
951966
bool isVariadic() const { return Data & VariadicMask; }
952967
bool isAutoClosure() const { return Data & AutoClosureMask; }

include/swift/AST/ASTDemangler.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,10 @@ class ASTBuilder {
9797

9898
Type createTupleType(ArrayRef<Type> eltTypes, StringRef labels);
9999

100-
Type createFunctionType(ArrayRef<Demangle::FunctionParam<Type>> params,
101-
Type output, FunctionTypeFlags flags);
100+
Type createFunctionType(
101+
ArrayRef<Demangle::FunctionParam<Type>> params,
102+
Type output, FunctionTypeFlags flags,
103+
FunctionMetadataDifferentiabilityKind diffKind);
102104

103105
Type createImplFunctionType(
104106
Demangle::ImplParameterConvention calleeConvention,

include/swift/Demangling/DemangleNodes.def

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,6 @@ NODE(DependentProtocolConformanceInherited)
6969
NODE(DependentProtocolConformanceAssociated)
7070
CONTEXT_NODE(Destructor)
7171
CONTEXT_NODE(DidSet)
72-
NODE(DifferentiableFunctionType)
73-
NODE(EscapingDifferentiableFunctionType)
74-
NODE(LinearFunctionType)
75-
NODE(EscapingLinearFunctionType)
7672
NODE(Directness)
7773
NODE(DynamicAttribute)
7874
NODE(DirectMethodReferenceAttribute)
@@ -86,6 +82,7 @@ NODE(ErrorType)
8682
NODE(EscapingAutoClosureType)
8783
NODE(NoEscapeFunctionType)
8884
NODE(ConcurrentFunctionType)
85+
NODE(DifferentiableFunctionType)
8986
NODE(ExistentialMetatype)
9087
CONTEXT_NODE(ExplicitClosure)
9188
CONTEXT_NODE(Extension)

include/swift/Demangling/Demangler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ class Demangler : public NodeFactory {
575575
NodePointer demangleAutoDiffSelfReorderingReabstractionThunk();
576576
NodePointer demangleDifferentiabilityWitness();
577577
NodePointer demangleIndexSubset();
578+
NodePointer demangleDifferentiableFunctionType();
578579

579580
bool demangleBoundGenerics(Vector<NodePointer> &TypeListList,
580581
NodePointer &RetroactiveConformances);

include/swift/Demangling/TypeDecoder.h

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,11 @@ class ImplFunctionTypeFlags {
316316

317317
bool isPseudogeneric() const { return Pseudogeneric; }
318318

319+
bool isDifferentiable() const {
320+
return getDifferentiabilityKind() !=
321+
ImplFunctionDifferentiabilityKind::NonDifferentiable;
322+
}
323+
319324
ImplFunctionDifferentiabilityKind getDifferentiabilityKind() const {
320325
return ImplFunctionDifferentiabilityKind(DifferentiabilityKind);
321326
}
@@ -708,10 +713,6 @@ class TypeDecoder {
708713
case NodeKind::NoEscapeFunctionType:
709714
case NodeKind::AutoClosureType:
710715
case NodeKind::EscapingAutoClosureType:
711-
case NodeKind::DifferentiableFunctionType:
712-
case NodeKind::EscapingDifferentiableFunctionType:
713-
case NodeKind::LinearFunctionType:
714-
case NodeKind::EscapingLinearFunctionType:
715716
case NodeKind::FunctionType: {
716717
if (Node->getNumChildren() < 2)
717718
return MAKE_NODE_TYPE_ERROR(Node,
@@ -727,15 +728,6 @@ class TypeDecoder {
727728
flags.withConvention(FunctionMetadataConvention::CFunctionPointer);
728729
} else if (Node->getKind() == NodeKind::ThinFunctionType) {
729730
flags = flags.withConvention(FunctionMetadataConvention::Thin);
730-
} else if (Node->getKind() == NodeKind::DifferentiableFunctionType ||
731-
Node->getKind() ==
732-
NodeKind::EscapingDifferentiableFunctionType) {
733-
flags = flags.withDifferentiabilityKind(
734-
FunctionMetadataDifferentiabilityKind::Reverse);
735-
} else if (Node->getKind() == NodeKind::LinearFunctionType ||
736-
Node->getKind() == NodeKind::EscapingLinearFunctionType) {
737-
flags = flags.withDifferentiabilityKind(
738-
FunctionMetadataDifferentiabilityKind::Linear);
739731
}
740732

741733
unsigned firstChildIdx = 0;
@@ -766,8 +758,34 @@ class TypeDecoder {
766758
++firstChildIdx;
767759
}
768760

761+
FunctionMetadataDifferentiabilityKind diffKind;
762+
if (Node->getChild(firstChildIdx)->getKind() ==
763+
NodeKind::DifferentiableFunctionType) {
764+
auto mangledDiffKind = (MangledDifferentiabilityKind)
765+
Node->getChild(firstChildIdx)->getIndex();
766+
switch (mangledDiffKind) {
767+
case MangledDifferentiabilityKind::NonDifferentiable:
768+
assert(false && "Unexpected case NonDifferentiable");
769+
break;
770+
case MangledDifferentiabilityKind::Forward:
771+
diffKind = FunctionMetadataDifferentiabilityKind::Forward;
772+
break;
773+
case MangledDifferentiabilityKind::Reverse:
774+
diffKind = FunctionMetadataDifferentiabilityKind::Reverse;
775+
break;
776+
case MangledDifferentiabilityKind::Normal:
777+
diffKind = FunctionMetadataDifferentiabilityKind::Normal;
778+
break;
779+
case MangledDifferentiabilityKind::Linear:
780+
diffKind = FunctionMetadataDifferentiabilityKind::Linear;
781+
break;
782+
}
783+
++firstChildIdx;
784+
}
785+
769786
flags = flags.withConcurrent(isSendable)
770-
.withAsync(isAsync).withThrows(isThrow);
787+
.withAsync(isAsync).withThrows(isThrow)
788+
.withDifferentiable(diffKind.isDifferentiable());
771789

772790
if (Node->getNumChildren() < firstChildIdx + 2)
773791
return MAKE_NODE_TYPE_ERROR(Node,
@@ -786,16 +804,13 @@ class TypeDecoder {
786804
.withEscaping(
787805
Node->getKind() == NodeKind::FunctionType ||
788806
Node->getKind() == NodeKind::EscapingAutoClosureType ||
789-
Node->getKind() == NodeKind::EscapingObjCBlock ||
790-
Node->getKind() ==
791-
NodeKind::EscapingDifferentiableFunctionType ||
792-
Node->getKind() ==
793-
NodeKind::EscapingLinearFunctionType);
807+
Node->getKind() == NodeKind::EscapingObjCBlock);
794808

795809
auto result = decodeMangledType(Node->getChild(firstChildIdx+1));
796810
if (result.isError())
797811
return result;
798-
return Builder.createFunctionType(parameters, result.getType(), flags);
812+
return Builder.createFunctionType(
813+
parameters, result.getType(), flags, diffKind);
799814
}
800815
case NodeKind::ImplFunctionType: {
801816
auto calleeConvention = ImplParameterConvention::Direct_Unowned;

include/swift/Reflection/TypeRef.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,11 @@ class FunctionTypeRef final : public TypeRef {
461461
std::vector<Param> Parameters;
462462
const TypeRef *Result;
463463
FunctionTypeFlags Flags;
464+
FunctionMetadataDifferentiabilityKind DifferentiabilityKind;
464465

465466
static TypeRefID Profile(const std::vector<Param> &Parameters,
466-
const TypeRef *Result, FunctionTypeFlags Flags) {
467+
const TypeRef *Result, FunctionTypeFlags Flags,
468+
FunctionMetadataDifferentiabilityKind DiffKind) {
467469
TypeRefID ID;
468470
for (const auto &Param : Parameters) {
469471
ID.addString(Param.getLabel().str());
@@ -472,20 +474,22 @@ class FunctionTypeRef final : public TypeRef {
472474
}
473475
ID.addPointer(Result);
474476
ID.addInteger(static_cast<uint64_t>(Flags.getIntValue()));
477+
ID.addInteger(static_cast<uint64_t>(DiffKind.getIntValue()));
475478
return ID;
476479
}
477480

478481
public:
479482
FunctionTypeRef(std::vector<Param> Params, const TypeRef *Result,
480-
FunctionTypeFlags Flags)
483+
FunctionTypeFlags Flags,
484+
FunctionMetadataDifferentiabilityKind DiffKind)
481485
: TypeRef(TypeRefKind::Function), Parameters(Params), Result(Result),
482-
Flags(Flags) {}
486+
Flags(Flags), DifferentiabilityKind(DiffKind) {}
483487

484488
template <typename Allocator>
485-
static const FunctionTypeRef *create(Allocator &A, std::vector<Param> Params,
486-
const TypeRef *Result,
487-
FunctionTypeFlags Flags) {
488-
FIND_OR_CREATE_TYPEREF(A, FunctionTypeRef, Params, Result, Flags);
489+
static const FunctionTypeRef *create(
490+
Allocator &A, std::vector<Param> Params, const TypeRef *Result,
491+
FunctionTypeFlags Flags, FunctionMetadataDifferentiabilityKind DiffKind) {
492+
FIND_OR_CREATE_TYPEREF(A, FunctionTypeRef, Params, Result, Flags, DiffKind);
489493
}
490494

491495
const std::vector<Param> &getParameters() const { return Parameters; };
@@ -498,6 +502,10 @@ class FunctionTypeRef final : public TypeRef {
498502
return Flags;
499503
}
500504

505+
FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
506+
return DifferentiabilityKind;
507+
}
508+
501509
static bool classof(const TypeRef *TR) {
502510
return TR->getKind() == TypeRefKind::Function;
503511
}

include/swift/Reflection/TypeRefBuilder.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,9 @@ class TypeRefBuilder {
414414

415415
const FunctionTypeRef *createFunctionType(
416416
llvm::ArrayRef<remote::FunctionParam<const TypeRef *>> params,
417-
const TypeRef *result, FunctionTypeFlags flags) {
418-
return FunctionTypeRef::create(*this, params, result, flags);
417+
const TypeRef *result, FunctionTypeFlags flags,
418+
FunctionMetadataDifferentiabilityKind diffKind) {
419+
return FunctionTypeRef::create(*this, params, result, flags, diffKind);
419420
}
420421
using BuiltSubstitution = std::pair<const TypeRef *, const TypeRef *>;
421422
using BuiltRequirement = TypeRefRequirement;
@@ -454,9 +455,29 @@ class TypeRefBuilder {
454455

455456
funcFlags = funcFlags.withConcurrent(flags.isSendable());
456457
funcFlags = funcFlags.withAsync(flags.isAsync());
458+
funcFlags = funcFlags.withDifferentiable(flags.isDifferentiable());
459+
460+
FunctionMetadataDifferentiabilityKind diffKind;
461+
switch (flags.getDifferentiabilityKind()) {
462+
case ImplFunctionDifferentiabilityKind::NonDifferentiable:
463+
diffKind = FunctionMetadataDifferentiabilityKind::NonDifferentiable;
464+
break;
465+
case ImplFunctionDifferentiabilityKind::Forward:
466+
diffKind = FunctionMetadataDifferentiabilityKind::Forward;
467+
break;
468+
case ImplFunctionDifferentiabilityKind::Reverse:
469+
diffKind = FunctionMetadataDifferentiabilityKind::Reverse;
470+
break;
471+
case ImplFunctionDifferentiabilityKind::Normal:
472+
diffKind = FunctionMetadataDifferentiabilityKind::Normal;
473+
break;
474+
case ImplFunctionDifferentiabilityKind::Linear:
475+
diffKind = FunctionMetadataDifferentiabilityKind::Linear;
476+
break;
477+
}
457478

458479
auto result = createTupleType({}, "");
459-
return FunctionTypeRef::create(*this, {}, result, funcFlags);
480+
return FunctionTypeRef::create(*this, {}, result, funcFlags, diffKind);
460481
}
461482

462483
const ProtocolCompositionTypeRef *

0 commit comments

Comments
 (0)