Skip to content

Commit fd3a880

Browse files
committed
[AutoDiff upstream] [AST] Add 'AutoDiffIndexSubset' data structure.
`AutoDiffIndexSubset` is a fixed-size bit vector that is used for efficiently representing a subset of indices in automatic differentiation, specifically for representing a subset of parameters and results of a function to differentiate with respect to. It is uniqued in `ASTContext`. This patch adds definition and unit tests for `AutoDiffIndexSubset` along with new files `AutoDiff.h` and `AutoDiff.cpp` into the 'AST' target, with no changes to the compiler's behavior. More data structures used for AutoDiff will be added to these files. ---------------------------- This is part of the ongoing effort to merge the experimental [differentiable programming feature](https://forums.swift.org/t/differentiable-programming-mega-proposal/28547) (informally referred to as "AutoDiff") to the 'master' branch for code reviews and better maintenance. Upstreaming task: [TF-879](https://bugs.swift.org/browse/TF-879)
1 parent b32e82c commit fd3a880

File tree

7 files changed

+639
-0
lines changed

7 files changed

+639
-0
lines changed

include/swift/AST/ASTContext.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ namespace swift {
109109
class TypeAliasDecl;
110110
class VarDecl;
111111
class UnifiedStatsReporter;
112+
class AutoDiffIndexSubset;
112113

113114
enum class KnownProtocolKind : uint8_t;
114115

include/swift/AST/AutoDiff.h

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
//===--- AutoDiff.h - Swift Differentiable Programming --------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines AST support for the experimental differentiable
14+
// programming feature.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
19+
#ifndef SWIFT_AST_AUTODIFF_H
20+
#define SWIFT_AST_AUTODIFF_H
21+
22+
#include "ASTContext.h"
23+
#include "llvm/ADT/SmallBitVector.h"
24+
#include "swift/Basic/Range.h"
25+
26+
namespace swift {
27+
28+
/// An efficient index subset data structure, uniqued in `ASTContext`.
29+
/// Stores a bit vector representing set indices and a total capacity.
30+
class AutoDiffIndexSubset : public llvm::FoldingSetNode {
31+
public:
32+
typedef uint64_t BitWord;
33+
34+
static constexpr unsigned bitWordSize = sizeof(BitWord);
35+
static constexpr unsigned numBitsPerBitWord = bitWordSize * 8;
36+
37+
static std::pair<unsigned, unsigned>
38+
getBitWordIndexAndOffset(unsigned index) {
39+
auto bitWordIndex = index / numBitsPerBitWord;
40+
auto bitWordOffset = index % numBitsPerBitWord;
41+
return {bitWordIndex, bitWordOffset};
42+
}
43+
44+
static unsigned getNumBitWordsNeededForCapacity(unsigned capacity) {
45+
if (capacity == 0) return 0;
46+
return capacity / numBitsPerBitWord + 1;
47+
}
48+
49+
private:
50+
/// The total capacity of the index subset, which is `1` less than the largest
51+
/// index.
52+
unsigned capacity;
53+
/// The number of bit words in the index subset.
54+
unsigned numBitWords;
55+
56+
BitWord *getBitWordsData() {
57+
return reinterpret_cast<BitWord *>(this + 1);
58+
}
59+
60+
const BitWord *getBitWordsData() const {
61+
return reinterpret_cast<const BitWord *>(this + 1);
62+
}
63+
64+
ArrayRef<BitWord> getBitWords() const {
65+
return {getBitWordsData(), getNumBitWords()};
66+
}
67+
68+
BitWord getBitWord(unsigned i) const {
69+
return getBitWordsData()[i];
70+
}
71+
72+
BitWord &getBitWord(unsigned i) {
73+
return getBitWordsData()[i];
74+
}
75+
76+
MutableArrayRef<BitWord> getMutableBitWords() {
77+
return {const_cast<BitWord *>(getBitWordsData()), getNumBitWords()};
78+
}
79+
80+
explicit AutoDiffIndexSubset(const SmallBitVector &indices)
81+
: capacity((unsigned)indices.size()),
82+
numBitWords(getNumBitWordsNeededForCapacity(capacity)) {
83+
std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0);
84+
for (auto i : indices.set_bits()) {
85+
unsigned bitWordIndex, offset;
86+
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i);
87+
getBitWord(bitWordIndex) |= (1 << offset);
88+
}
89+
}
90+
91+
public:
92+
AutoDiffIndexSubset() = delete;
93+
AutoDiffIndexSubset(const AutoDiffIndexSubset &) = delete;
94+
AutoDiffIndexSubset &operator=(const AutoDiffIndexSubset &) = delete;
95+
96+
// Defined in ASTContext.cpp.
97+
static AutoDiffIndexSubset *get(ASTContext &ctx,
98+
const SmallBitVector &indices);
99+
100+
static AutoDiffIndexSubset *get(ASTContext &ctx, unsigned capacity,
101+
ArrayRef<unsigned> indices) {
102+
SmallBitVector indicesBitVec(capacity, false);
103+
for (auto index : indices)
104+
indicesBitVec.set(index);
105+
return AutoDiffIndexSubset::get(ctx, indicesBitVec);
106+
}
107+
108+
static AutoDiffIndexSubset *getDefault(ASTContext &ctx, unsigned capacity,
109+
bool includeAll = false) {
110+
return get(ctx, SmallBitVector(capacity, includeAll));
111+
}
112+
113+
static AutoDiffIndexSubset *getFromRange(ASTContext &ctx, unsigned capacity,
114+
unsigned start, unsigned end) {
115+
assert(start < capacity);
116+
assert(end <= capacity);
117+
SmallBitVector bitVec(capacity);
118+
bitVec.set(start, end);
119+
return get(ctx, bitVec);
120+
}
121+
122+
/// Creates an index subset corresponding to the given string generated by
123+
/// `getString()`. If the string is invalid, returns nullptr.
124+
static AutoDiffIndexSubset *getFromString(ASTContext &ctx, StringRef string);
125+
126+
/// Returns the number of bit words used to store the index subset.
127+
// Note: Use `getCapacity()` to get the total index subset capacity.
128+
// This is public only for unit testing
129+
// (in unittests/AST/SILAutoDiffIndices.cpp).
130+
unsigned getNumBitWords() const {
131+
return numBitWords;
132+
}
133+
134+
/// Returns the capacity of the index subset.
135+
unsigned getCapacity() const {
136+
return capacity;
137+
}
138+
139+
/// Returns a textual string description of these indices.
140+
///
141+
/// It has the format `[SU]+`, where the total number of characters is equal
142+
/// to the capacity, and where "S" means that the corresponding index is
143+
/// contained and "U" means that the corresponding index is not.
144+
std::string getString() const;
145+
146+
class iterator;
147+
148+
iterator begin() const {
149+
return iterator(this);
150+
}
151+
152+
iterator end() const {
153+
return iterator(this, (int)capacity);
154+
}
155+
156+
/// Returns an iterator range of indices in the index subset.
157+
iterator_range<iterator> getIndices() const {
158+
return make_range(begin(), end());
159+
}
160+
161+
/// Returns the number of indices in the index subset.
162+
unsigned getNumIndices() const {
163+
return (unsigned)std::distance(begin(), end());
164+
}
165+
166+
SmallBitVector getBitVector() const {
167+
SmallBitVector indicesBitVec(capacity, false);
168+
for (auto index : getIndices())
169+
indicesBitVec.set(index);
170+
return indicesBitVec;
171+
}
172+
173+
bool contains(unsigned index) const {
174+
unsigned bitWordIndex, offset;
175+
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(index);
176+
return getBitWord(bitWordIndex) & (1 << offset);
177+
}
178+
179+
bool isEmpty() const {
180+
return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; });
181+
}
182+
183+
bool equals(AutoDiffIndexSubset *other) const {
184+
return capacity == other->getCapacity() &&
185+
getBitWords().equals(other->getBitWords());
186+
}
187+
188+
bool isSubsetOf(AutoDiffIndexSubset *other) const;
189+
bool isSupersetOf(AutoDiffIndexSubset *other) const;
190+
191+
AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const;
192+
AutoDiffIndexSubset *extendingCapacity(ASTContext &ctx,
193+
unsigned newCapacity) const;
194+
195+
void Profile(llvm::FoldingSetNodeID &id) const {
196+
id.AddInteger(capacity);
197+
for (auto index : getIndices())
198+
id.AddInteger(index);
199+
}
200+
201+
void print(llvm::raw_ostream &s = llvm::outs()) const {
202+
s << '{';
203+
interleave(range(capacity), [this, &s](unsigned i) { s << contains(i); },
204+
[&s] { s << ", "; });
205+
s << '}';
206+
}
207+
208+
void dump(llvm::raw_ostream &s = llvm::errs()) const {
209+
s << "(autodiff_index_subset capacity=" << capacity << " indices=(";
210+
interleave(getIndices(), [&s](unsigned i) { s << i; },
211+
[&s] { s << ", "; });
212+
s << "))";
213+
}
214+
215+
int findNext(int startIndex) const;
216+
int findFirst() const { return findNext(-1); }
217+
int findPrevious(int endIndex) const;
218+
int findLast() const { return findPrevious(capacity); }
219+
220+
class iterator {
221+
public:
222+
typedef unsigned value_type;
223+
typedef unsigned difference_type;
224+
typedef unsigned * pointer;
225+
typedef unsigned & reference;
226+
typedef std::forward_iterator_tag iterator_category;
227+
228+
private:
229+
const AutoDiffIndexSubset *parent;
230+
int current = 0;
231+
232+
void advance() {
233+
assert(current != -1 && "Trying to advance past end.");
234+
current = parent->findNext(current);
235+
}
236+
237+
public:
238+
iterator(const AutoDiffIndexSubset *parent, int current)
239+
: parent(parent), current(current) {}
240+
explicit iterator(const AutoDiffIndexSubset *parent)
241+
: iterator(parent, parent->findFirst()) {}
242+
iterator(const iterator &) = default;
243+
244+
iterator operator++(int) {
245+
auto prev = *this;
246+
advance();
247+
return prev;
248+
}
249+
250+
iterator &operator++() {
251+
advance();
252+
return *this;
253+
}
254+
255+
unsigned operator*() const { return current; }
256+
257+
bool operator==(const iterator &other) const {
258+
assert(parent == other.parent &&
259+
"Comparing iterators from different AutoDiffIndexSubsets");
260+
return current == other.current;
261+
}
262+
263+
bool operator!=(const iterator &other) const {
264+
assert(parent == other.parent &&
265+
"Comparing iterators from different AutoDiffIndexSubsets");
266+
return current != other.current;
267+
}
268+
};
269+
};
270+
271+
} // end namespace swift
272+
273+
#endif // SWIFT_AST_AUTODIFF_H

lib/AST/ASTContext.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "swift/AST/ASTContext.h"
1818
#include "ForeignRepresentationInfo.h"
1919
#include "SubstitutionMapStorage.h"
20+
#include "swift/AST/AutoDiff.h"
2021
#include "swift/AST/ClangModuleLoader.h"
2122
#include "swift/AST/ConcreteDeclRef.h"
2223
#include "swift/AST/DiagnosticEngine.h"
@@ -427,6 +428,9 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL)
427428
llvm::FoldingSet<DeclName::CompoundDeclName> CompoundNames;
428429
llvm::DenseMap<UUID, OpenedArchetypeType *> OpenedExistentialArchetypes;
429430

431+
/// For uniquifying `AutoDiffIndexSubset` allocations.
432+
llvm::FoldingSet<AutoDiffIndexSubset> AutoDiffIndexSubsets;
433+
430434
/// A cache of information about whether particular nominal types
431435
/// are representable in a foreign language.
432436
llvm::DenseMap<NominalTypeDecl *, ForeignRepresentationInfo>
@@ -4616,3 +4620,24 @@ void VarDecl::setOriginalWrappedProperty(VarDecl *originalProperty) {
46164620
assert(ctx.getImpl().OriginalWrappedProperties.count(this) == 0);
46174621
ctx.getImpl().OriginalWrappedProperties[this] = originalProperty;
46184622
}
4623+
4624+
AutoDiffIndexSubset *
4625+
AutoDiffIndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) {
4626+
auto &foldingSet = ctx.getImpl().AutoDiffIndexSubsets;
4627+
llvm::FoldingSetNodeID id;
4628+
unsigned capacity = indices.size();
4629+
id.AddInteger(capacity);
4630+
for (unsigned index : indices.set_bits())
4631+
id.AddInteger(index);
4632+
void *insertPos = nullptr;
4633+
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
4634+
if (existing)
4635+
return existing;
4636+
auto sizeToAlloc = sizeof(AutoDiffIndexSubset) +
4637+
getNumBitWordsNeededForCapacity(capacity);
4638+
auto *buf = reinterpret_cast<AutoDiffIndexSubset *>(
4639+
ctx.Allocate(sizeToAlloc, alignof(AutoDiffIndexSubset)));
4640+
auto *newNode = new (buf) AutoDiffIndexSubset(indices);
4641+
foldingSet.InsertNode(newNode, insertPos);
4642+
return newNode;
4643+
}

0 commit comments

Comments
 (0)