Skip to content

Commit a0a7ee5

Browse files
authored
uniquely allocate AutoDiffParameterIndices and AutoDiffAssociatedFunctionIdentifier (#21193)
Resolves https://bugs.swift.org/browse/SR-9428.
1 parent 3267cbb commit a0a7ee5

18 files changed

+543
-420
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 97 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -79,107 +79,60 @@ class AutoDiffParameter {
7979
};
8080

8181
class AnyFunctionType;
82+
class AutoDiffParameterIndicesBuilder;
8283
class Type;
8384

8485
/// Identifies a subset of a function's parameters.
8586
///
8687
/// Works with AST-level function decls and types. Requires further lowering to
8788
/// work with SIL-level functions and types. (In particular, tuples must be
8889
/// exploded).
89-
class AutoDiffParameterIndices {
90+
///
91+
/// Is uniquely allocated within an ASTContext so that it can be hashed and
92+
/// compared by opaque pointer value.
93+
class AutoDiffParameterIndices : public llvm::FoldingSetNode {
94+
friend AutoDiffParameterIndicesBuilder;
95+
9096
/// Bits corresponding to parameters in the set are "on", and bits
9197
/// corresponding to parameters not in the set are "off".
9298
///
93-
/// Normally, the bits correspond to the function's parameters in order. For
94-
/// example,
99+
/// For non-method functions, the bits correspond to the function's
100+
//// parameters in order. For example,
95101
///
96102
/// Function type: (A, B, C) -> R
97103
/// Bits: [A][B][C]
98104
///
99-
/// When `isMethodFlag` is set, the bits correspond to the function's
100-
/// non-self parameters in order, followed by the function's self parameter.
101-
/// For example,
105+
/// For methods, the bits correspond to the function's non-self parameters
106+
/// in order, followed by the function's self parameter. For example,
102107
///
103108
/// Function type: (Self) -> (A, B, C) -> R
104109
/// Bits: [A][B][C][Self]
105110
///
106-
llvm::SmallBitVector indices;
111+
const llvm::SmallBitVector indices;
107112

108-
/// Whether the function is a method.
109-
///
110-
bool isMethodFlag;
113+
AutoDiffParameterIndices(llvm::SmallBitVector indices)
114+
: indices(indices) {}
111115

112-
unsigned getNumNonSelfParameters() const;
113-
114-
AutoDiffParameterIndices(unsigned numIndices, bool isMethodFlag,
115-
bool setAllParams = false)
116-
: indices(numIndices, setAllParams), isMethodFlag(isMethodFlag) {}
117-
118-
AutoDiffParameterIndices(llvm::SmallBitVector indices, bool isMethodFlag)
119-
: indices(indices), isMethodFlag(isMethodFlag) {}
116+
static AutoDiffParameterIndices *get(llvm::SmallBitVector indices,
117+
ASTContext &C);
120118

121119
public:
122-
/// Allocates and initializes an empty `AutoDiffParameterIndices` for the
123-
/// given `functionType`. `isMethod` specifies whether to treat the function
124-
/// as a method.
125-
static AutoDiffParameterIndices *
126-
create(ASTContext &C, AnyFunctionType *functionType, bool isMethod,
127-
bool setAllParams = false);
128-
129-
bool isMethod() const { return isMethodFlag; }
130-
131120
/// Allocates and initializes an `AutoDiffParameterIndices` corresponding to
132121
/// the given `string` generated by `getString()`. If the string is invalid,
133122
/// returns nullptr.
134123
static AutoDiffParameterIndices *create(ASTContext &C, StringRef string);
135124

136125
/// Returns a textual string description of these indices,
137126
///
138-
/// [FM][SU]+
127+
/// [SU]+
139128
///
140-
/// "F" means that `isMethodFlag` is false
141-
/// "M" means that `isMethodFlag` is true
142129
/// "S" means that the corresponding index is set
143130
/// "U" means that the corresponding index is unset
144131
std::string getString() const;
145132

146133
/// Tests whether this set of parameters is empty.
147134
bool isEmpty() const { return indices.none(); }
148135

149-
/// Adds the indexed parameter to the set. When `isMethodFlag` is not set,
150-
/// the indices index into the first parameter list. For example,
151-
///
152-
/// functionType = (A, B, C) -> R
153-
/// paramIndex = 0
154-
/// ==> adds "A" to the set.
155-
///
156-
/// When `isMethodFlag` is set, the indices index into the first non-self
157-
/// parameter list. For example,
158-
///
159-
/// functionType = (Self) -> (A, B, C) -> R
160-
/// paramIndex = 0
161-
/// ==> adds "A" to the set.
162-
///
163-
void setNonSelfParameter(unsigned parameterIndex);
164-
165-
/// Adds all the paramaters from the first non-self parameter list to the set.
166-
/// For example,
167-
///
168-
/// functionType = (A, B, C) -> R
169-
/// ==> adds "A", B", and "C" to the set.
170-
///
171-
/// functionType = (Self) -> (A, B, C) -> R
172-
/// ==> adds "A", B", and "C" to the set.
173-
///
174-
void setAllNonSelfParameters();
175-
176-
/// Adds the self parameter to the set. `isMethodFlag` must be set. For
177-
/// example,
178-
/// functionType = (Self) -> (A, B, C) -> R
179-
/// ==> adds "Self" to the set
180-
///
181-
void setSelfParameter();
182-
183136
/// Pushes the subset's parameter's types to `paramTypes`, in the order in
184137
/// which they appear in the function type. For example,
185138
///
@@ -191,11 +144,14 @@ class AutoDiffParameterIndices {
191144
/// if "Self" and "C" are in the set,
192145
/// ==> pushes {Self, C} to `paramTypes`.
193146
///
147+
/// Pass `isMethod = true` when the function is a method.
148+
///
194149
/// Pass `selfUncurried = true` when the function type is for a method whose
195150
/// self parameter has been uncurried as in (A, B, C, Self) -> R.
196151
///
197152
void getSubsetParameterTypes(AnyFunctionType *functionType,
198153
SmallVectorImpl<Type> &paramTypes,
154+
bool isMethod,
199155
bool selfUncurried = false) const;
200156

201157
/// Returns a bitvector for the SILFunction parameters corresponding to the
@@ -217,17 +173,74 @@ class AutoDiffParameterIndices {
217173
/// ==> returns 1110
218174
/// (because the lowered SIL type is (A, B, C, D) -> R)
219175
///
176+
/// Pass `isMethod = true` when the function is a method.
177+
///
220178
/// Pass `selfUncurried = true` when the function type is for a method whose
221179
/// self parameter has been uncurried as in (A, B, C, Self) -> R.
222180
///
223181
llvm::SmallBitVector getLowered(AnyFunctionType *functionType,
182+
bool isMethod,
224183
bool selfUncurried = false) const;
225184

226-
bool operator==(const AutoDiffParameterIndices &other) const {
227-
return isMethodFlag == other.isMethodFlag && indices == other.indices;
185+
void Profile(llvm::FoldingSetNodeID &ID) const {
186+
ID.AddInteger(indices.size());
187+
for (unsigned setBit : indices.set_bits())
188+
ID.AddInteger(setBit);
228189
}
229190
};
230191

192+
/// Builder for `AutoDiffParameterIndices`.
193+
class AutoDiffParameterIndicesBuilder {
194+
llvm::SmallBitVector indices;
195+
bool isMethod;
196+
197+
unsigned getNumNonSelfParameters() const;
198+
199+
public:
200+
/// Start building an `AutoDiffParameterIndices` for the given function type.
201+
/// `isMethod` specifies whether to treat the function as a method.
202+
AutoDiffParameterIndicesBuilder(AnyFunctionType *functionType, bool isMethod,
203+
bool setAllParams = false);
204+
205+
/// Builds the `AutoDiffParameterIndices`, returning a pointer to an existing
206+
/// one if it has already been allocated in the `ASTContext`.
207+
AutoDiffParameterIndices *build(ASTContext &C) const;
208+
209+
/// Adds the indexed parameter to the set. When `isMethod` is not set,
210+
/// the indices index into the first parameter list. For example,
211+
///
212+
/// functionType = (A, B, C) -> R
213+
/// paramIndex = 0
214+
/// ==> adds "A" to the set.
215+
///
216+
/// When `isMethod` is set, the indices index into the first non-self
217+
/// parameter list. For example,
218+
///
219+
/// functionType = (Self) -> (A, B, C) -> R
220+
/// paramIndex = 0
221+
/// ==> adds "A" to the set.
222+
///
223+
void setNonSelfParameter(unsigned parameterIndex);
224+
225+
/// Adds all the paramaters from the first non-self parameter list to the set.
226+
/// For example,
227+
///
228+
/// functionType = (A, B, C) -> R
229+
/// ==> adds "A", B", and "C" to the set.
230+
///
231+
/// functionType = (Self) -> (A, B, C) -> R
232+
/// ==> adds "A", B", and "C" to the set.
233+
///
234+
void setAllNonSelfParameters();
235+
236+
/// Adds the self parameter to the set. `isMethod` must be set. For
237+
/// example,
238+
/// functionType = (Self) -> (A, B, C) -> R
239+
/// ==> adds "Self" to the set
240+
///
241+
void setSelfParameter();
242+
};
243+
231244
/// Differentiability of a function specifies the differentiation mode,
232245
/// parameter indices at which the function is differentiable with respect to,
233246
/// and indices of results which can be differentiated.
@@ -413,25 +426,35 @@ struct AutoDiffAssociatedFunctionKind {
413426

414427
/// In conjunction with the original function decl, identifies an associated
415428
/// autodiff function.
416-
class AutoDiffAssociatedFunctionIdentifier {
417-
AutoDiffAssociatedFunctionKind kind;
418-
unsigned differentiationOrder;
419-
AutoDiffParameterIndices *parameterIndices;
429+
///
430+
/// Is uniquely allocated within an ASTContext so that it can be hashed and
431+
/// compared by opaque pointer value.
432+
class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
433+
const AutoDiffAssociatedFunctionKind kind;
434+
const unsigned differentiationOrder;
435+
AutoDiffParameterIndices * const parameterIndices;
436+
437+
AutoDiffAssociatedFunctionIdentifier(
438+
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
439+
AutoDiffParameterIndices *parameterIndices) :
440+
kind(kind), differentiationOrder(differentiationOrder),
441+
parameterIndices(parameterIndices) {}
420442

421443
public:
422444
AutoDiffAssociatedFunctionKind getKind() const { return kind; }
423445
unsigned getDifferentiationOrder() const { return differentiationOrder; }
424-
const AutoDiffParameterIndices *getParameterIndices() const {
446+
AutoDiffParameterIndices *getParameterIndices() const {
425447
return parameterIndices;
426448
}
427449

428450
static AutoDiffAssociatedFunctionIdentifier *get(
429451
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
430452
AutoDiffParameterIndices *parameterIndices, ASTContext &C);
431453

432-
bool operator==(const AutoDiffAssociatedFunctionIdentifier &other) const {
433-
return kind == other.kind && differentiationOrder == other.differentiationOrder &&
434-
*parameterIndices == *other.parameterIndices;
454+
void Profile(llvm::FoldingSetNodeID &ID) {
455+
ID.AddInteger(kind);
456+
ID.AddInteger(differentiationOrder);
457+
ID.AddPointer(parameterIndices);
435458
}
436459
};
437460

include/swift/AST/Types.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3125,13 +3125,15 @@ class AnyFunctionType : public TypeBase {
31253125
/// resulting function will preserve all `ExtInfo` of the original function,
31263126
/// including `@autodiff`.
31273127
AnyFunctionType *getAutoDiffAssociatedFunctionType(
3128-
const AutoDiffParameterIndices &indices, unsigned resultIndex,
3128+
AutoDiffParameterIndices *indices, unsigned resultIndex,
31293129
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
3130-
LookupConformanceFn lookupConformance, bool selfUncurried = false);
3130+
LookupConformanceFn lookupConformance, bool isMethod,
3131+
bool selfUncurried = false);
31313132

31323133
AnyFunctionType *
3133-
getAutoDiffAdjointFunctionType(const AutoDiffParameterIndices &indices,
3134-
const TupleType *primalResultTy);
3134+
getAutoDiffAdjointFunctionType(AutoDiffParameterIndices *indices,
3135+
const TupleType *primalResultTy,
3136+
bool isMethod);
31353137

31363138
/// \brief True if this type allows an implicit conversion from a function
31373139
/// argument expression of type T to a function of type () -> T.

lib/AST/ASTContext.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,13 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL)
384384
/// A cache of tangent spaces per type.
385385
llvm::DenseMap<CanType, Optional<TangentSpace>> TangentSpaces;
386386

387+
/// For uniquifying `AutoDiffParameterIndices` allocations.
388+
llvm::FoldingSet<AutoDiffParameterIndices> AutoDiffParameterIndicesSet;
389+
390+
/// For uniquifying `AutoDiffAssociatedFunctionIdentifier` allocations.
391+
llvm::FoldingSet<AutoDiffAssociatedFunctionIdentifier>
392+
AutoDiffAssociatedFunctionIdentifiers;
393+
387394
/// List of Objective-C member conflicts we have found during type checking.
388395
std::vector<ObjCMethodConflict> ObjCMethodConflicts;
389396

@@ -5269,3 +5276,53 @@ Optional<TangentSpace> ASTContext::getTangentSpace(CanType type,
52695276
// support differentiation.
52705277
return cache(None);
52715278
}
5279+
5280+
AutoDiffParameterIndices *
5281+
AutoDiffParameterIndices::get(llvm::SmallBitVector indices, ASTContext &C) {
5282+
auto &foldingSet = C.getImpl().AutoDiffParameterIndicesSet;
5283+
5284+
llvm::FoldingSetNodeID id;
5285+
id.AddInteger(indices.size());
5286+
for (unsigned setBit : indices.set_bits())
5287+
id.AddInteger(setBit);
5288+
5289+
void *insertPos;
5290+
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
5291+
if (existing)
5292+
return existing;
5293+
5294+
// TODO(SR-9290): Note that the AutoDiffParameterIndices' destructor never
5295+
// gets called, which causes a small memory leak in the case that the
5296+
// SmallBitVector decides to allocate some heap space.
5297+
void *mem = C.Allocate(sizeof(AutoDiffParameterIndices),
5298+
alignof(AutoDiffParameterIndices));
5299+
auto *newNode = ::new (mem) AutoDiffParameterIndices(indices);
5300+
foldingSet.InsertNode(newNode, insertPos);
5301+
5302+
return newNode;
5303+
}
5304+
5305+
AutoDiffAssociatedFunctionIdentifier *
5306+
AutoDiffAssociatedFunctionIdentifier::get(
5307+
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
5308+
AutoDiffParameterIndices *parameterIndices, ASTContext &C) {
5309+
auto &foldingSet = C.getImpl().AutoDiffAssociatedFunctionIdentifiers;
5310+
5311+
llvm::FoldingSetNodeID id;
5312+
id.AddInteger((unsigned)kind);
5313+
id.AddInteger(differentiationOrder);
5314+
id.AddPointer(parameterIndices);
5315+
5316+
void *insertPos;
5317+
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
5318+
if (existing)
5319+
return existing;
5320+
5321+
void *mem = C.Allocate(sizeof(AutoDiffAssociatedFunctionIdentifier),
5322+
alignof(AutoDiffAssociatedFunctionIdentifier));
5323+
auto *newNode = ::new (mem) AutoDiffAssociatedFunctionIdentifier(
5324+
kind, differentiationOrder, parameterIndices);
5325+
foldingSet.InsertNode(newNode, insertPos);
5326+
5327+
return newNode;
5328+
}

0 commit comments

Comments
 (0)