Skip to content

uniquely allocate AutoDiffParameterIndices and AutoDiffAssociatedFunctionIdentifier #21193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 97 additions & 74 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,107 +79,60 @@ class AutoDiffParameter {
};

class AnyFunctionType;
class AutoDiffParameterIndicesBuilder;
class Type;

/// Identifies a subset of a function's parameters.
///
/// Works with AST-level function decls and types. Requires further lowering to
/// work with SIL-level functions and types. (In particular, tuples must be
/// exploded).
class AutoDiffParameterIndices {
///
/// Is uniquely allocated within an ASTContext so that it can be hashed and
/// compared by opaque pointer value.
class AutoDiffParameterIndices : public llvm::FoldingSetNode {
friend AutoDiffParameterIndicesBuilder;

/// Bits corresponding to parameters in the set are "on", and bits
/// corresponding to parameters not in the set are "off".
///
/// Normally, the bits correspond to the function's parameters in order. For
/// example,
/// For non-method functions, the bits correspond to the function's
//// parameters in order. For example,
///
/// Function type: (A, B, C) -> R
/// Bits: [A][B][C]
///
/// When `isMethodFlag` is set, the bits correspond to the function's
/// non-self parameters in order, followed by the function's self parameter.
/// For example,
/// For methods, the bits correspond to the function's non-self parameters
/// in order, followed by the function's self parameter. For example,
///
/// Function type: (Self) -> (A, B, C) -> R
/// Bits: [A][B][C][Self]
///
llvm::SmallBitVector indices;
const llvm::SmallBitVector indices;

/// Whether the function is a method.
///
bool isMethodFlag;
AutoDiffParameterIndices(llvm::SmallBitVector indices)
: indices(indices) {}

unsigned getNumNonSelfParameters() const;

AutoDiffParameterIndices(unsigned numIndices, bool isMethodFlag,
bool setAllParams = false)
: indices(numIndices, setAllParams), isMethodFlag(isMethodFlag) {}

AutoDiffParameterIndices(llvm::SmallBitVector indices, bool isMethodFlag)
: indices(indices), isMethodFlag(isMethodFlag) {}
static AutoDiffParameterIndices *get(llvm::SmallBitVector indices,
ASTContext &C);

public:
/// Allocates and initializes an empty `AutoDiffParameterIndices` for the
/// given `functionType`. `isMethod` specifies whether to treat the function
/// as a method.
static AutoDiffParameterIndices *
create(ASTContext &C, AnyFunctionType *functionType, bool isMethod,
bool setAllParams = false);

bool isMethod() const { return isMethodFlag; }

/// Allocates and initializes an `AutoDiffParameterIndices` corresponding to
/// the given `string` generated by `getString()`. If the string is invalid,
/// returns nullptr.
static AutoDiffParameterIndices *create(ASTContext &C, StringRef string);

/// Returns a textual string description of these indices,
///
/// [FM][SU]+
/// [SU]+
///
/// "F" means that `isMethodFlag` is false
/// "M" means that `isMethodFlag` is true
/// "S" means that the corresponding index is set
/// "U" means that the corresponding index is unset
std::string getString() const;

/// Tests whether this set of parameters is empty.
bool isEmpty() const { return indices.none(); }

/// Adds the indexed parameter to the set. When `isMethodFlag` is not set,
/// the indices index into the first parameter list. For example,
///
/// functionType = (A, B, C) -> R
/// paramIndex = 0
/// ==> adds "A" to the set.
///
/// When `isMethodFlag` is set, the indices index into the first non-self
/// parameter list. For example,
///
/// functionType = (Self) -> (A, B, C) -> R
/// paramIndex = 0
/// ==> adds "A" to the set.
///
void setNonSelfParameter(unsigned parameterIndex);

/// Adds all the paramaters from the first non-self parameter list to the set.
/// For example,
///
/// functionType = (A, B, C) -> R
/// ==> adds "A", B", and "C" to the set.
///
/// functionType = (Self) -> (A, B, C) -> R
/// ==> adds "A", B", and "C" to the set.
///
void setAllNonSelfParameters();

/// Adds the self parameter to the set. `isMethodFlag` must be set. For
/// example,
/// functionType = (Self) -> (A, B, C) -> R
/// ==> adds "Self" to the set
///
void setSelfParameter();

/// Pushes the subset's parameter's types to `paramTypes`, in the order in
/// which they appear in the function type. For example,
///
Expand All @@ -191,11 +144,14 @@ class AutoDiffParameterIndices {
/// if "Self" and "C" are in the set,
/// ==> pushes {Self, C} to `paramTypes`.
///
/// Pass `isMethod = true` when the function is a method.
///
/// Pass `selfUncurried = true` when the function type is for a method whose
/// self parameter has been uncurried as in (A, B, C, Self) -> R.
///
void getSubsetParameterTypes(AnyFunctionType *functionType,
SmallVectorImpl<Type> &paramTypes,
bool isMethod,
bool selfUncurried = false) const;

/// Returns a bitvector for the SILFunction parameters corresponding to the
Expand All @@ -217,17 +173,74 @@ class AutoDiffParameterIndices {
/// ==> returns 1110
/// (because the lowered SIL type is (A, B, C, D) -> R)
///
/// Pass `isMethod = true` when the function is a method.
///
/// Pass `selfUncurried = true` when the function type is for a method whose
/// self parameter has been uncurried as in (A, B, C, Self) -> R.
///
llvm::SmallBitVector getLowered(AnyFunctionType *functionType,
bool isMethod,
bool selfUncurried = false) const;

bool operator==(const AutoDiffParameterIndices &other) const {
return isMethodFlag == other.isMethodFlag && indices == other.indices;
void Profile(llvm::FoldingSetNodeID &ID) const {
ID.AddInteger(indices.size());
for (unsigned setBit : indices.set_bits())
ID.AddInteger(setBit);
}
};

/// Builder for `AutoDiffParameterIndices`.
class AutoDiffParameterIndicesBuilder {
llvm::SmallBitVector indices;
bool isMethod;

unsigned getNumNonSelfParameters() const;

public:
/// Start building an `AutoDiffParameterIndices` for the given function type.
/// `isMethod` specifies whether to treat the function as a method.
AutoDiffParameterIndicesBuilder(AnyFunctionType *functionType, bool isMethod,
bool setAllParams = false);

/// Builds the `AutoDiffParameterIndices`, returning a pointer to an existing
/// one if it has already been allocated in the `ASTContext`.
AutoDiffParameterIndices *build(ASTContext &C) const;

/// Adds the indexed parameter to the set. When `isMethod` is not set,
/// the indices index into the first parameter list. For example,
///
/// functionType = (A, B, C) -> R
/// paramIndex = 0
/// ==> adds "A" to the set.
///
/// When `isMethod` is set, the indices index into the first non-self
/// parameter list. For example,
///
/// functionType = (Self) -> (A, B, C) -> R
/// paramIndex = 0
/// ==> adds "A" to the set.
///
void setNonSelfParameter(unsigned parameterIndex);

/// Adds all the paramaters from the first non-self parameter list to the set.
/// For example,
///
/// functionType = (A, B, C) -> R
/// ==> adds "A", B", and "C" to the set.
///
/// functionType = (Self) -> (A, B, C) -> R
/// ==> adds "A", B", and "C" to the set.
///
void setAllNonSelfParameters();

/// Adds the self parameter to the set. `isMethod` must be set. For
/// example,
/// functionType = (Self) -> (A, B, C) -> R
/// ==> adds "Self" to the set
///
void setSelfParameter();
};

/// Differentiability of a function specifies the differentiation mode,
/// parameter indices at which the function is differentiable with respect to,
/// and indices of results which can be differentiated.
Expand Down Expand Up @@ -413,25 +426,35 @@ struct AutoDiffAssociatedFunctionKind {

/// In conjunction with the original function decl, identifies an associated
/// autodiff function.
class AutoDiffAssociatedFunctionIdentifier {
AutoDiffAssociatedFunctionKind kind;
unsigned differentiationOrder;
AutoDiffParameterIndices *parameterIndices;
///
/// Is uniquely allocated within an ASTContext so that it can be hashed and
/// compared by opaque pointer value.
class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
const AutoDiffAssociatedFunctionKind kind;
const unsigned differentiationOrder;
AutoDiffParameterIndices * const parameterIndices;

AutoDiffAssociatedFunctionIdentifier(
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
AutoDiffParameterIndices *parameterIndices) :
kind(kind), differentiationOrder(differentiationOrder),
parameterIndices(parameterIndices) {}

public:
AutoDiffAssociatedFunctionKind getKind() const { return kind; }
unsigned getDifferentiationOrder() const { return differentiationOrder; }
const AutoDiffParameterIndices *getParameterIndices() const {
AutoDiffParameterIndices *getParameterIndices() const {
return parameterIndices;
}

static AutoDiffAssociatedFunctionIdentifier *get(
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
AutoDiffParameterIndices *parameterIndices, ASTContext &C);

bool operator==(const AutoDiffAssociatedFunctionIdentifier &other) const {
return kind == other.kind && differentiationOrder == other.differentiationOrder &&
*parameterIndices == *other.parameterIndices;
void Profile(llvm::FoldingSetNodeID &ID) {
ID.AddInteger(kind);
ID.AddInteger(differentiationOrder);
ID.AddPointer(parameterIndices);
}
};

Expand Down
10 changes: 6 additions & 4 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3125,13 +3125,15 @@ class AnyFunctionType : public TypeBase {
/// resulting function will preserve all `ExtInfo` of the original function,
/// including `@autodiff`.
AnyFunctionType *getAutoDiffAssociatedFunctionType(
const AutoDiffParameterIndices &indices, unsigned resultIndex,
AutoDiffParameterIndices *indices, unsigned resultIndex,
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
LookupConformanceFn lookupConformance, bool selfUncurried = false);
LookupConformanceFn lookupConformance, bool isMethod,
bool selfUncurried = false);

AnyFunctionType *
getAutoDiffAdjointFunctionType(const AutoDiffParameterIndices &indices,
const TupleType *primalResultTy);
getAutoDiffAdjointFunctionType(AutoDiffParameterIndices *indices,
const TupleType *primalResultTy,
bool isMethod);

/// \brief True if this type allows an implicit conversion from a function
/// argument expression of type T to a function of type () -> T.
Expand Down
57 changes: 57 additions & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,13 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL)
/// A cache of tangent spaces per type.
llvm::DenseMap<CanType, Optional<TangentSpace>> TangentSpaces;

/// For uniquifying `AutoDiffParameterIndices` allocations.
llvm::FoldingSet<AutoDiffParameterIndices> AutoDiffParameterIndicesSet;

/// For uniquifying `AutoDiffAssociatedFunctionIdentifier` allocations.
llvm::FoldingSet<AutoDiffAssociatedFunctionIdentifier>
AutoDiffAssociatedFunctionIdentifiers;

/// List of Objective-C member conflicts we have found during type checking.
std::vector<ObjCMethodConflict> ObjCMethodConflicts;

Expand Down Expand Up @@ -5269,3 +5276,53 @@ Optional<TangentSpace> ASTContext::getTangentSpace(CanType type,
// support differentiation.
return cache(None);
}

AutoDiffParameterIndices *
AutoDiffParameterIndices::get(llvm::SmallBitVector indices, ASTContext &C) {
auto &foldingSet = C.getImpl().AutoDiffParameterIndicesSet;

llvm::FoldingSetNodeID id;
id.AddInteger(indices.size());
for (unsigned setBit : indices.set_bits())
id.AddInteger(setBit);

void *insertPos;
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
if (existing)
return existing;

// TODO(SR-9290): Note that the AutoDiffParameterIndices' destructor never
// gets called, which causes a small memory leak in the case that the
// SmallBitVector decides to allocate some heap space.
void *mem = C.Allocate(sizeof(AutoDiffParameterIndices),
alignof(AutoDiffParameterIndices));
auto *newNode = ::new (mem) AutoDiffParameterIndices(indices);
foldingSet.InsertNode(newNode, insertPos);

return newNode;
}

AutoDiffAssociatedFunctionIdentifier *
AutoDiffAssociatedFunctionIdentifier::get(
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
AutoDiffParameterIndices *parameterIndices, ASTContext &C) {
auto &foldingSet = C.getImpl().AutoDiffAssociatedFunctionIdentifiers;

llvm::FoldingSetNodeID id;
id.AddInteger((unsigned)kind);
id.AddInteger(differentiationOrder);
id.AddPointer(parameterIndices);

void *insertPos;
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
if (existing)
return existing;

void *mem = C.Allocate(sizeof(AutoDiffAssociatedFunctionIdentifier),
alignof(AutoDiffAssociatedFunctionIdentifier));
auto *newNode = ::new (mem) AutoDiffAssociatedFunctionIdentifier(
kind, differentiationOrder, parameterIndices);
foldingSet.InsertNode(newNode, insertPos);

return newNode;
}
Loading