Skip to content

Commit adfd8b3

Browse files
authored
Merge pull request #27555 from rxwei/ad-upstream-autodiff-index-subset
2 parents 94cafc0 + 0e9425e commit adfd8b3

File tree

7 files changed

+619
-0
lines changed

7 files changed

+619
-0
lines changed

include/swift/AST/ASTContext.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ namespace swift {
109109
class TypeAliasDecl;
110110
class VarDecl;
111111
class UnifiedStatsReporter;
112+
class IndexSubset;
112113

113114
enum class KnownProtocolKind : uint8_t;
114115

include/swift/AST/IndexSubset.h

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
//===---------- IndexSubset.h - Fixed-size subset of indices --------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines the `IndexSubset` class and support logic.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#ifndef SWIFT_AST_INDEXSUBSET_H
18+
#define SWIFT_AST_INDEXSUBSET_H
19+
20+
#include "swift/Basic/LLVM.h"
21+
#include "swift/Basic/Range.h"
22+
#include "swift/Basic/STLExtras.h"
23+
#include "llvm/ADT/ArrayRef.h"
24+
#include "llvm/ADT/FoldingSet.h"
25+
#include "llvm/ADT/SmallBitVector.h"
26+
#include "llvm/Support/raw_ostream.h"
27+
28+
namespace swift {
29+
30+
class ASTContext;
31+
32+
/// An efficient index subset data structure, uniqued in `ASTContext`.
33+
/// Stores a bit vector representing set indices and a total capacity.
34+
class IndexSubset : public llvm::FoldingSetNode {
35+
public:
36+
typedef uint64_t BitWord;
37+
38+
static constexpr unsigned bitWordSize = sizeof(BitWord);
39+
static constexpr unsigned numBitsPerBitWord = bitWordSize * 8;
40+
41+
static std::pair<unsigned, unsigned>
42+
getBitWordIndexAndOffset(unsigned index) {
43+
auto bitWordIndex = index / numBitsPerBitWord;
44+
auto bitWordOffset = index % numBitsPerBitWord;
45+
return {bitWordIndex, bitWordOffset};
46+
}
47+
48+
static unsigned getNumBitWordsNeededForCapacity(unsigned capacity) {
49+
if (capacity == 0) return 0;
50+
return capacity / numBitsPerBitWord + 1;
51+
}
52+
53+
private:
54+
/// The total capacity of the index subset, which is `1` less than the largest
55+
/// index.
56+
unsigned capacity;
57+
/// The number of bit words in the index subset.
58+
unsigned numBitWords;
59+
60+
BitWord *getBitWordsData() {
61+
return reinterpret_cast<BitWord *>(this + 1);
62+
}
63+
64+
const BitWord *getBitWordsData() const {
65+
return reinterpret_cast<const BitWord *>(this + 1);
66+
}
67+
68+
ArrayRef<BitWord> getBitWords() const {
69+
return {getBitWordsData(), getNumBitWords()};
70+
}
71+
72+
BitWord getBitWord(unsigned i) const {
73+
return getBitWordsData()[i];
74+
}
75+
76+
BitWord &getBitWord(unsigned i) {
77+
return getBitWordsData()[i];
78+
}
79+
80+
MutableArrayRef<BitWord> getMutableBitWords() {
81+
return {const_cast<BitWord *>(getBitWordsData()), getNumBitWords()};
82+
}
83+
84+
explicit IndexSubset(const SmallBitVector &indices)
85+
: capacity((unsigned)indices.size()),
86+
numBitWords(getNumBitWordsNeededForCapacity(capacity)) {
87+
std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0);
88+
for (auto i : indices.set_bits()) {
89+
unsigned bitWordIndex, offset;
90+
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i);
91+
getBitWord(bitWordIndex) |= (1 << offset);
92+
}
93+
}
94+
95+
public:
96+
IndexSubset() = delete;
97+
IndexSubset(const IndexSubset &) = delete;
98+
IndexSubset &operator=(const IndexSubset &) = delete;
99+
100+
// Defined in ASTContext.cpp.
101+
static IndexSubset *get(ASTContext &ctx, const SmallBitVector &indices);
102+
103+
static IndexSubset *get(ASTContext &ctx, unsigned capacity,
104+
ArrayRef<unsigned> indices) {
105+
SmallBitVector indicesBitVec(capacity, false);
106+
for (auto index : indices)
107+
indicesBitVec.set(index);
108+
return IndexSubset::get(ctx, indicesBitVec);
109+
}
110+
111+
static IndexSubset *getDefault(ASTContext &ctx, unsigned capacity,
112+
bool includeAll = false) {
113+
return get(ctx, SmallBitVector(capacity, includeAll));
114+
}
115+
116+
static IndexSubset *getFromRange(ASTContext &ctx, unsigned capacity,
117+
unsigned start, unsigned end) {
118+
assert(start < capacity);
119+
assert(end <= capacity);
120+
SmallBitVector bitVec(capacity);
121+
bitVec.set(start, end);
122+
return get(ctx, bitVec);
123+
}
124+
125+
/// Creates an index subset corresponding to the given string generated by
126+
/// `getString()`. If the string is invalid, returns nullptr.
127+
static IndexSubset *getFromString(ASTContext &ctx, StringRef string);
128+
129+
/// Returns the number of bit words used to store the index subset.
130+
// Note: Use `getCapacity()` to get the total index subset capacity.
131+
// This is public only for unit testing
132+
// (in unittests/AST/SILAutoDiffIndices.cpp).
133+
unsigned getNumBitWords() const {
134+
return numBitWords;
135+
}
136+
137+
/// Returns the capacity of the index subset.
138+
unsigned getCapacity() const {
139+
return capacity;
140+
}
141+
142+
/// Returns a textual string description of these indices.
143+
///
144+
/// It has the format `[SU]+`, where the total number of characters is equal
145+
/// to the capacity, and where "S" means that the corresponding index is
146+
/// contained and "U" means that the corresponding index is not.
147+
std::string getString() const;
148+
149+
class iterator;
150+
151+
iterator begin() const {
152+
return iterator(this);
153+
}
154+
155+
iterator end() const {
156+
return iterator(this, (int)capacity);
157+
}
158+
159+
/// Returns an iterator range of indices in the index subset.
160+
iterator_range<iterator> getIndices() const {
161+
return make_range(begin(), end());
162+
}
163+
164+
/// Returns the number of indices in the index subset.
165+
unsigned getNumIndices() const {
166+
return (unsigned)std::distance(begin(), end());
167+
}
168+
169+
SmallBitVector getBitVector() const {
170+
SmallBitVector indicesBitVec(capacity, false);
171+
for (auto index : getIndices())
172+
indicesBitVec.set(index);
173+
return indicesBitVec;
174+
}
175+
176+
bool contains(unsigned index) const {
177+
unsigned bitWordIndex, offset;
178+
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(index);
179+
return getBitWord(bitWordIndex) & (1 << offset);
180+
}
181+
182+
bool isEmpty() const {
183+
return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; });
184+
}
185+
186+
bool equals(IndexSubset *other) const {
187+
return capacity == other->getCapacity() &&
188+
getBitWords().equals(other->getBitWords());
189+
}
190+
191+
bool isSubsetOf(IndexSubset *other) const;
192+
bool isSupersetOf(IndexSubset *other) const;
193+
194+
IndexSubset *adding(unsigned index, ASTContext &ctx) const;
195+
IndexSubset *extendingCapacity(ASTContext &ctx,
196+
unsigned newCapacity) const;
197+
198+
void Profile(llvm::FoldingSetNodeID &id) const {
199+
id.AddInteger(capacity);
200+
for (auto index : getIndices())
201+
id.AddInteger(index);
202+
}
203+
204+
void print(llvm::raw_ostream &s = llvm::outs()) const {
205+
s << '{';
206+
interleave(range(capacity), [this, &s](unsigned i) { s << contains(i); },
207+
[&s] { s << ", "; });
208+
s << '}';
209+
}
210+
211+
void dump(llvm::raw_ostream &s = llvm::errs()) const {
212+
s << "(index_subset capacity=" << capacity << " indices=(";
213+
interleave(getIndices(), [&s](unsigned i) { s << i; },
214+
[&s] { s << ", "; });
215+
s << "))";
216+
}
217+
218+
int findNext(int startIndex) const;
219+
int findFirst() const { return findNext(-1); }
220+
int findPrevious(int endIndex) const;
221+
int findLast() const { return findPrevious(capacity); }
222+
223+
class iterator {
224+
public:
225+
typedef unsigned value_type;
226+
typedef unsigned difference_type;
227+
typedef unsigned * pointer;
228+
typedef unsigned & reference;
229+
typedef std::forward_iterator_tag iterator_category;
230+
231+
private:
232+
const IndexSubset *parent;
233+
int current = 0;
234+
235+
void advance() {
236+
assert(current != -1 && "Trying to advance past end.");
237+
current = parent->findNext(current);
238+
}
239+
240+
public:
241+
iterator(const IndexSubset *parent, int current)
242+
: parent(parent), current(current) {}
243+
explicit iterator(const IndexSubset *parent)
244+
: iterator(parent, parent->findFirst()) {}
245+
iterator(const iterator &) = default;
246+
247+
iterator operator++(int) {
248+
auto prev = *this;
249+
advance();
250+
return prev;
251+
}
252+
253+
iterator &operator++() {
254+
advance();
255+
return *this;
256+
}
257+
258+
unsigned operator*() const { return current; }
259+
260+
bool operator==(const iterator &other) const {
261+
assert(parent == other.parent &&
262+
"Comparing iterators from different IndexSubsets");
263+
return current == other.current;
264+
}
265+
266+
bool operator!=(const iterator &other) const {
267+
assert(parent == other.parent &&
268+
"Comparing iterators from different IndexSubsets");
269+
return current != other.current;
270+
}
271+
};
272+
};
273+
274+
}
275+
276+
#endif // SWIFT_AST_INDEXSUBSET_H

lib/AST/ASTContext.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "swift/AST/GenericSignature.h"
2929
#include "swift/AST/GenericSignatureBuilder.h"
3030
#include "swift/AST/ImportCache.h"
31+
#include "swift/AST/IndexSubset.h"
3132
#include "swift/AST/KnownProtocols.h"
3233
#include "swift/AST/LazyResolver.h"
3334
#include "swift/AST/ModuleLoader.h"
@@ -427,6 +428,9 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL)
427428
llvm::FoldingSet<DeclName::CompoundDeclName> CompoundNames;
428429
llvm::DenseMap<UUID, OpenedArchetypeType *> OpenedExistentialArchetypes;
429430

431+
/// For uniquifying `IndexSubset` allocations.
432+
llvm::FoldingSet<IndexSubset> IndexSubsets;
433+
430434
/// A cache of information about whether particular nominal types
431435
/// are representable in a foreign language.
432436
llvm::DenseMap<NominalTypeDecl *, ForeignRepresentationInfo>
@@ -4616,3 +4620,24 @@ void VarDecl::setOriginalWrappedProperty(VarDecl *originalProperty) {
46164620
assert(ctx.getImpl().OriginalWrappedProperties.count(this) == 0);
46174621
ctx.getImpl().OriginalWrappedProperties[this] = originalProperty;
46184622
}
4623+
4624+
IndexSubset *
4625+
IndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) {
4626+
auto &foldingSet = ctx.getImpl().IndexSubsets;
4627+
llvm::FoldingSetNodeID id;
4628+
unsigned capacity = indices.size();
4629+
id.AddInteger(capacity);
4630+
for (unsigned index : indices.set_bits())
4631+
id.AddInteger(index);
4632+
void *insertPos = nullptr;
4633+
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
4634+
if (existing)
4635+
return existing;
4636+
auto sizeToAlloc = sizeof(IndexSubset) +
4637+
getNumBitWordsNeededForCapacity(capacity);
4638+
auto *buf = reinterpret_cast<IndexSubset *>(
4639+
ctx.Allocate(sizeToAlloc, alignof(IndexSubset)));
4640+
auto *newNode = new (buf) IndexSubset(indices);
4641+
foldingSet.InsertNode(newNode, insertPos);
4642+
return newNode;
4643+
}

lib/AST/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ add_swift_host_library(swiftAST STATIC
2929
ASTVerifier.cpp
3030
ASTWalker.cpp
3131
Attr.cpp
32+
IndexSubset.cpp
3233
Availability.cpp
3334
AvailabilitySpec.cpp
3435
Builtins.cpp

0 commit comments

Comments
 (0)