Skip to content

Commit 4f08e57

Browse files
committed
[AutoDiff] Parameter indices data structure overhaul.
1 parent 81c2f84 commit 4f08e57

File tree

11 files changed

+352
-330
lines changed

11 files changed

+352
-330
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 143 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -223,22 +223,31 @@ class AutoDiffParameterIndicesBuilder {
223223
};
224224

225225
class AutoDiffIndexSubset : public llvm::FoldingSetNode {
226-
private:
227-
using BitWord = uint64_t;
226+
public:
227+
typedef uint64_t BitWord;
228+
229+
static constexpr unsigned bitWordSize = sizeof(BitWord);
230+
static constexpr unsigned numBitsPerBitWord = bitWordSize * 8;
231+
232+
static std::pair<unsigned, unsigned>
233+
getBitWordIndexAndOffset(unsigned index) {
234+
auto bitWordIndex = index / numBitsPerBitWord;
235+
auto bitWordOffset = index % numBitsPerBitWord;
236+
return {bitWordIndex, bitWordOffset};
237+
}
228238

239+
static unsigned getNumBitWordsNeededForCapacity(unsigned capacity) {
240+
if (capacity == 0) return 0;
241+
return capacity / numBitsPerBitWord + 1;
242+
}
243+
244+
private:
229245
/// The total capacity of the index subset, which is `1` less than the largest
230246
/// index.
231247
unsigned capacity;
232248
/// The number of bit words in the index subset. in the index subset.
233249
unsigned numBitWords;
234250

235-
static std::pair<unsigned, unsigned> getBitWordIndexAndOffset(unsigned index);
236-
static unsigned getNumBitWordsNeededForCapacity(unsigned capacity);
237-
238-
unsigned getNumBitWords() const {
239-
return numBitWords;
240-
}
241-
242251
BitWord *getBitWordsData() {
243252
return reinterpret_cast<BitWord *>(this + 1);
244253
}
@@ -263,62 +272,159 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
263272
return {const_cast<BitWord *>(getBitWordsData()), getNumBitWords()};
264273
}
265274

266-
explicit AutoDiffIndexSubset(unsigned capacity, unsigned numBitWords,
267-
ArrayRef<unsigned> indices);
275+
explicit AutoDiffIndexSubset(unsigned capacity, ArrayRef<unsigned> indices)
276+
: capacity(capacity),
277+
numBitWords(getNumBitWordsNeededForCapacity(capacity)) {
278+
std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0);
279+
for (auto i : indices) {
280+
unsigned bitWordIndex, offset;
281+
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i);
282+
getBitWord(bitWordIndex) |= (1 << offset);
283+
}
284+
}
268285

269286
public:
270287
AutoDiffIndexSubset() = delete;
271288
AutoDiffIndexSubset(const AutoDiffIndexSubset &) = delete;
272289
AutoDiffIndexSubset &operator=(const AutoDiffIndexSubset &) = delete;
273290

274-
static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity,
275-
bool includeAll = false);
276-
static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity,
277-
IntRange<> range);
278-
static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity,
291+
// Defined in ASTContext.h.
292+
static AutoDiffIndexSubset *get(ASTContext &ctx,
293+
unsigned capacity,
279294
ArrayRef<unsigned> indices);
280295

296+
static AutoDiffIndexSubset *getDefault(ASTContext &ctx,
297+
unsigned capacity,
298+
bool includeAll = false) {
299+
if (includeAll)
300+
return getFromRange(ctx, capacity, IntRange<>(capacity));
301+
return get(ctx, capacity, {});
302+
}
303+
304+
static AutoDiffIndexSubset *getFromRange(ASTContext &ctx,
305+
unsigned capacity,
306+
IntRange<> range) {
307+
return get(ctx, capacity,
308+
SmallVector<unsigned, 8>(range.begin(), range.end()));
309+
}
310+
311+
unsigned getNumBitWords() const {
312+
return numBitWords;
313+
}
314+
281315
unsigned getCapacity() const {
282316
return capacity;
283317
}
284318

285319
class iterator;
286320

287-
iterator begin() const;
288-
iterator end() const;
289-
iterator_range<iterator> getIndices() const;
290-
unsigned getNumIndices() const;
321+
iterator begin() const {
322+
return iterator(this);
323+
}
324+
325+
iterator end() const {
326+
return iterator(this, (int)capacity);
327+
}
328+
329+
iterator_range<iterator> getIndices() const {
330+
return make_range(begin(), end());
331+
}
332+
333+
unsigned getNumIndices() const {
334+
return (unsigned)std::distance(begin(), end());
335+
}
291336

292337
bool contains(unsigned index) const {
293338
unsigned bitWordIndex, offset;
294339
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(index);
295340
return getBitWord(bitWordIndex) & (1 << offset);
296341
}
297342

298-
bool isEmpty() const;
299-
bool equals(const AutoDiffIndexSubset *other) const;
300-
bool isSubsetOf(const AutoDiffIndexSubset *other) const;
301-
bool isSupersetOf(const AutoDiffIndexSubset *other) const;
343+
bool isEmpty() const {
344+
return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; });
345+
}
346+
347+
bool equals(const AutoDiffIndexSubset *other) const {
348+
return capacity == other->getCapacity() &&
349+
getBitWords().equals(other->getBitWords());
350+
}
351+
352+
bool isSubsetOf(const AutoDiffIndexSubset *other) const {
353+
assert(capacity == other->capacity);
354+
for (auto index : range(numBitWords))
355+
if (getBitWord(index) & ~other->getBitWord(index))
356+
return false;
357+
return true;
358+
}
359+
360+
bool isSupersetOf(const AutoDiffIndexSubset *other) const {
361+
assert(capacity == other->capacity);
362+
for (auto index : range(numBitWords))
363+
if (~getBitWord(index) & other->getBitWord(index))
364+
return false;
365+
return true;
366+
}
302367

303-
AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const;
304-
AutoDiffIndexSubset *extendingCapacity(ASTContext &ctx,
305-
unsigned newCapacity) const;
368+
AutoDiffIndexSubset *adding(
369+
unsigned index, ASTContext &ctx) const {
370+
assert(index < getCapacity());
371+
SmallVector<unsigned, 8> newIndices;
372+
newIndices.reserve(capacity + 1);
373+
bool inserted = false;
374+
for (auto curIndex : getIndices()) {
375+
if (inserted && curIndex > index) {
376+
newIndices.push_back(index);
377+
inserted = false;
378+
}
379+
newIndices.push_back(curIndex);
380+
}
381+
return get(ctx, capacity, newIndices);
382+
}
306383

307-
void Profile(llvm::FoldingSetNodeID &id) const;
384+
AutoDiffIndexSubset *extendingCapacity(
385+
ASTContext &ctx, unsigned newCapacity) const {
386+
assert(newCapacity >= capacity);
387+
if (newCapacity == capacity)
388+
return const_cast<AutoDiffIndexSubset *>(this);
389+
SmallVector<unsigned, 8> indices;
390+
for (auto index : getIndices())
391+
indices.push_back(index);
392+
return AutoDiffIndexSubset::get(ctx, newCapacity, indices);
393+
}
394+
395+
void Profile(llvm::FoldingSetNodeID &id) const {
396+
id.AddInteger(capacity);
397+
for (auto index : getIndices())
398+
id.AddInteger(index);
399+
}
400+
401+
void print(llvm::raw_ostream &s = llvm::outs()) const {
402+
s << '{';
403+
interleave(range(capacity), [this, &s](unsigned i) { s << contains(i); },
404+
[&s] { s << ", "; });
405+
s << '}';
406+
}
407+
408+
void dump(llvm::raw_ostream &s = llvm::errs()) const {
409+
s << "(autodiff_index_subset capacity=" << capacity << " indices=(";
410+
interleave(getIndices(), [&s](unsigned i) { s << i; },
411+
[&s] { s << ", "; });
412+
s << "))";
413+
}
308414

309-
private:
310415
int findNext(int startIndex) const;
311416
int findFirst() const { return findNext(-1); }
312417
int findPrevious(int endIndex) const;
313418
int findLast() const { return findPrevious(capacity); }
314419

315-
public:
316420
class iterator {
317-
typedef unsigned value_type;
318-
typedef int difference_type;
319-
typedef unsigned * pointer;
320-
typedef unsigned & reference;
321-
typedef std::forward_iterator_tag iterator_category;
421+
public:
422+
typedef unsigned value_type;
423+
typedef unsigned difference_type;
424+
typedef unsigned * pointer;
425+
typedef unsigned & reference;
426+
typedef std::forward_iterator_tag iterator_category;
427+
322428
private:
323429
const AutoDiffIndexSubset *parent;
324430
int current = 0;
@@ -349,42 +455,19 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
349455
unsigned operator*() const { return current; }
350456

351457
bool operator==(const iterator &other) const {
352-
assert(&parent == &other.parent &&
458+
assert(parent == other.parent &&
353459
"Comparing iterators from different AutoDiffIndexSubsets");
354460
return current == other.current;
355461
}
356462

357463
bool operator!=(const iterator &other) const {
358-
assert(&parent == &other.parent &&
464+
assert(parent == other.parent &&
359465
"Comparing iterators from different AutoDiffIndexSubsets");
360466
return current != other.current;
361467
}
362468
};
363469
};
364470

365-
class AutoDiffFunctionParameterSubset {
366-
private:
367-
AutoDiffIndexSubset *indexSubset;
368-
bool curried;
369-
370-
public:
371-
explicit AutoDiffFunctionParameterSubset(
372-
AutoDiffIndexSubset *indexSubset, bool isCurried)
373-
: indexSubset(indexSubset), curried(isCurried) {}
374-
375-
explicit AutoDiffFunctionParameterSubset(
376-
ASTContext &ctx, AutoDiffIndexSubset *parameterSubset,
377-
Optional<bool> isSelfIncluded);
378-
379-
AutoDiffIndexSubset *getIndexSubset() const {
380-
return indexSubset;
381-
}
382-
383-
bool isCurried() const {
384-
return curried;
385-
}
386-
};
387-
388471
/// SIL-level automatic differentiation indices. Consists of a source index,
389472
/// i.e. index of the dependent result to differentiate from, and parameter
390473
/// indices, i.e. index of independent parameters to differentiate with

include/swift/Serialization/ModuleFormat.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5252
/// describe what change you made. The content of this comment isn't important;
5353
/// it just ensures a conflict if two people change the module format.
5454
/// Don't worry about adhering to the 80-column limit for this line.
55-
const uint16_t SWIFTMODULE_VERSION_MINOR = 489; // Last change: `@differentiating` wrt
55+
const uint16_t SWIFTMODULE_VERSION_MINOR = 490; // Last change: `@differentiable` parameter indices layout.
5656

5757
using DeclIDField = BCFixed<31>;
5858

lib/AST/ASTContext.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4536,9 +4536,9 @@ AutoDiffParameterIndices::get(llvm::SmallBitVector indices, ASTContext &C) {
45364536
return newNode;
45374537
}
45384538

4539-
AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx,
4540-
unsigned capacity,
4541-
ArrayRef<unsigned> indices) {
4539+
AutoDiffIndexSubset *
4540+
AutoDiffIndexSubset::get(ASTContext &ctx, unsigned capacity,
4541+
ArrayRef<unsigned> indices) {
45424542
auto &foldingSet = ctx.getImpl().AutoDiffIndexSubsets;
45434543
llvm::FoldingSetNodeID id;
45444544
id.AddInteger(capacity);
@@ -4548,19 +4548,19 @@ AutoDiffIndexSubset *AutoDiffIndexSubset::get(ASTContext &ctx,
45484548
for (unsigned index : indices) {
45494549
#ifndef NDEBUG
45504550
assert((int)index > last && "Indices must be ascending");
4551-
last = index;
4551+
last = (int)index;
45524552
#endif
45534553
id.AddInteger(index);
45544554
}
45554555
void *insertPos = nullptr;
45564556
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
45574557
if (existing)
45584558
return existing;
4559-
auto numBitWords = sizeof(AutoDiffIndexSubset) +
4560-
getNumBitWordsNeededForCapacity(capacity);
4559+
auto sizeToAlloc = sizeof(AutoDiffIndexSubset) +
4560+
getNumBitWordsNeededForCapacity(capacity);
45614561
auto *buf = reinterpret_cast<AutoDiffIndexSubset *>(
4562-
ctx.Allocate(numBitWords, alignof(AutoDiffIndexSubset)));
4563-
auto *newNode = new (buf) AutoDiffIndexSubset(capacity, numBitWords, indices);
4562+
ctx.Allocate(sizeToAlloc, alignof(AutoDiffIndexSubset)));
4563+
auto *newNode = new (buf) AutoDiffIndexSubset(capacity, indices);
45644564
foldingSet.InsertNode(newNode, insertPos);
45654565
return newNode;
45664566
}

0 commit comments

Comments
 (0)