Skip to content

Commit 2c93727

Browse files
author
Marc Rasi
committed
uniquely allocate AutoDiffParameterIndices and AutoDiffAssociatedFunctionIdentifier
1 parent 07579ea commit 2c93727

File tree

12 files changed

+250
-175
lines changed

12 files changed

+250
-175
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 87 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,20 @@ class AutoDiffParameter {
7979
};
8080

8181
class AnyFunctionType;
82+
class AutoDiffParameterIndicesBuilder;
8283
class Type;
8384

8485
/// Identifies a subset of a function's parameters.
8586
///
8687
/// Works with AST-level function decls and types. Requires further lowering to
8788
/// work with SIL-level functions and types. (In particular, tuples must be
8889
/// exploded).
89-
class AutoDiffParameterIndices {
90+
///
91+
/// Is uniquely allocated within an ASTContext so that it can be hashed and
92+
/// compared by opaque pointer value.
93+
class AutoDiffParameterIndices : public llvm::FoldingSetNode {
94+
friend AutoDiffParameterIndicesBuilder;
95+
9096
/// Bits corresponding to parameters in the set are "on", and bits
9197
/// corresponding to parameters not in the set are "off".
9298
///
@@ -103,29 +109,19 @@ class AutoDiffParameterIndices {
103109
/// Function type: (Self) -> (A, B, C) -> R
104110
/// Bits: [A][B][C][Self]
105111
///
106-
llvm::SmallBitVector indices;
112+
const llvm::SmallBitVector indices;
107113

108114
/// Whether the function is a method.
109115
///
110-
bool isMethodFlag;
111-
112-
unsigned getNumNonSelfParameters() const;
113-
114-
AutoDiffParameterIndices(unsigned numIndices, bool isMethodFlag,
115-
bool setAllParams = false)
116-
: indices(numIndices, setAllParams), isMethodFlag(isMethodFlag) {}
116+
const bool isMethodFlag;
117117

118118
AutoDiffParameterIndices(llvm::SmallBitVector indices, bool isMethodFlag)
119119
: indices(indices), isMethodFlag(isMethodFlag) {}
120120

121-
public:
122-
/// Allocates and initializes an empty `AutoDiffParameterIndices` for the
123-
/// given `functionType`. `isMethod` specifies whether to treat the function
124-
/// as a method.
125-
static AutoDiffParameterIndices *
126-
create(ASTContext &C, AnyFunctionType *functionType, bool isMethod,
127-
bool setAllParams = false);
121+
static AutoDiffParameterIndices *get(llvm::SmallBitVector indices,
122+
bool isMethodFlag, ASTContext &C);
128123

124+
public:
129125
bool isMethod() const { return isMethodFlag; }
130126

131127
/// Allocates and initializes an `AutoDiffParameterIndices` corresponding to
@@ -146,40 +142,6 @@ class AutoDiffParameterIndices {
146142
/// Tests whether this set of parameters is empty.
147143
bool isEmpty() const { return indices.none(); }
148144

149-
/// Adds the indexed parameter to the set. When `isMethodFlag` is not set,
150-
/// the indices index into the first parameter list. For example,
151-
///
152-
/// functionType = (A, B, C) -> R
153-
/// paramIndex = 0
154-
/// ==> adds "A" to the set.
155-
///
156-
/// When `isMethodFlag` is set, the indices index into the first non-self
157-
/// parameter list. For example,
158-
///
159-
/// functionType = (Self) -> (A, B, C) -> R
160-
/// paramIndex = 0
161-
/// ==> adds "A" to the set.
162-
///
163-
void setNonSelfParameter(unsigned parameterIndex);
164-
165-
/// Adds all the paramaters from the first non-self parameter list to the set.
166-
/// For example,
167-
///
168-
/// functionType = (A, B, C) -> R
169-
/// ==> adds "A", B", and "C" to the set.
170-
///
171-
/// functionType = (Self) -> (A, B, C) -> R
172-
/// ==> adds "A", B", and "C" to the set.
173-
///
174-
void setAllNonSelfParameters();
175-
176-
/// Adds the self parameter to the set. `isMethodFlag` must be set. For
177-
/// example,
178-
/// functionType = (Self) -> (A, B, C) -> R
179-
/// ==> adds "Self" to the set
180-
///
181-
void setSelfParameter();
182-
183145
/// Pushes the subset's parameter's types to `paramTypes`, in the order in
184146
/// which they appear in the function type. For example,
185147
///
@@ -223,11 +185,66 @@ class AutoDiffParameterIndices {
223185
llvm::SmallBitVector getLowered(AnyFunctionType *functionType,
224186
bool selfUncurried = false) const;
225187

226-
bool operator==(const AutoDiffParameterIndices &other) const {
227-
return isMethodFlag == other.isMethodFlag && indices == other.indices;
188+
void Profile(llvm::FoldingSetNodeID &ID) const {
189+
ID.AddBoolean(isMethodFlag);
190+
ID.AddInteger(indices.size());
191+
for (unsigned setBit : indices.set_bits())
192+
ID.AddInteger(setBit);
228193
}
229194
};
230195

196+
/// Builder for `AutoDiffParameterIndices`.
197+
class AutoDiffParameterIndicesBuilder {
198+
llvm::SmallBitVector indices;
199+
bool isMethodFlag;
200+
201+
unsigned getNumNonSelfParameters() const;
202+
203+
public:
204+
/// Start building an `AutoDiffParameterIndices` for the given function type.
205+
/// `isMethod` specifies whether to treat the function as a method.
206+
AutoDiffParameterIndicesBuilder(AnyFunctionType *functionType, bool isMethod,
207+
bool setAllParams = false);
208+
209+
/// Builds the `AutoDiffParameterIndices`, returning a pointer to an existing
210+
/// one if it has already been allocated in the `ASTContext`.
211+
AutoDiffParameterIndices *build(ASTContext &C) const;
212+
213+
/// Adds the indexed parameter to the set. When `isMethodFlag` is not set,
214+
/// the indices index into the first parameter list. For example,
215+
///
216+
/// functionType = (A, B, C) -> R
217+
/// paramIndex = 0
218+
/// ==> adds "A" to the set.
219+
///
220+
/// When `isMethodFlag` is set, the indices index into the first non-self
221+
/// parameter list. For example,
222+
///
223+
/// functionType = (Self) -> (A, B, C) -> R
224+
/// paramIndex = 0
225+
/// ==> adds "A" to the set.
226+
///
227+
void setNonSelfParameter(unsigned parameterIndex);
228+
229+
/// Adds all the paramaters from the first non-self parameter list to the set.
230+
/// For example,
231+
///
232+
/// functionType = (A, B, C) -> R
233+
/// ==> adds "A", B", and "C" to the set.
234+
///
235+
/// functionType = (Self) -> (A, B, C) -> R
236+
/// ==> adds "A", B", and "C" to the set.
237+
///
238+
void setAllNonSelfParameters();
239+
240+
/// Adds the self parameter to the set. `isMethodFlag` must be set. For
241+
/// example,
242+
/// functionType = (Self) -> (A, B, C) -> R
243+
/// ==> adds "Self" to the set
244+
///
245+
void setSelfParameter();
246+
};
247+
231248
/// Differentiability of a function specifies the differentiation mode,
232249
/// parameter indices at which the function is differentiable with respect to,
233250
/// and indices of results which can be differentiated.
@@ -413,25 +430,35 @@ struct AutoDiffAssociatedFunctionKind {
413430

414431
/// In conjunction with the original function decl, identifies an associated
415432
/// autodiff function.
416-
class AutoDiffAssociatedFunctionIdentifier {
417-
AutoDiffAssociatedFunctionKind kind;
418-
unsigned differentiationOrder;
419-
AutoDiffParameterIndices *parameterIndices;
433+
///
434+
/// Is uniquely allocated within an ASTContext so that it can be hashed and
435+
/// compared by opaque pointer value.
436+
class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
437+
const AutoDiffAssociatedFunctionKind kind;
438+
const unsigned differentiationOrder;
439+
AutoDiffParameterIndices * const parameterIndices;
440+
441+
AutoDiffAssociatedFunctionIdentifier(
442+
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
443+
AutoDiffParameterIndices *parameterIndices) :
444+
kind(kind), differentiationOrder(differentiationOrder),
445+
parameterIndices(parameterIndices) {}
420446

421447
public:
422448
AutoDiffAssociatedFunctionKind getKind() const { return kind; }
423449
unsigned getDifferentiationOrder() const { return differentiationOrder; }
424-
const AutoDiffParameterIndices *getParameterIndices() const {
450+
AutoDiffParameterIndices *getParameterIndices() const {
425451
return parameterIndices;
426452
}
427453

428454
static AutoDiffAssociatedFunctionIdentifier *get(
429455
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
430456
AutoDiffParameterIndices *parameterIndices, ASTContext &C);
431457

432-
bool operator==(const AutoDiffAssociatedFunctionIdentifier &other) const {
433-
return kind == other.kind && differentiationOrder == other.differentiationOrder &&
434-
*parameterIndices == *other.parameterIndices;
458+
void Profile(llvm::FoldingSetNodeID &ID) {
459+
ID.AddInteger(kind);
460+
ID.AddInteger(differentiationOrder);
461+
ID.AddPointer(parameterIndices);
435462
}
436463
};
437464

include/swift/AST/Types.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3125,12 +3125,12 @@ class AnyFunctionType : public TypeBase {
31253125
/// resulting function will preserve all `ExtInfo` of the original function,
31263126
/// including `@autodiff`.
31273127
AnyFunctionType *getAutoDiffAssociatedFunctionType(
3128-
const AutoDiffParameterIndices &indices, unsigned resultIndex,
3128+
AutoDiffParameterIndices *indices, unsigned resultIndex,
31293129
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
31303130
LookupConformanceFn lookupConformance, bool selfUncurried = false);
31313131

31323132
AnyFunctionType *
3133-
getAutoDiffAdjointFunctionType(const AutoDiffParameterIndices &indices,
3133+
getAutoDiffAdjointFunctionType(AutoDiffParameterIndices *indices,
31343134
const TupleType *primalResultTy);
31353135

31363136
/// \brief True if this type allows an implicit conversion from a function

lib/AST/ASTContext.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,13 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL)
384384
/// A cache of tangent spaces per type.
385385
llvm::DenseMap<CanType, Optional<TangentSpace>> TangentSpaces;
386386

387+
/// For uniquifying `AutoDiffParameterIndices` allocations.
388+
llvm::FoldingSet<AutoDiffParameterIndices> AutoDiffParameterIndicesSet;
389+
390+
/// For uniquifying `AutoDiffAssociatedFunctionIdentifier` allocations.
391+
llvm::FoldingSet<AutoDiffAssociatedFunctionIdentifier>
392+
AutoDiffAssociatedFunctionIdentifiers;
393+
387394
/// List of Objective-C member conflicts we have found during type checking.
388395
std::vector<ObjCMethodConflict> ObjCMethodConflicts;
389396

@@ -5269,3 +5276,57 @@ Optional<TangentSpace> ASTContext::getTangentSpace(CanType type,
52695276
// support differentiation.
52705277
return cache(None);
52715278
}
5279+
5280+
AutoDiffParameterIndices *
5281+
AutoDiffParameterIndices::get(llvm::SmallBitVector indices, bool isMethodFlag,
5282+
ASTContext &C) {
5283+
auto &foldingSet = C.getImpl().AutoDiffParameterIndicesSet;
5284+
5285+
llvm::FoldingSetNodeID id;
5286+
id.AddBoolean(isMethodFlag);
5287+
id.AddInteger(indices.size());
5288+
for (unsigned setBit : indices.set_bits())
5289+
id.AddInteger(setBit);
5290+
5291+
void *insertPos;
5292+
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
5293+
if (existing)
5294+
return existing;
5295+
5296+
// TODO(SR-9290): Note that the AutoDiffParameterIndices' destructor never
5297+
// gets called, which causes a small memory leak in the case that the
5298+
// SmallBitVector decides to allocate some heap space.
5299+
void *mem = C.Allocate(sizeof(AutoDiffParameterIndices),
5300+
alignof(AutoDiffParameterIndices));
5301+
auto *newNode = ::new (mem) AutoDiffParameterIndices(indices, isMethodFlag);
5302+
foldingSet.InsertNode(newNode, insertPos);
5303+
5304+
return newNode;
5305+
}
5306+
5307+
AutoDiffAssociatedFunctionIdentifier *
5308+
AutoDiffAssociatedFunctionIdentifier::get(
5309+
AutoDiffAssociatedFunctionKind kind,
5310+
unsigned differentiationOrder,
5311+
AutoDiffParameterIndices *parameterIndices,
5312+
ASTContext &C) {
5313+
auto &foldingSet = C.getImpl().AutoDiffAssociatedFunctionIdentifiers;
5314+
5315+
llvm::FoldingSetNodeID id;
5316+
id.AddInteger((unsigned)kind);
5317+
id.AddInteger(differentiationOrder);
5318+
id.AddPointer(parameterIndices);
5319+
5320+
void *insertPos;
5321+
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
5322+
if (existing)
5323+
return existing;
5324+
5325+
void *mem = C.Allocate(sizeof(AutoDiffAssociatedFunctionIdentifier),
5326+
alignof(AutoDiffAssociatedFunctionIdentifier));
5327+
auto *newNode = ::new (mem) AutoDiffAssociatedFunctionIdentifier(
5328+
kind, differentiationOrder, parameterIndices);
5329+
foldingSet.InsertNode(newNode, insertPos);
5330+
5331+
return newNode;
5332+
}

0 commit comments

Comments
 (0)