Skip to content

[AutoDiff] Simplify function type differentiability to be binary #21211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions include/swift/ABI/MetadataValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -772,17 +772,6 @@ enum class FunctionMetadataConvention: uint8_t {
CFunctionPointer = 3,
};

// SWIFT_ENABLE_TENSORFLOW
/// Differentiability values for function type metadata.
enum class FunctionMetadataDifferentiability: uint8_t {
None = 0,
Forward = 1,
Reverse = 2,
Bidirectional = 3,
Linear = 4,
Constant = 5
};

/// Flags in a function type metadata record.
template <typename int_type>
class TargetFunctionTypeFlags {
Expand All @@ -797,8 +786,7 @@ class TargetFunctionTypeFlags {
ParamFlagsMask = 0x02000000U,
EscapingMask = 0x04000000U,
// SWIFT_ENABLE_TENSORFLOW
DifferentiabilityShift = 27U,
DifferentiabilityMask = 0x38000000U
DifferentiableMask = 0x08000000U
};
int_type Data;

Expand Down Expand Up @@ -837,9 +825,9 @@ class TargetFunctionTypeFlags {

// SWIFT_ENABLE_TENSORFLOW
constexpr TargetFunctionTypeFlags<int_type>
withDifferentiability(FunctionMetadataDifferentiability diffability) const {
return TargetFunctionTypeFlags((Data & ~DifferentiabilityMask)
| (int_type(diffability) << DifferentiabilityShift));
withDifferentiable(bool isDifferentiable) const {
return TargetFunctionTypeFlags<int_type>((Data & ~DifferentiableMask) |
(isDifferentiable ? DifferentiableMask : 0));
}

unsigned getNumParameters() const { return Data & NumParametersMask; }
Expand All @@ -857,9 +845,8 @@ class TargetFunctionTypeFlags {
}

// SWIFT_ENABLE_TENSORFLOW
FunctionMetadataDifferentiability getDifferentiability() const {
return FunctionMetadataDifferentiability(
(Data & DifferentiabilityMask) >> DifferentiabilityShift);
bool isDifferentiable() const {
return bool (Data & DifferentiableMask);
}

bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }
Expand Down
10 changes: 0 additions & 10 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ class TypeAttributes {
SourceLoc AtLoc;
Optional<StringRef> convention = None;
Optional<StringRef> conventionWitnessMethodProtocol = None;
// SWIFT_ENABLE_TENSORFLOW
Optional<std::pair<StringRef, int>> differentiabilityAndOrder = None;

// For an opened existential type, the known ID.
Optional<UUID> OpenedID;
Expand Down Expand Up @@ -112,14 +110,6 @@ class TypeAttributes {
bool hasConvention() const { return convention.hasValue(); }
StringRef getConvention() const { return *convention; }

// SWIFT_ENABLE_TENSORFLOW
bool hasDifferentiability() const {
return differentiabilityAndOrder.hasValue();
}
std::pair<StringRef, int> getDifferentiabilityAndOrder() const {
return *differentiabilityAndOrder;
}

bool hasOwnership() const {
return getOwnership() != ReferenceOwnership::Strong;
}
Expand Down
39 changes: 0 additions & 39 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,45 +241,6 @@ class AutoDiffParameterIndicesBuilder {
void setSelfParameter();
};

/// Differentiability of a function specifies the differentiation mode,
/// parameter indices at which the function is differentiable with respect to,
/// and indices of results which can be differentiated.
class Differentiability {
private:
// The differentiation mode.
AutoDiffMode mode;
// Differentiable with respect to `self`, applicable to methods only.
bool wrtSelf;
// Indices of parameters that are differentiable with respect to.
llvm::SmallBitVector parameterIndices;
// Indices of results that are differentiable.
llvm::SmallBitVector resultIndices;

public:
Differentiability(AutoDiffMode mode,
bool wrtSelf,
llvm::SmallBitVector parameterIndices,
llvm::SmallBitVector resultIndices);

Differentiability(AutoDiffMode mode, AnyFunctionType *type);

AutoDiffMode getMode() const {
return mode;
}

bool isWithRespectToSelf() const {
return wrtSelf;
}

const llvm::SmallBitVector &getParameterIndices() const {
return parameterIndices;
}

const llvm::SmallBitVector &getResultIndices() const {
return resultIndices;
}
};

/// SIL-level automatic differentiation indices. Consists of a source index,
/// i.e. index of the dependent result to differentiate from, and parameter
/// indices, i.e. index of independent parameters to differentiate with
Expand Down
6 changes: 0 additions & 6 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3680,12 +3680,6 @@ ERROR(unreferenced_generic_parameter,none,

// SWIFT_ENABLE_TENSORFLOW
// Function differentiability
ERROR(autodiff_attr_invalid_differentiability,none,
"invalid differentiability '%0' in '@autodiff' attribute; expected 'forward', 'reverse', 'linear', 'constant', or 'bidirectional'", (StringRef))
ERROR(autodiff_attr_order_cannot_be_zero,none,
"differentiation order cannot be zero; it should be at least first-order", ())
ERROR(autodiff_attr_order_cannot_be_specified_in_mode,none,
"differentiation order cannot be specified in '%0' mode", (StringRef))
ERROR(autodiff_attr_argument_not_differentiable,none,
"argument is not differentiable, but the enclosing function type is marked '@autodiff'; did you want to add '@nondiff' to this argument?", ())
ERROR(autodiff_attr_result_not_differentiable,none,
Expand Down
88 changes: 22 additions & 66 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ class alignas(1 << TypeAlignInBits) TypeBase {

protected:
// SWIFT_ENABLE_TENSORFLOW
enum { NumAFTExtInfoBits = 10 };
enum { NumSILExtInfoBits = 9 };
enum { NumAFTExtInfoBits = 8 };
enum { NumSILExtInfoBits = 7 };
union { uint64_t OpaqueBits;

SWIFT_INLINE_BITFIELD_BASE(TypeBase, bitmax(NumTypeKindBits,8) +
Expand Down Expand Up @@ -2706,23 +2706,6 @@ getSILFunctionLanguage(SILFunctionTypeRepresentation rep) {
llvm_unreachable("Unhandled SILFunctionTypeRepresentation in switch.");
}

// SWIFT_ENABLE_TENSORFLOW
/// The differentiability of a function type.
enum class FunctionTypeDifferentiability : uint8_t {
/// Non-differentiable.
None = 0,
/// Forward-mode differentiable.
Forward,
/// Reverse-mode differentiable.
Reverse,
/// Both forward-mode and reverse-mode differentiable.
Bidirectional,
/// Linear map.
Linear,
/// Constant function, whose derivatives are always zero.
Constant,
};

/// AnyFunctionType - A function type has zero or more input parameters and a
/// single result. The result type may be a tuple. For example:
/// "(int) -> int" or "(a : int, b : int) -> (int, int)".
Expand All @@ -2736,8 +2719,6 @@ class AnyFunctionType : public TypeBase {

public:
using Representation = FunctionTypeRepresentation;
// SWIFT_ENABLE_TENSORFLOW
using Differentiability = FunctionTypeDifferentiability;

class Param {
public:
Expand Down Expand Up @@ -2898,9 +2879,8 @@ class AnyFunctionType : public TypeBase {
NoEscapeMask = 1 << 5,
ThrowsMask = 1 << 6,
// SWIFT_ENABLE_TENSORFLOW
DifferentiabilityOffset = 7,
DifferentiabilityMask = 0b111 << DifferentiabilityOffset,
NumMaskBits = 10
DifferentiableMask = 1 << 7,
NumMaskBits = 8
};

unsigned Bits; // Naturally sized for speed.
Expand All @@ -2913,8 +2893,6 @@ class AnyFunctionType : public TypeBase {
// Constructor with all defaults.
ExtInfo() : Bits(0) {
assert(getRepresentation() == Representation::Swift);
// SWIFT_ENABLE_TENSORFLOW
assert(getDifferentiability() == Differentiability::None);
}

// Constructor for polymorphic type.
Expand All @@ -2926,31 +2904,25 @@ class AnyFunctionType : public TypeBase {
ExtInfo(Representation Rep,
bool IsAutoClosure, bool IsNoEscape,
// SWIFT_ENABLE_TENSORFLOW
bool Throws, Differentiability Diff)
bool Throws, bool IsDifferentiable)
: ExtInfo(Rep, Throws) {
Bits |= (IsAutoClosure ? AutoClosureMask : 0);
Bits |= (IsNoEscape ? NoEscapeMask : 0);
// SWIFT_ENABLE_TENSORFLOW
Bits |=
(((unsigned)Diff << DifferentiabilityOffset) & DifferentiabilityMask);
Bits |= (IsDifferentiable ? DifferentiableMask : 0);
}

bool isAutoClosure() const { return Bits & AutoClosureMask; }
bool isNoEscape() const { return Bits & NoEscapeMask; }
bool throws() const { return Bits & ThrowsMask; }
// SWIFT_ENABLE_TENSORFLOW
bool isDifferentiable() const { return Bits & DifferentiabilityMask; }
bool isDifferentiable() const { return Bits & DifferentiableMask; }
Representation getRepresentation() const {
unsigned rawRep = Bits & RepresentationMask;
assert(rawRep <= unsigned(Representation::Last)
&& "unexpected SIL representation");
return Representation(rawRep);
}
// SWIFT_ENABLE_TENSORFLOW
Differentiability getDifferentiability() const {
return Differentiability(
(Bits & DifferentiabilityMask) >> DifferentiabilityOffset);
}

bool hasSelfParam() const {
switch (getSILRepresentation()) {
Expand Down Expand Up @@ -3021,9 +2993,11 @@ class AnyFunctionType : public TypeBase {
}
// SWIFT_ENABLE_TENSORFLOW
LLVM_NODISCARD
ExtInfo withDifferentiability(Differentiability diff) const {
return ExtInfo((Bits & ~DifferentiabilityMask) |
(unsigned)diff << DifferentiabilityOffset);
ExtInfo withDifferentiable(bool isDifferentiable = true) const {
if (isDifferentiable)
return ExtInfo(Bits | DifferentiableMask);
else
return ExtInfo(Bits & ~DifferentiableMask);
}

unsigned getFuncAttrKey() const {
Expand Down Expand Up @@ -3109,12 +3083,6 @@ class AnyFunctionType : public TypeBase {
return getExtInfo().getRepresentation();
}

// SWIFT_ENABLE_TENSORFLOW
/// \brief Get the differentiability of the function type.
Differentiability getDifferentiability() const {
return getExtInfo().getDifferentiability();
}

/// Given `indices`, `differentiationOrder`, and `kind`, calculates the type
/// of the corresponding autodiff associated function.
///
Expand Down Expand Up @@ -3783,8 +3751,6 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
public:
using Language = SILFunctionLanguage;
using Representation = SILFunctionTypeRepresentation;
// SWIFT_ENABLE_TENSORFLOW
using Differentiability = FunctionTypeDifferentiability;

/// \brief A class which abstracts out some details necessary for
/// making a call.
Expand All @@ -3801,9 +3767,8 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
PseudogenericMask = 1 << 4,
NoEscapeMask = 1 << 5,
// SWIFT_ENABLE_TENSORFLOW
DifferentiabilityOffset = 6,
DifferentiabilityMask = 0b111 << DifferentiabilityOffset,
NumMaskBits = 9
DifferentiableMask = 1 << 6,
NumMaskBits = 7
};

unsigned Bits; // Naturally sized for speed.
Expand All @@ -3819,12 +3784,12 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
// Constructor for polymorphic type.
// SWIFT_ENABLE_TENSORFLOW
ExtInfo(Representation rep, bool isPseudogeneric, bool isNoEscape,
Differentiability diff) {
bool isDifferentiable) {
Bits = ((unsigned) rep) |
(isPseudogeneric ? PseudogenericMask : 0) |
// SWIFT_ENABLE_TENSORFLOW
(isNoEscape ? NoEscapeMask : 0) |
((unsigned)diff << DifferentiabilityOffset);
(isDifferentiable ? DifferentiableMask : 0);
}

/// Is this function pseudo-generic? A pseudo-generic function
Expand All @@ -3835,12 +3800,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
bool isNoEscape() const { return Bits & NoEscapeMask; }

// SWIFT_ENABLE_TENSORFLOW
bool isDifferentiable() const { return Bits & DifferentiabilityMask; }

Differentiability getDifferentiability() const {
return Differentiability(
(Bits & DifferentiabilityMask) >> DifferentiabilityOffset);
}
bool isDifferentiable() const { return Bits & DifferentiableMask; }

/// What is the abstract representation of this function value?
Representation getRepresentation() const {
Expand Down Expand Up @@ -3908,9 +3868,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
return ExtInfo(Bits & ~NoEscapeMask);
}
// SWIFT_ENABLE_TENSORFLOW
ExtInfo withDifferentiability(Differentiability diff) const {
return ExtInfo((Bits & ~DifferentiabilityMask) |
(unsigned)diff << DifferentiabilityOffset);
ExtInfo withDifferentiable(bool isDifferentiable = true) const {
if (isDifferentiable)
return ExtInfo(Bits | DifferentiableMask);
else
return ExtInfo(Bits & ~DifferentiableMask);
}

unsigned getFuncAttrKey() const {
Expand Down Expand Up @@ -4253,12 +4215,6 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
return getExtInfo().getRepresentation();
}

// SWIFT_ENABLE_TENSORFLOW
/// \brief Get the differentiability of the function type.
Differentiability getDifferentiability() const {
return getExtInfo().getDifferentiability();
}

bool isPseudogeneric() const {
return getExtInfo().isPseudogeneric();
}
Expand Down
21 changes: 4 additions & 17 deletions include/swift/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,19 +202,6 @@ enum class SILFunctionTypeRepresentation : uint8_t {
};
using SILFunctionTypeRepresentationField = BCFixed<4>;

// SWIFT_ENABLE_TENSORFLOW
// These IDs must \em not be renumbered or reordered without incrementing
// VERSION_MAJOR.
enum class FunctionTypeDifferentiability : uint8_t {
None = 0,
Forward,
Reverse,
Bidirectional,
Linear,
Constant,
};
using FunctionTypeDifferentiabilityField = BCFixed<3>;

// These IDs must \em not be renumbered or reordered without incrementing
// the module version.
enum class SILCoroutineKind : uint8_t {
Expand Down Expand Up @@ -752,7 +739,7 @@ namespace decls_block {
BCFixed<1>, // noescape?
// SWIFT_ENABLE_TENSORFLOW
BCFixed<1>, // throws?
FunctionTypeDifferentiabilityField // differentiability
BCFixed<1> // differentiable?
// trailed by parameters
>;

Expand Down Expand Up @@ -815,8 +802,8 @@ namespace decls_block {
FunctionTypeRepresentationField, // representation
BCFixed<1>, // throws?
// SWIFT_ENABLE_TENSORFLOW
GenericSignatureIDField, // generic signture
BCFixed<3> // differentiability
BCFixed<1>, // differentiable?
GenericSignatureIDField // generic signture

// trailed by parameters
>;
Expand All @@ -829,7 +816,7 @@ namespace decls_block {
BCFixed<1>, // pseudogeneric?
BCFixed<1>, // noescape?
// SWIFT_ENABLE_TENSORFLOW
FunctionTypeDifferentiabilityField, // differentiability
BCFixed<1>, // differentiable?
BCFixed<1>, // error result?
BCFixed<30>, // number of parameters
BCFixed<30>, // number of yields
Expand Down
Loading