Skip to content

[AutoDiff] Add 'SILDifferentiableFunctionType'. #23482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,9 @@ class ASTContext final {
/// has been imported. Otherwise, this returns null.
StructDecl *getTensorDataTypeDecl() const;

/// Retrieve the type for Swift.AnyDerivative.
CanType getAnyDerivativeType() const;

/// Retrieve the type Swift.Never.
CanType getNeverType() const;

Expand Down
46 changes: 31 additions & 15 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
SmallBitVector indicesBitVec(capacity, false);
for (auto index : indices)
indicesBitVec.set(index);
return AutoDiffIndexSubset::get(ctx, indicesBitVec);
return get(ctx, indicesBitVec);
}

static AutoDiffIndexSubset *getDefault(ASTContext &ctx, unsigned capacity,
Expand Down Expand Up @@ -557,6 +557,31 @@ class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
}
};

/// The kind of ABI used to represent a differentiable function.
enum class DifferentiabilityRepresentationKind : unsigned {
/// The function is linear and is represented as a bundle of the original
/// function and its transpose. Its differential is the function itself. Its
/// pullback is its transpose.
///
/// For original function `(T...) -> U`, there are a few typing invariants:
/// 1. T = T.TangentVector = T.CotangentVector
/// 2. U = U.TangentVector = U.CotangentVector
///
/// |----------------------|
/// | Original | Transpose |
/// |----------------------|
Linear = 0,

/// The function is represented as a bundle of the original function and
/// JVP functions at every order. JVP functions must be thin.
///
/// 1 2 ... n
/// |----------------------------------------|
/// | Original | JVP@1 | JVP@2 | ... | JVP@n |
/// |----------------------------------------|
Normal = 1
};

/// Automatic differentiation utility namespace.
namespace autodiff {

Expand Down Expand Up @@ -606,8 +631,8 @@ class VectorSpace {
Vector,
/// A product of vector spaces as a tuple.
Tuple,
/// A function type whose innermost result conforms to `AdditiveArithmetic`.
Function
/// An existential `AdditiveArithmetic` type.
Existential
};

private:
Expand All @@ -617,16 +642,12 @@ class VectorSpace {
Type vectorType;
// Tuple
TupleType *tupleType;
// Function
AnyFunctionType *functionType;

Value(Type vectorType) : vectorType(vectorType) {}
Value(TupleType *tupleType) : tupleType(tupleType) {}
Value(AnyFunctionType *functionType) : functionType(functionType) {}
} value;

VectorSpace(Kind kind, Value value)
: kind(kind), value(value) {}
VectorSpace(Kind kind, Value value) : kind(kind), value(value) {}

public:
VectorSpace() = delete;
Expand All @@ -637,12 +658,11 @@ class VectorSpace {
static VectorSpace getTuple(TupleType *tupleTy) {
return {Kind::Tuple, tupleTy};
}
static VectorSpace getFunction(AnyFunctionType *fnTy) {
return {Kind::Function, fnTy};
}
static VectorSpace getExistential(ASTContext &ctx);

bool isVector() const { return kind == Kind::Vector; }
bool isTuple() const { return kind == Kind::Tuple; }
bool isExistential() const { return kind == Kind::Existential; }

Kind getKind() const { return kind; }
Type getVector() const {
Expand All @@ -653,10 +673,6 @@ class VectorSpace {
assert(kind == Kind::Tuple);
return value.tupleType;
}
AnyFunctionType *getFunction() const {
assert(kind == Kind::Function);
return value.functionType;
}

Type getType() const;
CanType getCanonicalType() const;
Expand Down
19 changes: 19 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,25 @@ ERROR(convention_attribute_witness_method_expected_colon,none,
ERROR(convention_attribute_witness_method_expected_protocol,none,
"expected protocol name in 'witness_method' 'convention' attribute", ())

// sil_differentiable
ERROR(sil_differentiable_attribute_expected_lparen,none,
"expected '(' after 'sil_differentiable'", ())
ERROR(sil_differentiable_attribute_expected_max_order,none,
"expected a max differentiation order in 'sil_differentiable(...)'", ())
ERROR(sil_differentiable_attribute_expected_rparen,none,
"expected ')' after the representation kind or order for "
"'sil_differentiable'", ())
ERROR(sil_differentiable_attribute_expected_lbrace,none,
"expected '{' in a 'sil_differentiable' type", ())
ERROR(sil_differentiable_attribute_expected_differential,none,
"expected 'differential:'", ())
ERROR(sil_differentiable_attribute_expected_pullback,none,
"expected 'pullback:' ", ())
ERROR(sil_differentiable_attribute_expected_transpose,none,
"expected 'transpose:' ", ())
ERROR(sil_differentiable_attribute_expected_rbrace,none,
"expected '}' to end a 'sil_differentiable' type", ())

// objc
ERROR(attr_objc_missing_colon,none,
"missing ':' after selector piece in @objc attribute", ())
Expand Down
18 changes: 18 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3793,6 +3793,24 @@ ERROR(sil_metatype_multiple_reprs,none,
"metatypes in SIL can only be one of @thin, @thick, or @objc_metatype",
())

// SWIFT_ENABLE_TENSORFLOW
// @sil_differentiable types
ERROR(sil_differentiable_attr_not_applicable,none,
"'sil_differentiable' is not applicable to this type", ())
ERROR(sil_differentiable_required_original_function_field,none,
"an original function type field is required in a 'sil_differentiable'",
())
ERROR(sil_differentiable_required_field,none,
"a '%0' function type field is required in a 'sil_differentiable'",
(StringRef))
ERROR(sil_differentiable_fields_must_be_function_type,none,
"fields in a 'sil_differentiable' type must be function types", ())
ERROR(sil_differentiable_invalid_field,none,
"invalid field for the specified '@sil_differentiable' representation "
"kind", ())
ERROR(sil_differentiable_field_cannot_be_generic,none,
"'sil_differentiable' field type cannot be generic", ())

//------------------------------------------------------------------------------
// MARK: @objc and @nonobjc
//------------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/KnownStdlibTypes.def
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,7 @@ KNOWN_STDLIB_TYPE_DECL(KeyedEncodingContainer, NominalTypeDecl, 1)
KNOWN_STDLIB_TYPE_DECL(KeyedDecodingContainer, NominalTypeDecl, 1)
KNOWN_STDLIB_TYPE_DECL(RangeReplaceableCollection, ProtocolDecl, 1)

// SWIFT_ENABLE_TENSORFLOW
KNOWN_STDLIB_TYPE_DECL(AnyDerivative, StructDecl, 0)

#undef KNOWN_STDLIB_TYPE_DECL
2 changes: 2 additions & 0 deletions include/swift/AST/TypeMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ class TypeMatcher {
TRIVIAL_CASE(SILFunctionType)
TRIVIAL_CASE(SILBlockStorageType)
TRIVIAL_CASE(SILBoxType)
// SWIFT_ENABLE_TENSORFLOW
TRIVIAL_CASE(SILDifferentiableFunctionType)
TRIVIAL_CASE(ProtocolCompositionType)

bool visitLValueType(CanLValueType firstLValue, Type secondType,
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/TypeNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ ARTIFICIAL_TYPE(SILFunction, Type)
ARTIFICIAL_TYPE(SILBlockStorage, Type)
ARTIFICIAL_TYPE(SILBox, Type)
ARTIFICIAL_TYPE(SILToken, Type)
// SWIFT_ENABLE_TENSORFLOW
ARTIFICIAL_TYPE(SILDifferentiableFunction, Type)
TYPE(ProtocolComposition, Type)
TYPE(LValue, Type)
TYPE(InOut, Type)
Expand Down
57 changes: 57 additions & 0 deletions include/swift/AST/TypeRepr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,8 @@ inline bool TypeRepr::isSimple() const {
case TypeReprKind::InOut:
case TypeReprKind::Composition:
case TypeReprKind::OpaqueReturn:
// SWIFT_ENABLE_TENSORFLOW
case TypeReprKind::SILDifferentiableFunction:
return false;
case TypeReprKind::SimpleIdent:
case TypeReprKind::GenericIdent:
Expand All @@ -1170,6 +1172,61 @@ inline bool TypeRepr::isSimple() const {
llvm_unreachable("bad TypeRepr kind");
}

// SWIFT_ENABLE_TENSORFLOW
class SILDifferentiableFunctionTypeRepr final : public TypeRepr {
GenericParamList *GenericParams;
DifferentiabilityRepresentationKind reprKind;
int maxOrder;
GenericEnvironment *GenericEnv = nullptr;
TypeRepr *Original;
TypeRepr *Differential;
TypeRepr *Pullback;
TypeRepr *Transpose;
SourceRange Braces;

public:
SILDifferentiableFunctionTypeRepr(
GenericParamList *genericParams,
DifferentiabilityRepresentationKind reprKind, int maxOrder,
TypeRepr *original, TypeRepr *differential, TypeRepr *pullback,
TypeRepr *transpose, SourceRange braces)
: TypeRepr(TypeReprKind::SILDifferentiableFunction),
GenericParams(genericParams), reprKind(reprKind), maxOrder(maxOrder),
Original(original), Differential(differential), Pullback(pullback),
Transpose(transpose), Braces(braces) {}

GenericParamList *getGenericParams() const { return GenericParams; };
GenericEnvironment *getGenericEnvironment() const { return GenericEnv; };
void setGenericEnvironment(GenericEnvironment *env) {
assert(GenericEnv == nullptr);
GenericEnv = env;
}
DifferentiabilityRepresentationKind getRepresentationKind() const {
return reprKind;
}
int getMaxOrder() const { return maxOrder; }
TypeRepr *getOriginal() const { return Original; }
TypeRepr *getDifferential() const { return Differential; }
TypeRepr *getPullback() const { return Pullback; }
TypeRepr *getTranspose() const { return Transpose; }

SourceRange getBraces() const { return Braces; }

static bool classof(const TypeRepr *T) {
return T->getKind() == TypeReprKind::SILDifferentiableFunction;
}

static bool classof(const SILDifferentiableFunctionTypeRepr *T) {
return true;
}

private:
SourceLoc getStartLocImpl() const { return Braces.Start; }
SourceLoc getEndLocImpl() const { return Braces.End; }
void printImpl(ASTPrinter &Printer, const PrintOptions &Opts) const;
friend class TypeRepr;
};

} // end namespace swift

namespace llvm {
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/TypeReprNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ ABSTRACT_TYPEREPR(Specifier, TypeRepr)
TYPEREPR(Owned, SpecifierTypeRepr)
TYPEREPR(Fixed, TypeRepr)
TYPEREPR(SILBox, TypeRepr)
// SWIFT_ENABLE_TENSORFLOW
TYPEREPR(SILDifferentiableFunction, TypeRepr)
LAST_TYPEREPR(SILBox)

#undef ABSTRACT_TYPEREPR
Expand Down
Loading