Skip to content

Commit bad0d63

Browse files
committed
Merge branch 'tensorflow' of https://github.com/apple/swift into tensorflow
2 parents 5a8eb79 + 8c4853a commit bad0d63

File tree

61 files changed

+3063
-273
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+3063
-273
lines changed

docs/ABI/Mangling.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,10 @@ Types
516516
FUNCTION-KIND ::= 'C' // C function pointer type
517517
FUNCTION-KIND ::= 'A' // @auto_closure function type (escaping)
518518
FUNCTION-KIND ::= 'E' // function type (noescape)
519+
FUNCTION-KIND ::= 'F' // @differentiable function type
520+
FUNCTION-KIND ::= 'G' // @differentiable function type (escaping)
521+
FUNCTION-KIND ::= 'H' // @differentiable(linear) function type
522+
FUNCTION-KIND ::= 'I' // @differentiable(linear) function type (escaping)
519523

520524
function-signature ::= params-type params-type throws? // results and parameters
521525

include/swift/ABI/MetadataValues.h

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,14 @@ enum class FunctionMetadataConvention: uint8_t {
733733
CFunctionPointer = 3,
734734
};
735735

736+
/// Differentiability kind for function type metadata.
737+
/// Duplicates `DifferentiabilityKind` in AutoDiff.h.
738+
enum class FunctionMetadataDifferentiabilityKind: uint8_t {
739+
NonDifferentiable = 0b00,
740+
Normal = 0b01,
741+
Linear = 0b11
742+
};
743+
736744
/// Flags in a function type metadata record.
737745
template <typename int_type>
738746
class TargetFunctionTypeFlags {
@@ -747,7 +755,8 @@ class TargetFunctionTypeFlags {
747755
ParamFlagsMask = 0x02000000U,
748756
EscapingMask = 0x04000000U,
749757
// SWIFT_ENABLE_TENSORFLOW
750-
DifferentiableMask = 0x08000000U
758+
DifferentiableMask = 0x08000000U,
759+
LinearMask = 0x10000000U
751760
};
752761
int_type Data;
753762

@@ -785,10 +794,14 @@ class TargetFunctionTypeFlags {
785794
}
786795

787796
// SWIFT_ENABLE_TENSORFLOW
788-
constexpr TargetFunctionTypeFlags<int_type>
789-
withDifferentiable(bool isDifferentiable) const {
790-
return TargetFunctionTypeFlags<int_type>((Data & ~DifferentiableMask) |
791-
(isDifferentiable ? DifferentiableMask : 0));
797+
constexpr TargetFunctionTypeFlags<int_type> withDifferentiabilityKind(
798+
FunctionMetadataDifferentiabilityKind differentiability) const {
799+
return TargetFunctionTypeFlags<int_type>(
800+
(Data & ~DifferentiableMask & ~LinearMask) |
801+
(differentiability == FunctionMetadataDifferentiabilityKind::Normal
802+
? DifferentiableMask : 0) |
803+
(differentiability == FunctionMetadataDifferentiabilityKind::Linear
804+
? LinearMask : 0));
792805
}
793806

794807
unsigned getNumParameters() const { return Data & NumParametersMask; }
@@ -807,7 +820,15 @@ class TargetFunctionTypeFlags {
807820

808821
// SWIFT_ENABLE_TENSORFLOW
809822
bool isDifferentiable() const {
810-
return bool (Data & DifferentiableMask);
823+
return getDifferentiabilityKind() >=
824+
FunctionMetadataDifferentiabilityKind::Normal;
825+
}
826+
FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
827+
if (bool(Data & DifferentiableMask))
828+
return FunctionMetadataDifferentiabilityKind::Normal;
829+
if (bool(Data & LinearMask))
830+
return FunctionMetadataDifferentiabilityKind::Linear;
831+
return FunctionMetadataDifferentiabilityKind::NonDifferentiable;
811832
}
812833

813834
bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }

include/swift/AST/Attr.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,9 @@ SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
432432
/* Not serialized */ 91)
433433
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
434434
OnVar, 92)
435+
DECL_ATTR(transposing, Transposing,
436+
OnFunc | LongAttribute | AllowMultipleAttributes |
437+
NotSerialized, 93)
435438

436439
#undef TYPE_ATTR
437440
#undef DECL_ATTR_ALIAS

include/swift/AST/Attr.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,79 @@ class DifferentiatingAttr final
16821682
return DA->getKind() == DAK_Differentiating;
16831683
}
16841684
};
1685+
1686+
/// Attribute that registers a function as a transpose of another function.
1687+
///
1688+
/// Examples:
1689+
/// @transposing(foo)
1690+
/// @transposing(+, wrt: (lhs, rhs))
1691+
class TransposingAttr final
1692+
: public DeclAttribute,
1693+
private llvm::TrailingObjects<DifferentiableAttr,
1694+
ParsedAutoDiffParameter> {
1695+
/// The base type of the original function.
1696+
/// This is non-null only when the original function is not top-level (i.e. it
1697+
/// is an instance/static method).
1698+
TypeRepr *BaseType;
1699+
/// The original function name.
1700+
DeclNameWithLoc Original;
1701+
/// The original function, resolved by the type checker.
1702+
FuncDecl *OriginalFunction = nullptr;
1703+
/// The number of parsed parameters specified in 'wrt:'.
1704+
unsigned NumParsedParameters = 0;
1705+
/// The differentiation parameters' indices, resolved by the type checker.
1706+
AutoDiffIndexSubset *ParameterIndexSubset = nullptr;
1707+
1708+
explicit TransposingAttr(ASTContext &context, bool implicit,
1709+
SourceLoc atLoc, SourceRange baseRange,
1710+
TypeRepr *baseType, DeclNameWithLoc original,
1711+
ArrayRef<ParsedAutoDiffParameter> params);
1712+
1713+
explicit TransposingAttr(ASTContext &context, bool implicit,
1714+
SourceLoc atLoc, SourceRange baseRange,
1715+
TypeRepr *baseType, DeclNameWithLoc original,
1716+
AutoDiffIndexSubset *indices);
1717+
1718+
public:
1719+
static TransposingAttr *create(ASTContext &context, bool implicit,
1720+
SourceLoc atLoc, SourceRange baseRange,
1721+
TypeRepr *baseType, DeclNameWithLoc original,
1722+
ArrayRef<ParsedAutoDiffParameter> params);
1723+
1724+
static TransposingAttr *create(ASTContext &context, bool implicit,
1725+
SourceLoc atLoc, SourceRange baseRange,
1726+
TypeRepr *baseType, DeclNameWithLoc original,
1727+
AutoDiffIndexSubset *indices);
1728+
1729+
TypeRepr *getBaseType() const { return BaseType; }
1730+
DeclNameWithLoc getOriginal() const { return Original; }
1731+
1732+
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
1733+
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
1734+
1735+
/// The parsed transposing parameters, i.e. the list of parameters
1736+
/// specified in 'wrt:'.
1737+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1738+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1739+
}
1740+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1741+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1742+
}
1743+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1744+
return NumParsedParameters;
1745+
}
1746+
1747+
AutoDiffIndexSubset *getParameterIndexSubset() const {
1748+
return ParameterIndexSubset;
1749+
}
1750+
void setParameterIndices(AutoDiffIndexSubset *pi) {
1751+
ParameterIndexSubset = pi;
1752+
}
1753+
1754+
static bool classof(const DeclAttribute *DA) {
1755+
return DA->getKind() == DAK_Transposing;
1756+
}
1757+
};
16851758

16861759
/// Attributes that may be applied to declarations.
16871760
class DeclAttributes {

include/swift/AST/AutoDiff.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "swift/Basic/Range.h"
2424

2525
namespace swift {
26-
26+
2727
enum class DifferentiabilityKind: uint8_t {
2828
NonDifferentiable = 0b00,
2929
Normal = 0b01,
@@ -354,6 +354,13 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
354354
unsigned getNumIndices() const {
355355
return (unsigned)std::distance(begin(), end());
356356
}
357+
358+
SmallBitVector getBitVector() const {
359+
SmallBitVector indicesBitVec(capacity, false);
360+
for (auto index : getIndices())
361+
indicesBitVec.set(index);
362+
return indicesBitVec;
363+
}
357364

358365
bool contains(unsigned index) const {
359366
unsigned bitWordIndex, offset;

include/swift/AST/DiagnosticsParse.def

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,13 +1512,23 @@ ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,
15121512
ERROR(attr_differentiating_expected_label_linear_or_wrt,none,
15131513
"expected either 'linear' or 'wrt:'", ())
15141514

1515+
// transposing
1516+
ERROR(attr_transposing_expected_original_name,PointsToFirstBadToken,
1517+
"expected an original function name", ())
1518+
ERROR(attr_transposing_expected_label_linear_or_wrt,none,
1519+
"expected 'wrt:'", ())
1520+
1521+
// transposing `wrt` parameters clause
1522+
ERROR(transposing_params_clause_expected_parameter,PointsToFirstBadToken,
1523+
"expected a parameter, which can be a 'unsigned int' parameter number "
1524+
"or 'self'", ())
1525+
15151526
// differentiation `wrt` parameters clause
15161527
ERROR(expected_colon_after_label,PointsToFirstBadToken,
15171528
"expected a colon ':' after '%0'", (StringRef))
15181529
ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
15191530
"expected a parameter, which can be a function parameter name, "
1520-
"parameter index, or 'self'",
1521-
())
1531+
"parameter index, or 'self'", ())
15221532

15231533
// [differentiable ...] (sil-decl attr)
15241534
ERROR(sil_attr_differentiable_expected_keyword,PointsToFirstBadToken,

include/swift/AST/DiagnosticsSema.def

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2763,8 +2763,10 @@ ERROR(differentiable_attr_overload_not_found,none,
27632763
"%0 does not have expected type %1", (DeclName, Type))
27642764
ERROR(differentiable_attr_no_currying,none,
27652765
"'@differentiable' cannot be defined on functions that return functions", ())
2766+
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
2767+
// mention "same generic requirements".
27662768
ERROR(differentiable_attr_duplicate,none,
2767-
"duplicate '@differentiable' attribute with same parameter indices", ())
2769+
"duplicate '@differentiable' attribute with same parameters", ())
27682770
NOTE(differentiable_attr_duplicate_note,none,
27692771
"other attribute declared here", ())
27702772
ERROR(differentiable_attr_function_not_same_type_context,none,
@@ -2824,6 +2826,19 @@ ERROR(differentiating_attr_not_in_same_file_as_original,none,
28242826
ERROR(differentiating_attr_original_already_has_derivative,none,
28252827
"a derivative already exists for %0", (DeclName))
28262828

2829+
// transposing
2830+
ERROR(transpose_params_clause_param_not_differentiable,none,
2831+
"can only transpose with respect to parameters that conform to "
2832+
"'Differentiable' and where '%0 == %0.TangentVector'", (StringRef))
2833+
ERROR(transposing_attr_overload_not_found,none,
2834+
"could not find function %0 with expected type %1", (DeclName, Type))
2835+
ERROR(transposing_attr_cant_use_named_wrt_params,none,
2836+
"cannot use named wrt parameters in '@transposing' attribute, found %0",
2837+
(Identifier))
2838+
ERROR(transposing_attr_result_value_not_differentiable,none,
2839+
"'@transposing' attribute requires original function result to "
2840+
"conform to 'Differentiable'", (Type))
2841+
28272842
// differentiation `wrt` parameters clause
28282843
ERROR(diff_function_no_parameters,none,
28292844
"%0 has no parameters to differentiate with respect to", (DeclName))
@@ -2841,6 +2856,8 @@ ERROR(diff_params_clause_no_inferred_parameters,PointsToFirstBadToken,
28412856
"no differentiation parameters could be inferred; must differentiate "
28422857
"with respect to at least one parameter conforming to 'Differentiable'",
28432858
())
2859+
ERROR(diff_params_clause_inout_argument,none,
2860+
"'inout' parameters (%0) cannot be differentiated with respect to", (Type))
28442861
ERROR(diff_params_clause_cannot_diff_wrt_objects_or_existentials,none,
28452862
"class objects and protocol existentials (%0) cannot be differentiated "
28462863
"with respect to", (Type))
@@ -2865,9 +2882,9 @@ ERROR(compiler_evaluable_ref_non_compiler_evaluable,none,
28652882
"@compilerEvaluable functions may not reference non-@compilerEvaluable functions", ())
28662883

28672884
// @noDerivative attribute
2868-
ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none,
2869-
"'@noDerivative' is only allowed on stored properties in structure types "
2870-
"that declare a conformance to 'Differentiable'", ())
2885+
ERROR(noderivative_only_on_differentiable_struct_or_class_fields,none,
2886+
"'@noDerivative' is only allowed on stored properties in structure or "
2887+
"class types that declare a conformance to 'Differentiable'", ())
28712888

28722889
//------------------------------------------------------------------------------
28732890
// MARK: Type Check Expressions

include/swift/AST/Types.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
// SWIFT_ENABLE_TENSORFLOW
2121
#include "swift/AST/AutoDiff.h"
22+
#include "swift/AST/Attr.h"
2223
#include "swift/AST/DeclContext.h"
2324
#include "swift/AST/GenericParamKey.h"
2425
#include "swift/AST/Identifier.h"
@@ -274,7 +275,7 @@ class alignas(1 << TypeAlignInBits) TypeBase {
274275
TypeBase(const TypeBase&) = delete;
275276
void operator=(const TypeBase&) = delete;
276277

277-
/// This union contains to the ASTContext for canonical types, and is
278+
/// This union contains the ASTContext for canonical types, and is
278279
/// otherwise lazily populated by ASTContext when the canonical form of a
279280
/// non-canonical type is requested. The disposition of the union is stored
280281
/// outside of the union for performance. See Bits.TypeBase.IsCanonical.
@@ -3093,7 +3094,7 @@ class AnyFunctionType : public TypeBase {
30933094
///
30943095
/// If `makeSelfParamFirst` is true, self's tangent is reordered to appear
30953096
/// first. This should be used during type-checking, e.g. type-checking
3096-
/// `@differentiable` and `@differentiating` attributes.
3097+
/// `@differentiable`, `@differentiating`, and `@transposing` attributes.
30973098
///
30983099
/// \note The original function type (`self`) need not be `@differentiable`.
30993100
/// The resulting function will preserve all `ExtInfo` of the original
@@ -3108,6 +3109,13 @@ class AnyFunctionType : public TypeBase {
31083109
/// Given the type of an autodiff associated function, returns the
31093110
/// corresponding original function type.
31103111
AnyFunctionType *getAutoDiffOriginalFunctionType();
3112+
3113+
/// Given the type of a transposing associated function, returns the
3114+
/// corresponding original function type.
3115+
AnyFunctionType *
3116+
getTransposeOriginalFunctionType(TransposingAttr *attr,
3117+
AutoDiffIndexSubset *wrtParamIndices,
3118+
bool wrtSelf);
31113119

31123120
AnyFunctionType *getWithoutDifferentiability() const;
31133121

include/swift/Demangling/DemangleNodes.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ NODE(DependentProtocolConformanceInherited)
6868
NODE(DependentProtocolConformanceAssociated)
6969
CONTEXT_NODE(Destructor)
7070
CONTEXT_NODE(DidSet)
71+
// SWIFT_ENABLE_TENSORFLOW
72+
NODE(DifferentiableFunctionType)
73+
NODE(EscapingDifferentiableFunctionType)
74+
NODE(LinearFunctionType)
75+
NODE(EscapingLinearFunctionType)
76+
// SWIFT_ENABLE_TENSORFLOW END
7177
NODE(Directness)
7278
NODE(DynamicAttribute)
7379
NODE(DirectMethodReferenceAttribute)

include/swift/Demangling/TypeDecoder.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,12 @@ class TypeDecoder {
494494
case NodeKind::NoEscapeFunctionType:
495495
case NodeKind::AutoClosureType:
496496
case NodeKind::EscapingAutoClosureType:
497+
// SWIFT_ENABLE_TENSORFLOW
498+
case NodeKind::DifferentiableFunctionType:
499+
case NodeKind::EscapingDifferentiableFunctionType:
500+
case NodeKind::LinearFunctionType:
501+
case NodeKind::EscapingLinearFunctionType:
502+
// SWIFT_ENABLE_TENSORFLOW END
497503
case NodeKind::FunctionType: {
498504
if (Node->getNumChildren() < 2)
499505
return BuiltType();
@@ -508,6 +514,17 @@ class TypeDecoder {
508514
} else if (Node->getKind() == NodeKind::ThinFunctionType) {
509515
flags = flags.withConvention(FunctionMetadataConvention::Thin);
510516
}
517+
// SWIFT_ENABLE_TENSORFLOW
518+
else if (Node->getKind() == NodeKind::DifferentiableFunctionType ||
519+
Node->getKind() ==
520+
NodeKind::EscapingDifferentiableFunctionType) {
521+
flags = flags.withDifferentiabilityKind(
522+
FunctionMetadataDifferentiabilityKind::Normal);
523+
} else if (Node->getKind() == NodeKind::LinearFunctionType ||
524+
Node->getKind() == NodeKind::EscapingLinearFunctionType) {
525+
flags = flags.withDifferentiabilityKind(
526+
FunctionMetadataDifferentiabilityKind::Linear);
527+
}
511528

512529
bool isThrow =
513530
Node->getChild(0)->getKind() == NodeKind::ThrowsAnnotation;
@@ -527,7 +544,12 @@ class TypeDecoder {
527544
.withEscaping(
528545
Node->getKind() == NodeKind::FunctionType ||
529546
Node->getKind() == NodeKind::EscapingAutoClosureType ||
530-
Node->getKind() == NodeKind::EscapingObjCBlock);
547+
Node->getKind() == NodeKind::EscapingObjCBlock ||
548+
// SWIFT_ENABLE_TENSORFLOW
549+
Node->getKind() ==
550+
NodeKind::EscapingDifferentiableFunctionType ||
551+
Node->getKind() ==
552+
NodeKind::EscapingLinearFunctionType);
531553

532554
auto result = decodeMangledType(Node->getChild(isThrow ? 2 : 1));
533555
if (!result) return BuiltType();

include/swift/Parse/Parser.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,10 +960,17 @@ class Parser {
960960
/// Parse a differentiation parameters clause.
961961
bool parseDifferentiationParametersClause(
962962
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
963+
964+
/// Parse a transposing parameters clause.
965+
bool parseTransposingParametersClause(
966+
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
963967

964968
/// Parse the @differentiating attribute.
965969
ParserResult<DifferentiatingAttr>
966970
parseDifferentiatingAttribute(SourceLoc AtLoc, SourceLoc Loc);
971+
972+
ParserResult<TransposingAttr> parseTransposingAttribute(SourceLoc AtLoc,
973+
SourceLoc Loc);
967974

968975
/// Parse a specific attribute.
969976
ParserStatus parseDeclAttribute(DeclAttributes &Attributes, SourceLoc AtLoc);

include/swift/SIL/SILFunction.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ class SILFunction
580580
}
581581

582582
/// Returns true if the function has parameters that are consumed by the
583-
// callee.
583+
/// callee.
584584
bool hasOwnedParameters() const {
585585
for (auto &ParamInfo : getLoweredFunctionType()->getParameters()) {
586586
if (ParamInfo.isConsumed())

0 commit comments

Comments
 (0)