Skip to content

Commit 3b8498d

Browse files
committed
Pick up changes from IndexSubset.
1 parent 46ff5f2 commit 3b8498d

File tree

4 files changed

+10
-368
lines changed

4 files changed

+10
-368
lines changed

include/swift/AST/Attr.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,7 +1758,7 @@ class DifferentiableAttr final
17581758
/// specified.
17591759
FuncDecl *VJPFunction = nullptr;
17601760
/// The differentiation parameters' indices, resolved by the type checker.
1761-
AutoDiffIndexSubset *ParameterIndices = nullptr;
1761+
IndexSubset *ParameterIndices = nullptr;
17621762
/// The trailing where clause (optional).
17631763
TrailingWhereClause *WhereClause = nullptr;
17641764
/// The generic signature for autodiff associated functions. Resolved by the
@@ -1777,7 +1777,7 @@ class DifferentiableAttr final
17771777

17781778
explicit DifferentiableAttr(ASTContext &context, bool implicit,
17791779
SourceLoc atLoc, SourceRange baseRange,
1780-
bool linear, AutoDiffIndexSubset *indices,
1780+
bool linear, IndexSubset *indices,
17811781
Optional<DeclNameWithLoc> jvp,
17821782
Optional<DeclNameWithLoc> vjp,
17831783
GenericSignature derivativeGenericSignature);
@@ -1793,7 +1793,7 @@ class DifferentiableAttr final
17931793

17941794
static DifferentiableAttr *create(ASTContext &context, bool implicit,
17951795
SourceLoc atLoc, SourceRange baseRange,
1796-
bool linear, AutoDiffIndexSubset *indices,
1796+
bool linear, IndexSubset *indices,
17971797
Optional<DeclNameWithLoc> jvp,
17981798
Optional<DeclNameWithLoc> vjp,
17991799
GenericSignature derivativeGenSig);
@@ -1808,10 +1808,10 @@ class DifferentiableAttr final
18081808
/// registered VJP.
18091809
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
18101810

1811-
AutoDiffIndexSubset *getParameterIndices() const {
1811+
IndexSubset *getParameterIndices() const {
18121812
return ParameterIndices;
18131813
}
1814-
void setParameterIndices(AutoDiffIndexSubset *pi) {
1814+
void setParameterIndices(IndexSubset *pi) {
18151815
ParameterIndices = pi;
18161816
}
18171817

include/swift/AST/AutoDiff.h

Lines changed: 0 additions & 241 deletions
Original file line numberDiff line numberDiff line change
@@ -87,247 +87,6 @@ class ParsedAutoDiffParameter {
8787
}
8888
};
8989

90-
/// An efficient index subset data structure, uniqued in `ASTContext`.
91-
/// Stores a bit vector representing set indices and a total capacity.
92-
class AutoDiffIndexSubset : public llvm::FoldingSetNode {
93-
public:
94-
typedef uint64_t BitWord;
95-
96-
static constexpr unsigned bitWordSize = sizeof(BitWord);
97-
static constexpr unsigned numBitsPerBitWord = bitWordSize * 8;
98-
99-
static std::pair<unsigned, unsigned>
100-
getBitWordIndexAndOffset(unsigned index) {
101-
auto bitWordIndex = index / numBitsPerBitWord;
102-
auto bitWordOffset = index % numBitsPerBitWord;
103-
return {bitWordIndex, bitWordOffset};
104-
}
105-
106-
static unsigned getNumBitWordsNeededForCapacity(unsigned capacity) {
107-
if (capacity == 0) return 0;
108-
return capacity / numBitsPerBitWord + 1;
109-
}
110-
111-
private:
112-
/// The total capacity of the index subset, which is `1` less than the largest
113-
/// index.
114-
unsigned capacity;
115-
/// The number of bit words in the index subset.
116-
unsigned numBitWords;
117-
118-
BitWord *getBitWordsData() {
119-
return reinterpret_cast<BitWord *>(this + 1);
120-
}
121-
122-
const BitWord *getBitWordsData() const {
123-
return reinterpret_cast<const BitWord *>(this + 1);
124-
}
125-
126-
ArrayRef<BitWord> getBitWords() const {
127-
return {getBitWordsData(), getNumBitWords()};
128-
}
129-
130-
BitWord getBitWord(unsigned i) const {
131-
return getBitWordsData()[i];
132-
}
133-
134-
BitWord &getBitWord(unsigned i) {
135-
return getBitWordsData()[i];
136-
}
137-
138-
MutableArrayRef<BitWord> getMutableBitWords() {
139-
return {const_cast<BitWord *>(getBitWordsData()), getNumBitWords()};
140-
}
141-
142-
explicit AutoDiffIndexSubset(const SmallBitVector &indices)
143-
: capacity((unsigned)indices.size()),
144-
numBitWords(getNumBitWordsNeededForCapacity(capacity)) {
145-
std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0);
146-
for (auto i : indices.set_bits()) {
147-
unsigned bitWordIndex, offset;
148-
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i);
149-
getBitWord(bitWordIndex) |= (1 << offset);
150-
}
151-
}
152-
153-
public:
154-
AutoDiffIndexSubset() = delete;
155-
AutoDiffIndexSubset(const AutoDiffIndexSubset &) = delete;
156-
AutoDiffIndexSubset &operator=(const AutoDiffIndexSubset &) = delete;
157-
158-
// Defined in ASTContext.cpp.
159-
static AutoDiffIndexSubset *get(ASTContext &ctx,
160-
const SmallBitVector &indices);
161-
162-
static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity,
163-
ArrayRef<unsigned> indices) {
164-
SmallBitVector indicesBitVec(capacity, false);
165-
for (auto index : indices)
166-
indicesBitVec.set(index);
167-
return AutoDiffIndexSubset::get(ctx, indicesBitVec);
168-
}
169-
170-
static AutoDiffIndexSubset *getDefault(ASTContext &ctx, unsigned capacity,
171-
bool includeAll = false) {
172-
return get(ctx, SmallBitVector(capacity, includeAll));
173-
}
174-
175-
static AutoDiffIndexSubset *getFromRange(ASTContext &ctx, unsigned capacity,
176-
unsigned start, unsigned end) {
177-
assert(start < capacity);
178-
assert(end <= capacity);
179-
SmallBitVector bitVec(capacity);
180-
bitVec.set(start, end);
181-
return get(ctx, bitVec);
182-
}
183-
184-
/// Creates an index subset corresponding to the given string generated by
185-
/// `getString()`. If the string is invalid, returns nullptr.
186-
static AutoDiffIndexSubset *getFromString(ASTContext &ctx, StringRef string);
187-
188-
/// Returns the number of bit words used to store the index subset.
189-
// Note: Use `getCapacity()` to get the total index subset capacity.
190-
// This is public only for unit testing
191-
// (in unittests/AST/SILAutoDiffIndices.cpp).
192-
unsigned getNumBitWords() const {
193-
return numBitWords;
194-
}
195-
196-
/// Returns the capacity of the index subset.
197-
unsigned getCapacity() const {
198-
return capacity;
199-
}
200-
201-
/// Returns a textual string description of these indices.
202-
///
203-
/// It has the format `[SU]+`, where the total number of characters is equal
204-
/// to the capacity, and where "S" means that the corresponding index is
205-
/// contained and "U" means that the corresponding index is not.
206-
std::string getString() const;
207-
208-
class iterator;
209-
210-
iterator begin() const {
211-
return iterator(this);
212-
}
213-
214-
iterator end() const {
215-
return iterator(this, (int)capacity);
216-
}
217-
218-
iterator_range<iterator> getIndices() const {
219-
return make_range(begin(), end());
220-
}
221-
222-
unsigned getNumIndices() const {
223-
return (unsigned)std::distance(begin(), end());
224-
}
225-
226-
SmallBitVector getBitVector() const {
227-
SmallBitVector indicesBitVec(capacity, false);
228-
for (auto index : getIndices())
229-
indicesBitVec.set(index);
230-
return indicesBitVec;
231-
}
232-
233-
bool contains(unsigned index) const {
234-
unsigned bitWordIndex, offset;
235-
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(index);
236-
return getBitWord(bitWordIndex) & (1 << offset);
237-
}
238-
239-
bool isEmpty() const {
240-
return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; });
241-
}
242-
243-
bool equals(AutoDiffIndexSubset *other) const {
244-
return capacity == other->getCapacity() &&
245-
getBitWords().equals(other->getBitWords());
246-
}
247-
248-
bool isSubsetOf(AutoDiffIndexSubset *other) const;
249-
bool isSupersetOf(AutoDiffIndexSubset *other) const;
250-
251-
AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const;
252-
AutoDiffIndexSubset *extendingCapacity(ASTContext &ctx,
253-
unsigned newCapacity) const;
254-
255-
void Profile(llvm::FoldingSetNodeID &id) const {
256-
id.AddInteger(capacity);
257-
for (auto index : getIndices())
258-
id.AddInteger(index);
259-
}
260-
261-
void print(llvm::raw_ostream &s = llvm::outs()) const {
262-
s << '{';
263-
interleave(range(capacity), [this, &s](unsigned i) { s << contains(i); },
264-
[&s] { s << ", "; });
265-
s << '}';
266-
}
267-
268-
void dump(llvm::raw_ostream &s = llvm::errs()) const {
269-
s << "(autodiff_index_subset capacity=" << capacity << " indices=(";
270-
interleave(getIndices(), [&s](unsigned i) { s << i; },
271-
[&s] { s << ", "; });
272-
s << "))";
273-
}
274-
275-
int findNext(int startIndex) const;
276-
int findFirst() const { return findNext(-1); }
277-
int findPrevious(int endIndex) const;
278-
int findLast() const { return findPrevious(capacity); }
279-
280-
class iterator {
281-
public:
282-
typedef unsigned value_type;
283-
typedef unsigned difference_type;
284-
typedef unsigned * pointer;
285-
typedef unsigned & reference;
286-
typedef std::forward_iterator_tag iterator_category;
287-
288-
private:
289-
const AutoDiffIndexSubset *parent;
290-
int current = 0;
291-
292-
void advance() {
293-
assert(current != -1 && "Trying to advance past end.");
294-
current = parent->findNext(current);
295-
}
296-
297-
public:
298-
iterator(const AutoDiffIndexSubset *parent, int current)
299-
: parent(parent), current(current) {}
300-
explicit iterator(const AutoDiffIndexSubset *parent)
301-
: iterator(parent, parent->findFirst()) {}
302-
iterator(const iterator &) = default;
303-
304-
iterator operator++(int) {
305-
auto prev = *this;
306-
advance();
307-
return prev;
308-
}
309-
310-
iterator &operator++() {
311-
advance();
312-
return *this;
313-
}
314-
315-
unsigned operator*() const { return current; }
316-
317-
bool operator==(const iterator &other) const {
318-
assert(parent == other.parent &&
319-
"Comparing iterators from different AutoDiffIndexSubsets");
320-
return current == other.current;
321-
}
322-
323-
bool operator!=(const iterator &other) const {
324-
assert(parent == other.parent &&
325-
"Comparing iterators from different AutoDiffIndexSubsets");
326-
return current != other.current;
327-
}
328-
};
329-
};
330-
33190
} // end namespace swift
33291

33392
#endif // SWIFT_AST_AUTODIFF_H

lib/AST/Attr.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "swift/AST/Decl.h"
2121
#include "swift/AST/Expr.h"
2222
#include "swift/AST/GenericEnvironment.h"
23+
#include "swift/AST/IndexSubset.h"
2324
#include "swift/AST/Module.h"
2425
#include "swift/AST/TypeRepr.h"
2526
#include "swift/AST/Types.h"
@@ -353,13 +354,13 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
353354
}
354355

355356
static std::string getDifferentiationParametersClauseString(
356-
const AbstractFunctionDecl *function, AutoDiffIndexSubset *indices,
357+
const AbstractFunctionDecl *function, IndexSubset *indices,
357358
ArrayRef<ParsedAutoDiffParameter> parsedParams) {
358359
bool isInstanceMethod = function && function->isInstanceMember();
359360
std::string result;
360361
llvm::raw_string_ostream printer(result);
361362

362-
// Use parameter indices from `AutoDiffIndexSubset`, if specified.
363+
// Use parameter indices from `IndexSubset`, if specified.
363364
if (indices) {
364365
auto parameters = indices->getBitVector();
365366
auto parameterCount = parameters.count();
@@ -1340,7 +1341,7 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
13401341
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
13411342
SourceLoc atLoc, SourceRange baseRange,
13421343
bool linear,
1343-
AutoDiffIndexSubset *indices,
1344+
IndexSubset *indices,
13441345
Optional<DeclNameWithLoc> jvp,
13451346
Optional<DeclNameWithLoc> vjp,
13461347
GenericSignature derivativeGenSig)
@@ -1368,7 +1369,7 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
13681369
DifferentiableAttr *
13691370
DifferentiableAttr::create(ASTContext &context, bool implicit,
13701371
SourceLoc atLoc, SourceRange baseRange,
1371-
bool linear, AutoDiffIndexSubset *indices,
1372+
bool linear, IndexSubset *indices,
13721373
Optional<DeclNameWithLoc> jvp,
13731374
Optional<DeclNameWithLoc> vjp,
13741375
GenericSignature derivativeGenSig) {

0 commit comments

Comments
 (0)