Skip to content

Commit cb9ac39

Browse files
authored
[AutoDiff] Simplify function type differentiability to be binary (#21211)
* No longer make the distinction between forward, reverse, etc as they are no longer needed in the updated design. * Remove a bunch of stale code.
1 parent db52338 commit cb9ac39

24 files changed

+77
-488
lines changed

include/swift/ABI/MetadataValues.h

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -772,17 +772,6 @@ enum class FunctionMetadataConvention: uint8_t {
772772
CFunctionPointer = 3,
773773
};
774774

775-
// SWIFT_ENABLE_TENSORFLOW
776-
/// Differentiability values for function type metadata.
777-
enum class FunctionMetadataDifferentiability: uint8_t {
778-
None = 0,
779-
Forward = 1,
780-
Reverse = 2,
781-
Bidirectional = 3,
782-
Linear = 4,
783-
Constant = 5
784-
};
785-
786775
/// Flags in a function type metadata record.
787776
template <typename int_type>
788777
class TargetFunctionTypeFlags {
@@ -797,8 +786,7 @@ class TargetFunctionTypeFlags {
797786
ParamFlagsMask = 0x02000000U,
798787
EscapingMask = 0x04000000U,
799788
// SWIFT_ENABLE_TENSORFLOW
800-
DifferentiabilityShift = 27U,
801-
DifferentiabilityMask = 0x38000000U
789+
DifferentiableMask = 0x08000000U
802790
};
803791
int_type Data;
804792

@@ -837,9 +825,9 @@ class TargetFunctionTypeFlags {
837825

838826
// SWIFT_ENABLE_TENSORFLOW
839827
constexpr TargetFunctionTypeFlags<int_type>
840-
withDifferentiability(FunctionMetadataDifferentiability diffability) const {
841-
return TargetFunctionTypeFlags((Data & ~DifferentiabilityMask)
842-
| (int_type(diffability) << DifferentiabilityShift));
828+
withDifferentiable(bool isDifferentiable) const {
829+
return TargetFunctionTypeFlags<int_type>((Data & ~DifferentiableMask) |
830+
(isDifferentiable ? DifferentiableMask : 0));
843831
}
844832

845833
unsigned getNumParameters() const { return Data & NumParametersMask; }
@@ -857,9 +845,8 @@ class TargetFunctionTypeFlags {
857845
}
858846

859847
// SWIFT_ENABLE_TENSORFLOW
860-
FunctionMetadataDifferentiability getDifferentiability() const {
861-
return FunctionMetadataDifferentiability(
862-
(Data & DifferentiabilityMask) >> DifferentiabilityShift);
848+
bool isDifferentiable() const {
849+
return bool (Data & DifferentiableMask);
863850
}
864851

865852
bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }

include/swift/AST/Attr.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ class TypeAttributes {
6363
SourceLoc AtLoc;
6464
Optional<StringRef> convention = None;
6565
Optional<StringRef> conventionWitnessMethodProtocol = None;
66-
// SWIFT_ENABLE_TENSORFLOW
67-
Optional<std::pair<StringRef, int>> differentiabilityAndOrder = None;
6866

6967
// For an opened existential type, the known ID.
7068
Optional<UUID> OpenedID;
@@ -112,14 +110,6 @@ class TypeAttributes {
112110
bool hasConvention() const { return convention.hasValue(); }
113111
StringRef getConvention() const { return *convention; }
114112

115-
// SWIFT_ENABLE_TENSORFLOW
116-
bool hasDifferentiability() const {
117-
return differentiabilityAndOrder.hasValue();
118-
}
119-
std::pair<StringRef, int> getDifferentiabilityAndOrder() const {
120-
return *differentiabilityAndOrder;
121-
}
122-
123113
bool hasOwnership() const {
124114
return getOwnership() != ReferenceOwnership::Strong;
125115
}

include/swift/AST/AutoDiff.h

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -241,45 +241,6 @@ class AutoDiffParameterIndicesBuilder {
241241
void setSelfParameter();
242242
};
243243

244-
/// Differentiability of a function specifies the differentiation mode,
245-
/// parameter indices at which the function is differentiable with respect to,
246-
/// and indices of results which can be differentiated.
247-
class Differentiability {
248-
private:
249-
// The differentiation mode.
250-
AutoDiffMode mode;
251-
// Differentiable with respect to `self`, applicable to methods only.
252-
bool wrtSelf;
253-
// Indices of parameters that are differentiable with respect to.
254-
llvm::SmallBitVector parameterIndices;
255-
// Indices of results that are differentiable.
256-
llvm::SmallBitVector resultIndices;
257-
258-
public:
259-
Differentiability(AutoDiffMode mode,
260-
bool wrtSelf,
261-
llvm::SmallBitVector parameterIndices,
262-
llvm::SmallBitVector resultIndices);
263-
264-
Differentiability(AutoDiffMode mode, AnyFunctionType *type);
265-
266-
AutoDiffMode getMode() const {
267-
return mode;
268-
}
269-
270-
bool isWithRespectToSelf() const {
271-
return wrtSelf;
272-
}
273-
274-
const llvm::SmallBitVector &getParameterIndices() const {
275-
return parameterIndices;
276-
}
277-
278-
const llvm::SmallBitVector &getResultIndices() const {
279-
return resultIndices;
280-
}
281-
};
282-
283244
/// SIL-level automatic differentiation indices. Consists of a source index,
284245
/// i.e. index of the dependent result to differentiate from, and parameter
285246
/// indices, i.e. index of independent parameters to differentiate with

include/swift/AST/DiagnosticsSema.def

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3680,12 +3680,6 @@ ERROR(unreferenced_generic_parameter,none,
36803680

36813681
// SWIFT_ENABLE_TENSORFLOW
36823682
// Function differentiability
3683-
ERROR(autodiff_attr_invalid_differentiability,none,
3684-
"invalid differentiability '%0' in '@autodiff' attribute; expected 'forward', 'reverse', 'linear', 'constant', or 'bidirectional'", (StringRef))
3685-
ERROR(autodiff_attr_order_cannot_be_zero,none,
3686-
"differentiation order cannot be zero; it should be at least first-order", ())
3687-
ERROR(autodiff_attr_order_cannot_be_specified_in_mode,none,
3688-
"differentiation order cannot be specified in '%0' mode", (StringRef))
36893683
ERROR(autodiff_attr_argument_not_differentiable,none,
36903684
"argument is not differentiable, but the enclosing function type is marked '@autodiff'; did you want to add '@nondiff' to this argument?", ())
36913685
ERROR(autodiff_attr_result_not_differentiable,none,

include/swift/AST/Types.h

Lines changed: 22 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ class alignas(1 << TypeAlignInBits) TypeBase {
281281

282282
protected:
283283
// SWIFT_ENABLE_TENSORFLOW
284-
enum { NumAFTExtInfoBits = 10 };
285-
enum { NumSILExtInfoBits = 9 };
284+
enum { NumAFTExtInfoBits = 8 };
285+
enum { NumSILExtInfoBits = 7 };
286286
union { uint64_t OpaqueBits;
287287

288288
SWIFT_INLINE_BITFIELD_BASE(TypeBase, bitmax(NumTypeKindBits,8) +
@@ -2706,23 +2706,6 @@ getSILFunctionLanguage(SILFunctionTypeRepresentation rep) {
27062706
llvm_unreachable("Unhandled SILFunctionTypeRepresentation in switch.");
27072707
}
27082708

2709-
// SWIFT_ENABLE_TENSORFLOW
2710-
/// The differentiability of a function type.
2711-
enum class FunctionTypeDifferentiability : uint8_t {
2712-
/// Non-differentiable.
2713-
None = 0,
2714-
/// Forward-mode differentiable.
2715-
Forward,
2716-
/// Reverse-mode differentiable.
2717-
Reverse,
2718-
/// Both forward-mode and reverse-mode differentiable.
2719-
Bidirectional,
2720-
/// Linear map.
2721-
Linear,
2722-
/// Constant function, whose derivatives are always zero.
2723-
Constant,
2724-
};
2725-
27262709
/// AnyFunctionType - A function type has zero or more input parameters and a
27272710
/// single result. The result type may be a tuple. For example:
27282711
/// "(int) -> int" or "(a : int, b : int) -> (int, int)".
@@ -2736,8 +2719,6 @@ class AnyFunctionType : public TypeBase {
27362719

27372720
public:
27382721
using Representation = FunctionTypeRepresentation;
2739-
// SWIFT_ENABLE_TENSORFLOW
2740-
using Differentiability = FunctionTypeDifferentiability;
27412722

27422723
class Param {
27432724
public:
@@ -2898,9 +2879,8 @@ class AnyFunctionType : public TypeBase {
28982879
NoEscapeMask = 1 << 5,
28992880
ThrowsMask = 1 << 6,
29002881
// SWIFT_ENABLE_TENSORFLOW
2901-
DifferentiabilityOffset = 7,
2902-
DifferentiabilityMask = 0b111 << DifferentiabilityOffset,
2903-
NumMaskBits = 10
2882+
DifferentiableMask = 1 << 7,
2883+
NumMaskBits = 8
29042884
};
29052885

29062886
unsigned Bits; // Naturally sized for speed.
@@ -2913,8 +2893,6 @@ class AnyFunctionType : public TypeBase {
29132893
// Constructor with all defaults.
29142894
ExtInfo() : Bits(0) {
29152895
assert(getRepresentation() == Representation::Swift);
2916-
// SWIFT_ENABLE_TENSORFLOW
2917-
assert(getDifferentiability() == Differentiability::None);
29182896
}
29192897

29202898
// Constructor for polymorphic type.
@@ -2926,31 +2904,25 @@ class AnyFunctionType : public TypeBase {
29262904
ExtInfo(Representation Rep,
29272905
bool IsAutoClosure, bool IsNoEscape,
29282906
// SWIFT_ENABLE_TENSORFLOW
2929-
bool Throws, Differentiability Diff)
2907+
bool Throws, bool IsDifferentiable)
29302908
: ExtInfo(Rep, Throws) {
29312909
Bits |= (IsAutoClosure ? AutoClosureMask : 0);
29322910
Bits |= (IsNoEscape ? NoEscapeMask : 0);
29332911
// SWIFT_ENABLE_TENSORFLOW
2934-
Bits |=
2935-
(((unsigned)Diff << DifferentiabilityOffset) & DifferentiabilityMask);
2912+
Bits |= (IsDifferentiable ? DifferentiableMask : 0);
29362913
}
29372914

29382915
bool isAutoClosure() const { return Bits & AutoClosureMask; }
29392916
bool isNoEscape() const { return Bits & NoEscapeMask; }
29402917
bool throws() const { return Bits & ThrowsMask; }
29412918
// SWIFT_ENABLE_TENSORFLOW
2942-
bool isDifferentiable() const { return Bits & DifferentiabilityMask; }
2919+
bool isDifferentiable() const { return Bits & DifferentiableMask; }
29432920
Representation getRepresentation() const {
29442921
unsigned rawRep = Bits & RepresentationMask;
29452922
assert(rawRep <= unsigned(Representation::Last)
29462923
&& "unexpected SIL representation");
29472924
return Representation(rawRep);
29482925
}
2949-
// SWIFT_ENABLE_TENSORFLOW
2950-
Differentiability getDifferentiability() const {
2951-
return Differentiability(
2952-
(Bits & DifferentiabilityMask) >> DifferentiabilityOffset);
2953-
}
29542926

29552927
bool hasSelfParam() const {
29562928
switch (getSILRepresentation()) {
@@ -3021,9 +2993,11 @@ class AnyFunctionType : public TypeBase {
30212993
}
30222994
// SWIFT_ENABLE_TENSORFLOW
30232995
LLVM_NODISCARD
3024-
ExtInfo withDifferentiability(Differentiability diff) const {
3025-
return ExtInfo((Bits & ~DifferentiabilityMask) |
3026-
(unsigned)diff << DifferentiabilityOffset);
2996+
ExtInfo withDifferentiable(bool isDifferentiable = true) const {
2997+
if (isDifferentiable)
2998+
return ExtInfo(Bits | DifferentiableMask);
2999+
else
3000+
return ExtInfo(Bits & ~DifferentiableMask);
30273001
}
30283002

30293003
unsigned getFuncAttrKey() const {
@@ -3109,12 +3083,6 @@ class AnyFunctionType : public TypeBase {
31093083
return getExtInfo().getRepresentation();
31103084
}
31113085

3112-
// SWIFT_ENABLE_TENSORFLOW
3113-
/// \brief Get the differentiability of the function type.
3114-
Differentiability getDifferentiability() const {
3115-
return getExtInfo().getDifferentiability();
3116-
}
3117-
31183086
/// Given `indices`, `differentiationOrder`, and `kind`, calculates the type
31193087
/// of the corresponding autodiff associated function.
31203088
///
@@ -3783,8 +3751,6 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37833751
public:
37843752
using Language = SILFunctionLanguage;
37853753
using Representation = SILFunctionTypeRepresentation;
3786-
// SWIFT_ENABLE_TENSORFLOW
3787-
using Differentiability = FunctionTypeDifferentiability;
37883754

37893755
/// \brief A class which abstracts out some details necessary for
37903756
/// making a call.
@@ -3801,9 +3767,8 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
38013767
PseudogenericMask = 1 << 4,
38023768
NoEscapeMask = 1 << 5,
38033769
// SWIFT_ENABLE_TENSORFLOW
3804-
DifferentiabilityOffset = 6,
3805-
DifferentiabilityMask = 0b111 << DifferentiabilityOffset,
3806-
NumMaskBits = 9
3770+
DifferentiableMask = 1 << 6,
3771+
NumMaskBits = 7
38073772
};
38083773

38093774
unsigned Bits; // Naturally sized for speed.
@@ -3819,12 +3784,12 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
38193784
// Constructor for polymorphic type.
38203785
// SWIFT_ENABLE_TENSORFLOW
38213786
ExtInfo(Representation rep, bool isPseudogeneric, bool isNoEscape,
3822-
Differentiability diff) {
3787+
bool isDifferentiable) {
38233788
Bits = ((unsigned) rep) |
38243789
(isPseudogeneric ? PseudogenericMask : 0) |
38253790
// SWIFT_ENABLE_TENSORFLOW
38263791
(isNoEscape ? NoEscapeMask : 0) |
3827-
((unsigned)diff << DifferentiabilityOffset);
3792+
(isDifferentiable ? DifferentiableMask : 0);
38283793
}
38293794

38303795
/// Is this function pseudo-generic? A pseudo-generic function
@@ -3835,12 +3800,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
38353800
bool isNoEscape() const { return Bits & NoEscapeMask; }
38363801

38373802
// SWIFT_ENABLE_TENSORFLOW
3838-
bool isDifferentiable() const { return Bits & DifferentiabilityMask; }
3839-
3840-
Differentiability getDifferentiability() const {
3841-
return Differentiability(
3842-
(Bits & DifferentiabilityMask) >> DifferentiabilityOffset);
3843-
}
3803+
bool isDifferentiable() const { return Bits & DifferentiableMask; }
38443804

38453805
/// What is the abstract representation of this function value?
38463806
Representation getRepresentation() const {
@@ -3908,9 +3868,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
39083868
return ExtInfo(Bits & ~NoEscapeMask);
39093869
}
39103870
// SWIFT_ENABLE_TENSORFLOW
3911-
ExtInfo withDifferentiability(Differentiability diff) const {
3912-
return ExtInfo((Bits & ~DifferentiabilityMask) |
3913-
(unsigned)diff << DifferentiabilityOffset);
3871+
ExtInfo withDifferentiable(bool isDifferentiable = true) const {
3872+
if (isDifferentiable)
3873+
return ExtInfo(Bits | DifferentiableMask);
3874+
else
3875+
return ExtInfo(Bits & ~DifferentiableMask);
39143876
}
39153877

39163878
unsigned getFuncAttrKey() const {
@@ -4253,12 +4215,6 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
42534215
return getExtInfo().getRepresentation();
42544216
}
42554217

4256-
// SWIFT_ENABLE_TENSORFLOW
4257-
/// \brief Get the differentiability of the function type.
4258-
Differentiability getDifferentiability() const {
4259-
return getExtInfo().getDifferentiability();
4260-
}
4261-
42624218
bool isPseudogeneric() const {
42634219
return getExtInfo().isPseudogeneric();
42644220
}

include/swift/Serialization/ModuleFormat.h

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -202,19 +202,6 @@ enum class SILFunctionTypeRepresentation : uint8_t {
202202
};
203203
using SILFunctionTypeRepresentationField = BCFixed<4>;
204204

205-
// SWIFT_ENABLE_TENSORFLOW
206-
// These IDs must \em not be renumbered or reordered without incrementing
207-
// VERSION_MAJOR.
208-
enum class FunctionTypeDifferentiability : uint8_t {
209-
None = 0,
210-
Forward,
211-
Reverse,
212-
Bidirectional,
213-
Linear,
214-
Constant,
215-
};
216-
using FunctionTypeDifferentiabilityField = BCFixed<3>;
217-
218205
// These IDs must \em not be renumbered or reordered without incrementing
219206
// the module version.
220207
enum class SILCoroutineKind : uint8_t {
@@ -752,7 +739,7 @@ namespace decls_block {
752739
BCFixed<1>, // noescape?
753740
// SWIFT_ENABLE_TENSORFLOW
754741
BCFixed<1>, // throws?
755-
FunctionTypeDifferentiabilityField // differentiability
742+
BCFixed<1> // differentiable?
756743
// trailed by parameters
757744
>;
758745

@@ -815,8 +802,8 @@ namespace decls_block {
815802
FunctionTypeRepresentationField, // representation
816803
BCFixed<1>, // throws?
817804
// SWIFT_ENABLE_TENSORFLOW
818-
GenericSignatureIDField, // generic signture
819-
BCFixed<3> // differentiability
805+
BCFixed<1>, // differentiable?
806+
GenericSignatureIDField // generic signture
820807

821808
// trailed by parameters
822809
>;
@@ -829,7 +816,7 @@ namespace decls_block {
829816
BCFixed<1>, // pseudogeneric?
830817
BCFixed<1>, // noescape?
831818
// SWIFT_ENABLE_TENSORFLOW
832-
FunctionTypeDifferentiabilityField, // differentiability
819+
BCFixed<1>, // differentiable?
833820
BCFixed<1>, // error result?
834821
BCFixed<30>, // number of parameters
835822
BCFixed<30>, // number of yields

0 commit comments

Comments
 (0)