@@ -281,8 +281,8 @@ class alignas(1 << TypeAlignInBits) TypeBase {
281
281
282
282
protected:
283
283
// SWIFT_ENABLE_TENSORFLOW
284
- enum { NumAFTExtInfoBits = 10 };
285
- enum { NumSILExtInfoBits = 9 };
284
+ enum { NumAFTExtInfoBits = 8 };
285
+ enum { NumSILExtInfoBits = 7 };
286
286
union { uint64_t OpaqueBits;
287
287
288
288
SWIFT_INLINE_BITFIELD_BASE (TypeBase, bitmax (NumTypeKindBits,8 ) +
@@ -2706,23 +2706,6 @@ getSILFunctionLanguage(SILFunctionTypeRepresentation rep) {
2706
2706
llvm_unreachable (" Unhandled SILFunctionTypeRepresentation in switch." );
2707
2707
}
2708
2708
2709
- // SWIFT_ENABLE_TENSORFLOW
2710
- // / The differentiability of a function type.
2711
- enum class FunctionTypeDifferentiability : uint8_t {
2712
- // / Non-differentiable.
2713
- None = 0 ,
2714
- // / Forward-mode differentiable.
2715
- Forward,
2716
- // / Reverse-mode differentiable.
2717
- Reverse,
2718
- // / Both forward-mode and reverse-mode differentiable.
2719
- Bidirectional,
2720
- // / Linear map.
2721
- Linear,
2722
- // / Constant function, whose derivatives are always zero.
2723
- Constant,
2724
- };
2725
-
2726
2709
// / AnyFunctionType - A function type has zero or more input parameters and a
2727
2710
// / single result. The result type may be a tuple. For example:
2728
2711
// / "(int) -> int" or "(a : int, b : int) -> (int, int)".
@@ -2736,8 +2719,6 @@ class AnyFunctionType : public TypeBase {
2736
2719
2737
2720
public:
2738
2721
using Representation = FunctionTypeRepresentation;
2739
- // SWIFT_ENABLE_TENSORFLOW
2740
- using Differentiability = FunctionTypeDifferentiability;
2741
2722
2742
2723
class Param {
2743
2724
public:
@@ -2898,9 +2879,8 @@ class AnyFunctionType : public TypeBase {
2898
2879
NoEscapeMask = 1 << 5 ,
2899
2880
ThrowsMask = 1 << 6 ,
2900
2881
// SWIFT_ENABLE_TENSORFLOW
2901
- DifferentiabilityOffset = 7 ,
2902
- DifferentiabilityMask = 0b111 << DifferentiabilityOffset,
2903
- NumMaskBits = 10
2882
+ DifferentiableMask = 1 << 7 ,
2883
+ NumMaskBits = 8
2904
2884
};
2905
2885
2906
2886
unsigned Bits; // Naturally sized for speed.
@@ -2913,8 +2893,6 @@ class AnyFunctionType : public TypeBase {
2913
2893
// Constructor with all defaults.
2914
2894
ExtInfo () : Bits(0 ) {
2915
2895
assert (getRepresentation () == Representation::Swift);
2916
- // SWIFT_ENABLE_TENSORFLOW
2917
- assert (getDifferentiability () == Differentiability::None);
2918
2896
}
2919
2897
2920
2898
// Constructor for polymorphic type.
@@ -2926,31 +2904,25 @@ class AnyFunctionType : public TypeBase {
2926
2904
ExtInfo (Representation Rep,
2927
2905
bool IsAutoClosure, bool IsNoEscape,
2928
2906
// SWIFT_ENABLE_TENSORFLOW
2929
- bool Throws, Differentiability Diff )
2907
+ bool Throws, bool IsDifferentiable )
2930
2908
: ExtInfo(Rep, Throws) {
2931
2909
Bits |= (IsAutoClosure ? AutoClosureMask : 0 );
2932
2910
Bits |= (IsNoEscape ? NoEscapeMask : 0 );
2933
2911
// SWIFT_ENABLE_TENSORFLOW
2934
- Bits |=
2935
- (((unsigned )Diff << DifferentiabilityOffset) & DifferentiabilityMask);
2912
+ Bits |= (IsDifferentiable ? DifferentiableMask : 0 );
2936
2913
}
2937
2914
2938
2915
bool isAutoClosure () const { return Bits & AutoClosureMask; }
2939
2916
bool isNoEscape () const { return Bits & NoEscapeMask; }
2940
2917
bool throws () const { return Bits & ThrowsMask; }
2941
2918
// SWIFT_ENABLE_TENSORFLOW
2942
- bool isDifferentiable () const { return Bits & DifferentiabilityMask ; }
2919
+ bool isDifferentiable () const { return Bits & DifferentiableMask ; }
2943
2920
Representation getRepresentation () const {
2944
2921
unsigned rawRep = Bits & RepresentationMask;
2945
2922
assert (rawRep <= unsigned (Representation::Last)
2946
2923
&& " unexpected SIL representation" );
2947
2924
return Representation (rawRep);
2948
2925
}
2949
- // SWIFT_ENABLE_TENSORFLOW
2950
- Differentiability getDifferentiability () const {
2951
- return Differentiability (
2952
- (Bits & DifferentiabilityMask) >> DifferentiabilityOffset);
2953
- }
2954
2926
2955
2927
bool hasSelfParam () const {
2956
2928
switch (getSILRepresentation ()) {
@@ -3021,9 +2993,11 @@ class AnyFunctionType : public TypeBase {
3021
2993
}
3022
2994
// SWIFT_ENABLE_TENSORFLOW
3023
2995
LLVM_NODISCARD
3024
- ExtInfo withDifferentiability (Differentiability diff) const {
3025
- return ExtInfo ((Bits & ~DifferentiabilityMask) |
3026
- (unsigned )diff << DifferentiabilityOffset);
2996
+ ExtInfo withDifferentiable (bool isDifferentiable = true ) const {
2997
+ if (isDifferentiable)
2998
+ return ExtInfo (Bits | DifferentiableMask);
2999
+ else
3000
+ return ExtInfo (Bits & ~DifferentiableMask);
3027
3001
}
3028
3002
3029
3003
unsigned getFuncAttrKey () const {
@@ -3109,12 +3083,6 @@ class AnyFunctionType : public TypeBase {
3109
3083
return getExtInfo ().getRepresentation ();
3110
3084
}
3111
3085
3112
- // SWIFT_ENABLE_TENSORFLOW
3113
- // / \brief Get the differentiability of the function type.
3114
- Differentiability getDifferentiability () const {
3115
- return getExtInfo ().getDifferentiability ();
3116
- }
3117
-
3118
3086
// / Given `indices`, `differentiationOrder`, and `kind`, calculates the type
3119
3087
// / of the corresponding autodiff associated function.
3120
3088
// /
@@ -3783,8 +3751,6 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
3783
3751
public:
3784
3752
using Language = SILFunctionLanguage;
3785
3753
using Representation = SILFunctionTypeRepresentation;
3786
- // SWIFT_ENABLE_TENSORFLOW
3787
- using Differentiability = FunctionTypeDifferentiability;
3788
3754
3789
3755
// / \brief A class which abstracts out some details necessary for
3790
3756
// / making a call.
@@ -3801,9 +3767,8 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
3801
3767
PseudogenericMask = 1 << 4 ,
3802
3768
NoEscapeMask = 1 << 5 ,
3803
3769
// SWIFT_ENABLE_TENSORFLOW
3804
- DifferentiabilityOffset = 6 ,
3805
- DifferentiabilityMask = 0b111 << DifferentiabilityOffset,
3806
- NumMaskBits = 9
3770
+ DifferentiableMask = 1 << 6 ,
3771
+ NumMaskBits = 7
3807
3772
};
3808
3773
3809
3774
unsigned Bits; // Naturally sized for speed.
@@ -3819,12 +3784,12 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
3819
3784
// Constructor for polymorphic type.
3820
3785
// SWIFT_ENABLE_TENSORFLOW
3821
3786
ExtInfo (Representation rep, bool isPseudogeneric, bool isNoEscape,
3822
- Differentiability diff ) {
3787
+ bool isDifferentiable ) {
3823
3788
Bits = ((unsigned ) rep) |
3824
3789
(isPseudogeneric ? PseudogenericMask : 0 ) |
3825
3790
// SWIFT_ENABLE_TENSORFLOW
3826
3791
(isNoEscape ? NoEscapeMask : 0 ) |
3827
- (( unsigned )diff << DifferentiabilityOffset );
3792
+ (isDifferentiable ? DifferentiableMask : 0 );
3828
3793
}
3829
3794
3830
3795
// / Is this function pseudo-generic? A pseudo-generic function
@@ -3835,12 +3800,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
3835
3800
bool isNoEscape () const { return Bits & NoEscapeMask; }
3836
3801
3837
3802
// SWIFT_ENABLE_TENSORFLOW
3838
- bool isDifferentiable () const { return Bits & DifferentiabilityMask; }
3839
-
3840
- Differentiability getDifferentiability () const {
3841
- return Differentiability (
3842
- (Bits & DifferentiabilityMask) >> DifferentiabilityOffset);
3843
- }
3803
+ bool isDifferentiable () const { return Bits & DifferentiableMask; }
3844
3804
3845
3805
// / What is the abstract representation of this function value?
3846
3806
Representation getRepresentation () const {
@@ -3908,9 +3868,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
3908
3868
return ExtInfo (Bits & ~NoEscapeMask);
3909
3869
}
3910
3870
// SWIFT_ENABLE_TENSORFLOW
3911
- ExtInfo withDifferentiability (Differentiability diff) const {
3912
- return ExtInfo ((Bits & ~DifferentiabilityMask) |
3913
- (unsigned )diff << DifferentiabilityOffset);
3871
+ ExtInfo withDifferentiable (bool isDifferentiable = true ) const {
3872
+ if (isDifferentiable)
3873
+ return ExtInfo (Bits | DifferentiableMask);
3874
+ else
3875
+ return ExtInfo (Bits & ~DifferentiableMask);
3914
3876
}
3915
3877
3916
3878
unsigned getFuncAttrKey () const {
@@ -4253,12 +4215,6 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
4253
4215
return getExtInfo ().getRepresentation ();
4254
4216
}
4255
4217
4256
- // SWIFT_ENABLE_TENSORFLOW
4257
- // / \brief Get the differentiability of the function type.
4258
- Differentiability getDifferentiability () const {
4259
- return getExtInfo ().getDifferentiability ();
4260
- }
4261
-
4262
4218
bool isPseudogeneric () const {
4263
4219
return getExtInfo ().isPseudogeneric ();
4264
4220
}
0 commit comments