Skip to content

Commit 871f916

Browse files
committed
SILDifferentiableFunctionType
1 parent cdf3234 commit 871f916

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1227
-114
lines changed

include/swift/AST/ASTContext.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,9 @@ class ASTContext final {
494494
/// has been imported. Otherwise, this returns null.
495495
StructDecl *getTensorDataTypeDecl() const;
496496

497+
/// Retrieve the type for Swift.AnyDerivative.
498+
CanType getAnyDerivativeType() const;
499+
497500
/// Retrieve the type Swift.Never.
498501
CanType getNeverType() const;
499502

include/swift/AST/Attr.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ TYPE_ATTR(noescape)
5353
TYPE_ATTR(escaping)
5454
// SWIFT_ENABLE_TENSORFLOW
5555
TYPE_ATTR(differentiable)
56+
TYPE_ATTR(sil_differentiable)
5657
TYPE_ATTR(autodiff)
5758
TYPE_ATTR(nondiff)
5859

include/swift/AST/Attr.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ class TypeAttributes {
6868
Optional<StringRef> convention = None;
6969
Optional<StringRef> conventionWitnessMethodProtocol = None;
7070

71+
// SWIFT_ENABLE_TENSORFLOW
72+
Optional<std::pair<DifferentiabilityRepresentationKind, unsigned>>
73+
differentiabilityReprKindAndOrder = None;
74+
7175
// For an opened existential type, the known ID.
7276
Optional<UUID> OpenedID;
7377

@@ -126,6 +130,15 @@ class TypeAttributes {
126130
bool hasConvention() const { return convention.hasValue(); }
127131
StringRef getConvention() const { return *convention; }
128132

133+
// SWIFT_ENABLE_TENSORFLOW
134+
bool hasDifferentiabilityRepresentationKindAndOrder() const {
135+
return differentiabilityReprKindAndOrder.hasValue();
136+
}
137+
std::pair<DifferentiabilityRepresentationKind, unsigned>
138+
getDifferentiabilityRepresentationKindAndOrder() const {
139+
return *differentiabilityReprKindAndOrder;
140+
}
141+
129142
bool hasOwnership() const {
130143
return getOwnership() != ReferenceOwnership::Strong;
131144
}

include/swift/AST/AutoDiff.h

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
301301
SmallBitVector indicesBitVec(capacity, false);
302302
for (auto index : indices)
303303
indicesBitVec.set(index);
304-
return AutoDiffIndexSubset::get(ctx, indicesBitVec);
304+
return get(ctx, indicesBitVec);
305305
}
306306

307307
static AutoDiffIndexSubset *getDefault(ASTContext &ctx, unsigned capacity,
@@ -557,6 +557,31 @@ class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
557557
}
558558
};
559559

560+
/// The kind of ABI used to represent a differentiable function.
561+
enum class DifferentiabilityRepresentationKind : unsigned {
562+
/// The function is linear and is represented as a bundle of the original
563+
/// function and its transpose. Its differential is the function itself. Its
564+
/// pullback is its transpose.
565+
///
566+
/// For original function `(T...) -> U`, there are a few typing invariants:
567+
/// 1. T = T.TangentVector = T.CotangentVector
568+
/// 2. U = U.TangentVector = U.CotangentVector
569+
///
570+
/// |----------------------|
571+
/// | Original | Transpose |
572+
/// |----------------------|
573+
Linear = 0,
574+
575+
/// The function is represented as a bundle of the original function and
576+
/// JVP functions at every order. JVP functions must be thin.
577+
///
578+
/// 1 2 ... n
579+
/// |----------------------------------------|
580+
/// | Original | JVP@1 | JVP@2 | ... | JVP@n |
581+
/// |----------------------------------------|
582+
Normal = 1
583+
};
584+
560585
/// Automatic differentiation utility namespace.
561586
namespace autodiff {
562587

@@ -606,8 +631,8 @@ class VectorSpace {
606631
Vector,
607632
/// A product of vector spaces as a tuple.
608633
Tuple,
609-
/// A function type whose innermost result conforms to `AdditiveArithmetic`.
610-
Function
634+
/// An existential `AdditiveArithmetic` type.
635+
Existential
611636
};
612637

613638
private:
@@ -617,16 +642,12 @@ class VectorSpace {
617642
Type vectorType;
618643
// Tuple
619644
TupleType *tupleType;
620-
// Function
621-
AnyFunctionType *functionType;
622645

623646
Value(Type vectorType) : vectorType(vectorType) {}
624647
Value(TupleType *tupleType) : tupleType(tupleType) {}
625-
Value(AnyFunctionType *functionType) : functionType(functionType) {}
626648
} value;
627649

628-
VectorSpace(Kind kind, Value value)
629-
: kind(kind), value(value) {}
650+
VectorSpace(Kind kind, Value value) : kind(kind), value(value) {}
630651

631652
public:
632653
VectorSpace() = delete;
@@ -637,12 +658,11 @@ class VectorSpace {
637658
static VectorSpace getTuple(TupleType *tupleTy) {
638659
return {Kind::Tuple, tupleTy};
639660
}
640-
static VectorSpace getFunction(AnyFunctionType *fnTy) {
641-
return {Kind::Function, fnTy};
642-
}
661+
static VectorSpace getExistential(ASTContext &ctx);
643662

644663
bool isVector() const { return kind == Kind::Vector; }
645664
bool isTuple() const { return kind == Kind::Tuple; }
665+
bool isExistential() const { return kind == Kind::Existential; }
646666

647667
Kind getKind() const { return kind; }
648668
Type getVector() const {
@@ -653,10 +673,6 @@ class VectorSpace {
653673
assert(kind == Kind::Tuple);
654674
return value.tupleType;
655675
}
656-
AnyFunctionType *getFunction() const {
657-
assert(kind == Kind::Function);
658-
return value.functionType;
659-
}
660676

661677
Type getType() const;
662678
CanType getCanonicalType() const;

include/swift/AST/DiagnosticsParse.def

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,6 +1393,24 @@ ERROR(convention_attribute_witness_method_expected_colon,none,
13931393
ERROR(convention_attribute_witness_method_expected_protocol,none,
13941394
"expected protocol name in 'witness_method' 'convention' attribute", ())
13951395

1396+
// sil_differentiable
1397+
ERROR(sil_differentiable_attribute_expected_lparen,none,
1398+
"expected '(' after 'sil_differentiable' attribute", ())
1399+
ERROR(sil_differentiable_attribute_expected_max_order,none,
1400+
"expected a max differentiation order in 'sil_differentiable' attribute", ())
1401+
ERROR(sil_differentiable_attribute_expected_rparen,none,
1402+
"expected ')' after convention name for 'sil_differentiable' attribute", ())
1403+
ERROR(sil_differentiable_attribute_expected_lbrace,none,
1404+
"expected '{' in a '@sil_differentiable' type", ())
1405+
ERROR(sil_differentiable_attribute_expected_differential,none,
1406+
"expected 'differential:'", ())
1407+
ERROR(sil_differentiable_attribute_expected_pullback,none,
1408+
"expected 'pullback:' ", ())
1409+
ERROR(sil_differentiable_attribute_expected_transpose,none,
1410+
"expected 'transpose:' ", ())
1411+
ERROR(sil_differentiable_attribute_expected_rbrace,none,
1412+
"expected '}' to end '@sil_differentiable' type", ())
1413+
13961414
// objc
13971415
ERROR(attr_objc_missing_colon,none,
13981416
"missing ':' after selector piece in @objc attribute", ())

include/swift/AST/DiagnosticsSema.def

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3793,6 +3793,21 @@ ERROR(sil_metatype_multiple_reprs,none,
37933793
"metatypes in SIL can only be one of @thin, @thick, or @objc_metatype",
37943794
())
37953795

3796+
// SWIFT_ENABLE_TENSORFLOW
3797+
// @sil_differentiable types
3798+
ERROR(sil_differentiable_attr_not_applicable,none,
3799+
"'@sil_differentiable' is not applicable to this type", ())
3800+
ERROR(sil_differentiable_required_original_function_field,none,
3801+
"an original function type field is required in a '@sil_differentiable'", ())
3802+
ERROR(sil_differentiable_required_field,none,
3803+
"a '%0' function type field is required in a '@sil_differentiable'", (StringRef))
3804+
ERROR(sil_differentiable_fields_must_be_function_type,none,
3805+
"fields in a '@sil_differentiable' type must be function types", ())
3806+
ERROR(sil_differentiable_invalid_field,none,
3807+
"invalid field for the specified '@sil_differentiable' representation kind", ())
3808+
ERROR(sil_differentiable_field_cannot_be_generic,none,
3809+
"'@sil_differentiable' field type cannot be generic", ())
3810+
37963811
//------------------------------------------------------------------------------
37973812
// MARK: @objc and @nonobjc
37983813
//------------------------------------------------------------------------------

include/swift/AST/KnownStdlibTypes.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,7 @@ KNOWN_STDLIB_TYPE_DECL(KeyedEncodingContainer, NominalTypeDecl, 1)
8484
KNOWN_STDLIB_TYPE_DECL(KeyedDecodingContainer, NominalTypeDecl, 1)
8585
KNOWN_STDLIB_TYPE_DECL(RangeReplaceableCollection, ProtocolDecl, 1)
8686

87+
// SWIFT_ENABLE_TENSORFLOW
88+
KNOWN_STDLIB_TYPE_DECL(AnyDerivative, StructDecl, 0)
89+
8790
#undef KNOWN_STDLIB_TYPE_DECL

include/swift/AST/TypeMatcher.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ class TypeMatcher {
239239
TRIVIAL_CASE(SILFunctionType)
240240
TRIVIAL_CASE(SILBlockStorageType)
241241
TRIVIAL_CASE(SILBoxType)
242+
// SWIFT_ENABLE_TENSORFLOW
243+
TRIVIAL_CASE(SILDifferentiableFunctionType)
242244
TRIVIAL_CASE(ProtocolCompositionType)
243245

244246
bool visitLValueType(CanLValueType firstLValue, Type secondType,

include/swift/AST/TypeNodes.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ ARTIFICIAL_TYPE(SILFunction, Type)
148148
ARTIFICIAL_TYPE(SILBlockStorage, Type)
149149
ARTIFICIAL_TYPE(SILBox, Type)
150150
ARTIFICIAL_TYPE(SILToken, Type)
151+
// SWIFT_ENABLE_TENSORFLOW
152+
ARTIFICIAL_TYPE(SILDifferentiableFunction, Type)
151153
TYPE(ProtocolComposition, Type)
152154
TYPE(LValue, Type)
153155
TYPE(InOut, Type)

include/swift/AST/TypeRepr.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,8 @@ inline bool TypeRepr::isSimple() const {
11501150
case TypeReprKind::InOut:
11511151
case TypeReprKind::Composition:
11521152
case TypeReprKind::OpaqueReturn:
1153+
// SWIFT_ENABLE_TENSORFLOW
1154+
case TypeReprKind::SILDifferentiableFunction:
11531155
return false;
11541156
case TypeReprKind::SimpleIdent:
11551157
case TypeReprKind::GenericIdent:
@@ -1170,6 +1172,54 @@ inline bool TypeRepr::isSimple() const {
11701172
llvm_unreachable("bad TypeRepr kind");
11711173
}
11721174

1175+
// SWIFT_ENABLE_TENSORFLOW
1176+
class SILDifferentiableFunctionTypeRepr final : public TypeRepr {
1177+
GenericParamList *GenericParams;
1178+
GenericEnvironment *GenericEnv = nullptr;
1179+
TypeRepr *Original;
1180+
TypeRepr *Differential;
1181+
TypeRepr *Pullback;
1182+
TypeRepr *Transpose;
1183+
SourceRange Braces;
1184+
1185+
public:
1186+
SILDifferentiableFunctionTypeRepr(
1187+
GenericParamList *genericParams, TypeRepr *original,
1188+
TypeRepr *differential, TypeRepr *pullback, TypeRepr *transpose,
1189+
SourceRange braces)
1190+
: TypeRepr(TypeReprKind::SILDifferentiableFunction),
1191+
GenericParams(genericParams), Original(original),
1192+
Differential(differential), Pullback(pullback), Transpose(transpose),
1193+
Braces(braces) {}
1194+
1195+
GenericParamList *getGenericParams() const { return GenericParams; };
1196+
GenericEnvironment *getGenericEnvironment() const { return GenericEnv; };
1197+
void setGenericEnvironment(GenericEnvironment *env) {
1198+
assert(GenericEnv == nullptr);
1199+
GenericEnv = env;
1200+
}
1201+
TypeRepr *getOriginal() const { return Original; }
1202+
TypeRepr *getDifferential() const { return Differential; }
1203+
TypeRepr *getPullback() const { return Pullback; }
1204+
TypeRepr *getTranspose() const { return Transpose; }
1205+
1206+
SourceRange getBraces() const { return Braces; }
1207+
1208+
static bool classof(const TypeRepr *T) {
1209+
return T->getKind() == TypeReprKind::SILDifferentiableFunction;
1210+
}
1211+
1212+
static bool classof(const SILDifferentiableFunctionTypeRepr *T) {
1213+
return true;
1214+
}
1215+
1216+
private:
1217+
SourceLoc getStartLocImpl() const { return Braces.Start; }
1218+
SourceLoc getEndLocImpl() const { return Braces.End; }
1219+
void printImpl(ASTPrinter &Printer, const PrintOptions &Opts) const;
1220+
friend class TypeRepr;
1221+
};
1222+
11731223
} // end namespace swift
11741224

11751225
namespace llvm {

include/swift/AST/TypeReprNodes.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ ABSTRACT_TYPEREPR(Specifier, TypeRepr)
6060
TYPEREPR(Owned, SpecifierTypeRepr)
6161
TYPEREPR(Fixed, TypeRepr)
6262
TYPEREPR(SILBox, TypeRepr)
63+
// SWIFT_ENABLE_TENSORFLOW
64+
TYPEREPR(SILDifferentiableFunction, TypeRepr)
6365
LAST_TYPEREPR(SILBox)
6466

6567
#undef ABSTRACT_TYPEREPR

0 commit comments

Comments
 (0)