Skip to content

Commit 3d3f9e4

Browse files
authored
[AutoDiff] Parameter indices data structure overhaul (#24761)
### Introduce `AutoDiffIndexSubset` This PR overhauls data structures used for parameter indices in the AutoDiff infrastructure in SIL. Previously, we used `llvm::SmallBitVector` to represent differentiation parameter indices in both AST and SIL. It was not efficient, and most importantly there's no way to put this in an instruction without causing memory leaks. This change replaces all uses of `llvm::SmallBitVector` in SIL AutoDiff code paths with a `ASTContext`-uniqued `AutoDiffIndexSubset *` where bits are stored as trailing objects. `AutoDiffIndexSubset` does not have "parameter indices" in its name because it is not only designed for parameter indices, but also for result indices as we move to supporting multi-result differentiation. `AutoDiffIndexSubset` has set operations like `isSubsetOf`, `isSupersetOf`, and `contains`, but it also has a special _capacity_ property. All differentiable function's parameter indices data should store the number of parameters as capacity, so that the differentiation transform won't need special logic to check whether an index is out of range. Another minor change is the module format layout of `SILDifferentiableAttr`. It used to store parameter indices as consecutive `bool` bits, but now stores numeric parameter indices directly for efficiency. It will be necessary to refactor or eliminate `AutoDiffParameterIndices` to make use of `AutoDiffIndexSubset`. `AutoDiffParameterIndices` is at the AST level, so it is not in the scope for this PR. Unblocks #23482. Partially resolves [TF-67](https://bugs.swift.org/browse/TF-67).
1 parent 054c4be commit 3d3f9e4

24 files changed

+680
-236
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 224 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "ASTContext.h"
2222
#include "llvm/ADT/SmallBitVector.h"
23+
#include "swift/Basic/Range.h"
2324

2425
namespace swift {
2526

@@ -73,6 +74,7 @@ class ParsedAutoDiffParameter {
7374
};
7475

7576
class AnyFunctionType;
77+
class AutoDiffIndexSubset;
7678
class AutoDiffParameterIndicesBuilder;
7779
class Type;
7880

@@ -173,7 +175,8 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
173175
/// ==> returns 1110
174176
/// (because the lowered SIL type is (A, B, C, D) -> R)
175177
///
176-
llvm::SmallBitVector getLowered(AnyFunctionType *functionType) const;
178+
AutoDiffIndexSubset *getLowered(ASTContext &ctx,
179+
AnyFunctionType *functionType) const;
177180

178181
void Profile(llvm::FoldingSetNodeID &ID) const {
179182
ID.AddInteger(parameters.size());
@@ -219,6 +222,216 @@ class AutoDiffParameterIndicesBuilder {
219222
unsigned size() { return parameters.size(); }
220223
};
221224

225+
class AutoDiffIndexSubset : public llvm::FoldingSetNode {
226+
public:
227+
typedef uint64_t BitWord;
228+
229+
static constexpr unsigned bitWordSize = sizeof(BitWord);
230+
static constexpr unsigned numBitsPerBitWord = bitWordSize * 8;
231+
232+
static std::pair<unsigned, unsigned>
233+
getBitWordIndexAndOffset(unsigned index) {
234+
auto bitWordIndex = index / numBitsPerBitWord;
235+
auto bitWordOffset = index % numBitsPerBitWord;
236+
return {bitWordIndex, bitWordOffset};
237+
}
238+
239+
static unsigned getNumBitWordsNeededForCapacity(unsigned capacity) {
240+
if (capacity == 0) return 0;
241+
return capacity / numBitsPerBitWord + 1;
242+
}
243+
244+
private:
245+
/// The total capacity of the index subset, which is `1` less than the largest
246+
/// index.
247+
unsigned capacity;
248+
/// The number of bit words in the index subset.
249+
unsigned numBitWords;
250+
251+
BitWord *getBitWordsData() {
252+
return reinterpret_cast<BitWord *>(this + 1);
253+
}
254+
255+
const BitWord *getBitWordsData() const {
256+
return reinterpret_cast<const BitWord *>(this + 1);
257+
}
258+
259+
ArrayRef<BitWord> getBitWords() const {
260+
return {getBitWordsData(), getNumBitWords()};
261+
}
262+
263+
BitWord getBitWord(unsigned i) const {
264+
return getBitWordsData()[i];
265+
}
266+
267+
BitWord &getBitWord(unsigned i) {
268+
return getBitWordsData()[i];
269+
}
270+
271+
MutableArrayRef<BitWord> getMutableBitWords() {
272+
return {const_cast<BitWord *>(getBitWordsData()), getNumBitWords()};
273+
}
274+
275+
explicit AutoDiffIndexSubset(unsigned capacity, ArrayRef<unsigned> indices)
276+
: capacity(capacity),
277+
numBitWords(getNumBitWordsNeededForCapacity(capacity)) {
278+
std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0);
279+
for (auto i : indices) {
280+
unsigned bitWordIndex, offset;
281+
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i);
282+
getBitWord(bitWordIndex) |= (1 << offset);
283+
}
284+
}
285+
286+
public:
287+
AutoDiffIndexSubset() = delete;
288+
AutoDiffIndexSubset(const AutoDiffIndexSubset &) = delete;
289+
AutoDiffIndexSubset &operator=(const AutoDiffIndexSubset &) = delete;
290+
291+
// Defined in ASTContext.h.
292+
static AutoDiffIndexSubset *get(ASTContext &ctx,
293+
unsigned capacity,
294+
ArrayRef<unsigned> indices);
295+
296+
static AutoDiffIndexSubset *getDefault(ASTContext &ctx,
297+
unsigned capacity,
298+
bool includeAll = false) {
299+
if (includeAll)
300+
return getFromRange(ctx, capacity, IntRange<>(capacity));
301+
return get(ctx, capacity, {});
302+
}
303+
304+
static AutoDiffIndexSubset *getFromRange(ASTContext &ctx,
305+
unsigned capacity,
306+
IntRange<> range) {
307+
return get(ctx, capacity,
308+
SmallVector<unsigned, 8>(range.begin(), range.end()));
309+
}
310+
311+
unsigned getNumBitWords() const {
312+
return numBitWords;
313+
}
314+
315+
unsigned getCapacity() const {
316+
return capacity;
317+
}
318+
319+
class iterator;
320+
321+
iterator begin() const {
322+
return iterator(this);
323+
}
324+
325+
iterator end() const {
326+
return iterator(this, (int)capacity);
327+
}
328+
329+
iterator_range<iterator> getIndices() const {
330+
return make_range(begin(), end());
331+
}
332+
333+
unsigned getNumIndices() const {
334+
return (unsigned)std::distance(begin(), end());
335+
}
336+
337+
bool contains(unsigned index) const {
338+
unsigned bitWordIndex, offset;
339+
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(index);
340+
return getBitWord(bitWordIndex) & (1 << offset);
341+
}
342+
343+
bool isEmpty() const {
344+
return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; });
345+
}
346+
347+
bool equals(AutoDiffIndexSubset *other) const {
348+
return capacity == other->getCapacity() &&
349+
getBitWords().equals(other->getBitWords());
350+
}
351+
352+
bool isSubsetOf(AutoDiffIndexSubset *other) const;
353+
bool isSupersetOf(AutoDiffIndexSubset *other) const;
354+
355+
AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const;
356+
AutoDiffIndexSubset *extendingCapacity(ASTContext &ctx,
357+
unsigned newCapacity) const;
358+
359+
void Profile(llvm::FoldingSetNodeID &id) const {
360+
id.AddInteger(capacity);
361+
for (auto index : getIndices())
362+
id.AddInteger(index);
363+
}
364+
365+
void print(llvm::raw_ostream &s = llvm::outs()) const {
366+
s << '{';
367+
interleave(range(capacity), [this, &s](unsigned i) { s << contains(i); },
368+
[&s] { s << ", "; });
369+
s << '}';
370+
}
371+
372+
void dump(llvm::raw_ostream &s = llvm::errs()) const {
373+
s << "(autodiff_index_subset capacity=" << capacity << " indices=(";
374+
interleave(getIndices(), [&s](unsigned i) { s << i; },
375+
[&s] { s << ", "; });
376+
s << "))";
377+
}
378+
379+
int findNext(int startIndex) const;
380+
int findFirst() const { return findNext(-1); }
381+
int findPrevious(int endIndex) const;
382+
int findLast() const { return findPrevious(capacity); }
383+
384+
class iterator {
385+
public:
386+
typedef unsigned value_type;
387+
typedef unsigned difference_type;
388+
typedef unsigned * pointer;
389+
typedef unsigned & reference;
390+
typedef std::forward_iterator_tag iterator_category;
391+
392+
private:
393+
const AutoDiffIndexSubset *parent;
394+
int current = 0;
395+
396+
void advance() {
397+
assert(current != -1 && "Trying to advance past end.");
398+
current = parent->findNext(current);
399+
}
400+
401+
public:
402+
iterator(const AutoDiffIndexSubset *parent, int current)
403+
: parent(parent), current(current) {}
404+
explicit iterator(const AutoDiffIndexSubset *parent)
405+
: iterator(parent, parent->findFirst()) {}
406+
iterator(const iterator &) = default;
407+
408+
iterator operator++(int) {
409+
auto prev = *this;
410+
advance();
411+
return prev;
412+
}
413+
414+
iterator &operator++() {
415+
advance();
416+
return *this;
417+
}
418+
419+
unsigned operator*() const { return current; }
420+
421+
bool operator==(const iterator &other) const {
422+
assert(parent == other.parent &&
423+
"Comparing iterators from different AutoDiffIndexSubsets");
424+
return current == other.current;
425+
}
426+
427+
bool operator!=(const iterator &other) const {
428+
assert(parent == other.parent &&
429+
"Comparing iterators from different AutoDiffIndexSubsets");
430+
return current != other.current;
431+
}
432+
};
433+
};
434+
222435
/// SIL-level automatic differentiation indices. Consists of a source index,
223436
/// i.e. index of the dependent result to differentiate from, and parameter
224437
/// indices, i.e. index of independent parameters to differentiate with
@@ -242,38 +455,33 @@ struct SILAutoDiffIndices {
242455
/// Function type: (A, B) -> (C, D) -> R
243456
/// Bits: [C][D][A][B]
244457
///
245-
llvm::SmallBitVector parameters;
458+
AutoDiffIndexSubset *parameters;
246459

247460
/// Creates a set of AD indices from the given source index and a bit vector
248461
/// representing parameter indices.
249462
/*implicit*/ SILAutoDiffIndices(unsigned source,
250-
llvm::SmallBitVector parameters)
463+
AutoDiffIndexSubset *parameters)
251464
: source(source), parameters(parameters) {}
252465

253-
/// Creates a set of AD indices from the given source index and an array of
254-
/// parameter indices. Elements in `parameters` must be ascending integers.
255-
/*implicit*/ SILAutoDiffIndices(unsigned source,
256-
ArrayRef<unsigned> parameters);
257-
258466
bool operator==(const SILAutoDiffIndices &other) const;
259467

260468
/// Queries whether the function's parameter with index `parameterIndex` is
261469
/// one of the parameters to differentiate with respect to.
262470
bool isWrtParameter(unsigned parameterIndex) const {
263-
return parameterIndex < parameters.size() &&
264-
parameters.test(parameterIndex);
471+
return parameterIndex < parameters->getCapacity() &&
472+
parameters->contains(parameterIndex);
265473
}
266474

267475
void print(llvm::raw_ostream &s = llvm::outs()) const {
268476
s << "(source=" << source << " parameters=(";
269-
interleave(parameters.set_bits(),
477+
interleave(parameters->getIndices(),
270478
[&s](unsigned p) { s << p; }, [&s]{ s << ' '; });
271479
s << "))";
272480
}
273481

274482
std::string mangle() const {
275483
std::string result = "src_" + llvm::utostr(source) + "_wrt_";
276-
interleave(parameters.set_bits(),
484+
interleave(parameters->getIndices(),
277485
[&](unsigned idx) { result += llvm::utostr(idx); },
278486
[&] { result += '_'; });
279487
return result;
@@ -449,19 +657,18 @@ template<typename T> struct DenseMapInfo;
449657

450658
template<> struct DenseMapInfo<SILAutoDiffIndices> {
451659
static SILAutoDiffIndices getEmptyKey() {
452-
return { DenseMapInfo<unsigned>::getEmptyKey(), SmallBitVector() };
660+
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };
453661
}
454662

455663
static SILAutoDiffIndices getTombstoneKey() {
456-
return { DenseMapInfo<unsigned>::getTombstoneKey(),
457-
SmallBitVector(sizeof(intptr_t), true) };
664+
return { DenseMapInfo<unsigned>::getTombstoneKey(), nullptr };
458665
}
459666

460667
static unsigned getHashValue(const SILAutoDiffIndices &Val) {
461-
auto params = Val.parameters.set_bits();
462668
unsigned combinedHash =
463669
hash_combine(~1U, DenseMapInfo<unsigned>::getHashValue(Val.source),
464-
hash_combine_range(params.begin(), params.end()));
670+
hash_combine_range(Val.parameters->begin(),
671+
Val.parameters->end()));
465672
return combinedHash;
466673
}
467674

include/swift/AST/DiagnosticsParse.def

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1535,7 +1535,9 @@ ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
15351535
ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken,
15361536
"the number of operand lists does not match the order", ())
15371537
ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken,
1538-
"expects an assoiacted function kind attribute, e.g. '[jvp]'", ())
1538+
"expected an associated function kind attribute, e.g. '[jvp]'", ())
1539+
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
1540+
"expected an operand of a function type", ())
15391541

15401542
//------------------------------------------------------------------------------
15411543
// MARK: Generics parsing diagnostics

include/swift/AST/Types.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4132,14 +4132,14 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
41324132

41334133
// SWIFT_ENABLE_TENSORFLOW
41344134
CanSILFunctionType getWithDifferentiability(
4135-
unsigned differentiationOrder, const SmallBitVector &parameterIndices);
4135+
unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices);
41364136

41374137
CanSILFunctionType getWithoutDifferentiability();
41384138

41394139
/// Returns the type of a differentiation function that is associated with
41404140
/// a function of this type.
41414141
CanSILFunctionType getAutoDiffAssociatedFunctionType(
4142-
const SmallBitVector &parameterIndices, unsigned resultIndex,
4142+
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
41434143
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
41444144
SILModule &module, LookupConformanceFn lookupConformance,
41454145
GenericSignature *whereClauseGenericSignature = nullptr);
@@ -4148,7 +4148,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
41484148
/// differentiate with respect to for this differentiable function type. (e.g.
41494149
/// which parameters are not @nondiff). The function type must be
41504150
/// differentiable.
4151-
SmallBitVector getDifferentiationParameterIndices() const;
4151+
AutoDiffIndexSubset *getDifferentiationParameterIndices();
41524152

41534153
/// If this is a @convention(witness_method) function with a class
41544154
/// constrained self parameter, return the class constraint for the

include/swift/SIL/SILBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ class SILBuilder {
504504

505505
/// SWIFT_ENABLE_TENSORFLOW
506506
AutoDiffFunctionInst *createAutoDiffFunction(
507-
SILLocation loc, const llvm::SmallBitVector &parameterIndices,
507+
SILLocation loc, AutoDiffIndexSubset *parameterIndices,
508508
unsigned differentiationOrder, SILValue original,
509509
ArrayRef<SILValue> associatedFunctions = {}) {
510510
return insert(AutoDiffFunctionInst::create(getModule(),

include/swift/SIL/SILFunction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class SILDifferentiableAttr final {
174174
SILFunction *getOriginal() const { return Original; }
175175

176176
const SILAutoDiffIndices &getIndices() const { return indices; }
177+
void setIndices(const SILAutoDiffIndices &indices) {
178+
this->indices = indices;
179+
}
177180

178181
TrailingWhereClause *getWhereClause() const { return WhereClause; }
179182

0 commit comments

Comments
 (0)