20
20
21
21
#include " ASTContext.h"
22
22
#include " llvm/ADT/SmallBitVector.h"
23
+ #include " swift/Basic/Range.h"
23
24
24
25
namespace swift {
25
26
@@ -73,6 +74,7 @@ class ParsedAutoDiffParameter {
73
74
};
74
75
75
76
class AnyFunctionType ;
77
+ class AutoDiffIndexSubset ;
76
78
class AutoDiffParameterIndicesBuilder ;
77
79
class Type ;
78
80
@@ -173,7 +175,8 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
173
175
// / ==> returns 1110
174
176
// / (because the lowered SIL type is (A, B, C, D) -> R)
175
177
// /
176
- llvm::SmallBitVector getLowered (AnyFunctionType *functionType) const ;
178
+ AutoDiffIndexSubset *getLowered (ASTContext &ctx,
179
+ AnyFunctionType *functionType) const ;
177
180
178
181
void Profile (llvm::FoldingSetNodeID &ID) const {
179
182
ID.AddInteger (parameters.size ());
@@ -219,6 +222,216 @@ class AutoDiffParameterIndicesBuilder {
219
222
unsigned size () { return parameters.size (); }
220
223
};
221
224
225
+ class AutoDiffIndexSubset : public llvm ::FoldingSetNode {
226
+ public:
227
+ typedef uint64_t BitWord;
228
+
229
+ static constexpr unsigned bitWordSize = sizeof (BitWord);
230
+ static constexpr unsigned numBitsPerBitWord = bitWordSize * 8 ;
231
+
232
+ static std::pair<unsigned , unsigned >
233
+ getBitWordIndexAndOffset (unsigned index) {
234
+ auto bitWordIndex = index / numBitsPerBitWord;
235
+ auto bitWordOffset = index % numBitsPerBitWord;
236
+ return {bitWordIndex, bitWordOffset};
237
+ }
238
+
239
+ static unsigned getNumBitWordsNeededForCapacity (unsigned capacity) {
240
+ if (capacity == 0 ) return 0 ;
241
+ return capacity / numBitsPerBitWord + 1 ;
242
+ }
243
+
244
+ private:
245
+ // / The total capacity of the index subset, which is `1` less than the largest
246
+ // / index.
247
+ unsigned capacity;
248
+ // / The number of bit words in the index subset.
249
+ unsigned numBitWords;
250
+
251
+ BitWord *getBitWordsData () {
252
+ return reinterpret_cast <BitWord *>(this + 1 );
253
+ }
254
+
255
+ const BitWord *getBitWordsData () const {
256
+ return reinterpret_cast <const BitWord *>(this + 1 );
257
+ }
258
+
259
+ ArrayRef<BitWord> getBitWords () const {
260
+ return {getBitWordsData (), getNumBitWords ()};
261
+ }
262
+
263
+ BitWord getBitWord (unsigned i) const {
264
+ return getBitWordsData ()[i];
265
+ }
266
+
267
+ BitWord &getBitWord (unsigned i) {
268
+ return getBitWordsData ()[i];
269
+ }
270
+
271
+ MutableArrayRef<BitWord> getMutableBitWords () {
272
+ return {const_cast <BitWord *>(getBitWordsData ()), getNumBitWords ()};
273
+ }
274
+
275
+ explicit AutoDiffIndexSubset (unsigned capacity, ArrayRef<unsigned > indices)
276
+ : capacity(capacity),
277
+ numBitWords(getNumBitWordsNeededForCapacity(capacity)) {
278
+ std::uninitialized_fill_n (getBitWordsData (), numBitWords, 0 );
279
+ for (auto i : indices) {
280
+ unsigned bitWordIndex, offset;
281
+ std::tie (bitWordIndex, offset) = getBitWordIndexAndOffset (i);
282
+ getBitWord (bitWordIndex) |= (1 << offset);
283
+ }
284
+ }
285
+
286
+ public:
287
+ AutoDiffIndexSubset () = delete ;
288
+ AutoDiffIndexSubset (const AutoDiffIndexSubset &) = delete ;
289
+ AutoDiffIndexSubset &operator =(const AutoDiffIndexSubset &) = delete ;
290
+
291
+ // Defined in ASTContext.h.
292
+ static AutoDiffIndexSubset *get (ASTContext &ctx,
293
+ unsigned capacity,
294
+ ArrayRef<unsigned > indices);
295
+
296
+ static AutoDiffIndexSubset *getDefault (ASTContext &ctx,
297
+ unsigned capacity,
298
+ bool includeAll = false ) {
299
+ if (includeAll)
300
+ return getFromRange (ctx, capacity, IntRange<>(capacity));
301
+ return get (ctx, capacity, {});
302
+ }
303
+
304
+ static AutoDiffIndexSubset *getFromRange (ASTContext &ctx,
305
+ unsigned capacity,
306
+ IntRange<> range) {
307
+ return get (ctx, capacity,
308
+ SmallVector<unsigned , 8 >(range.begin (), range.end ()));
309
+ }
310
+
311
+ unsigned getNumBitWords () const {
312
+ return numBitWords;
313
+ }
314
+
315
+ unsigned getCapacity () const {
316
+ return capacity;
317
+ }
318
+
319
+ class iterator ;
320
+
321
+ iterator begin () const {
322
+ return iterator (this );
323
+ }
324
+
325
+ iterator end () const {
326
+ return iterator (this , (int )capacity);
327
+ }
328
+
329
+ iterator_range<iterator> getIndices () const {
330
+ return make_range (begin (), end ());
331
+ }
332
+
333
+ unsigned getNumIndices () const {
334
+ return (unsigned )std::distance (begin (), end ());
335
+ }
336
+
337
+ bool contains (unsigned index) const {
338
+ unsigned bitWordIndex, offset;
339
+ std::tie (bitWordIndex, offset) = getBitWordIndexAndOffset (index);
340
+ return getBitWord (bitWordIndex) & (1 << offset);
341
+ }
342
+
343
+ bool isEmpty () const {
344
+ return llvm::all_of (getBitWords (), [](BitWord bw) { return !(bool )bw; });
345
+ }
346
+
347
+ bool equals (AutoDiffIndexSubset *other) const {
348
+ return capacity == other->getCapacity () &&
349
+ getBitWords ().equals (other->getBitWords ());
350
+ }
351
+
352
+ bool isSubsetOf (AutoDiffIndexSubset *other) const ;
353
+ bool isSupersetOf (AutoDiffIndexSubset *other) const ;
354
+
355
+ AutoDiffIndexSubset *adding (unsigned index, ASTContext &ctx) const ;
356
+ AutoDiffIndexSubset *extendingCapacity (ASTContext &ctx,
357
+ unsigned newCapacity) const ;
358
+
359
+ void Profile (llvm::FoldingSetNodeID &id) const {
360
+ id.AddInteger (capacity);
361
+ for (auto index : getIndices ())
362
+ id.AddInteger (index);
363
+ }
364
+
365
+ void print (llvm::raw_ostream &s = llvm::outs()) const {
366
+ s << ' {' ;
367
+ interleave (range (capacity), [this , &s](unsigned i) { s << contains (i); },
368
+ [&s] { s << " , " ; });
369
+ s << ' }' ;
370
+ }
371
+
372
+ void dump (llvm::raw_ostream &s = llvm::errs()) const {
373
+ s << " (autodiff_index_subset capacity=" << capacity << " indices=(" ;
374
+ interleave (getIndices (), [&s](unsigned i) { s << i; },
375
+ [&s] { s << " , " ; });
376
+ s << " ))" ;
377
+ }
378
+
379
+ int findNext (int startIndex) const ;
380
+ int findFirst () const { return findNext (-1 ); }
381
+ int findPrevious (int endIndex) const ;
382
+ int findLast () const { return findPrevious (capacity); }
383
+
384
+ class iterator {
385
+ public:
386
+ typedef unsigned value_type;
387
+ typedef unsigned difference_type;
388
+ typedef unsigned * pointer;
389
+ typedef unsigned & reference;
390
+ typedef std::forward_iterator_tag iterator_category;
391
+
392
+ private:
393
+ const AutoDiffIndexSubset *parent;
394
+ int current = 0 ;
395
+
396
+ void advance () {
397
+ assert (current != -1 && " Trying to advance past end." );
398
+ current = parent->findNext (current);
399
+ }
400
+
401
+ public:
402
+ iterator (const AutoDiffIndexSubset *parent, int current)
403
+ : parent(parent), current(current) {}
404
+ explicit iterator (const AutoDiffIndexSubset *parent)
405
+ : iterator(parent, parent->findFirst ()) {}
406
+ iterator (const iterator &) = default;
407
+
408
+ iterator operator ++(int ) {
409
+ auto prev = *this ;
410
+ advance ();
411
+ return prev;
412
+ }
413
+
414
+ iterator &operator ++() {
415
+ advance ();
416
+ return *this ;
417
+ }
418
+
419
+ unsigned operator *() const { return current; }
420
+
421
+ bool operator ==(const iterator &other) const {
422
+ assert (parent == other.parent &&
423
+ " Comparing iterators from different AutoDiffIndexSubsets" );
424
+ return current == other.current ;
425
+ }
426
+
427
+ bool operator !=(const iterator &other) const {
428
+ assert (parent == other.parent &&
429
+ " Comparing iterators from different AutoDiffIndexSubsets" );
430
+ return current != other.current ;
431
+ }
432
+ };
433
+ };
434
+
222
435
// / SIL-level automatic differentiation indices. Consists of a source index,
223
436
// / i.e. index of the dependent result to differentiate from, and parameter
224
437
// / indices, i.e. index of independent parameters to differentiate with
@@ -242,38 +455,33 @@ struct SILAutoDiffIndices {
242
455
// / Function type: (A, B) -> (C, D) -> R
243
456
// / Bits: [C][D][A][B]
244
457
// /
245
- llvm::SmallBitVector parameters;
458
+ AutoDiffIndexSubset * parameters;
246
459
247
460
// / Creates a set of AD indices from the given source index and a bit vector
248
461
// / representing parameter indices.
249
462
/* implicit*/ SILAutoDiffIndices(unsigned source,
250
- llvm::SmallBitVector parameters)
463
+ AutoDiffIndexSubset * parameters)
251
464
: source(source), parameters(parameters) {}
252
465
253
- // / Creates a set of AD indices from the given source index and an array of
254
- // / parameter indices. Elements in `parameters` must be ascending integers.
255
- /* implicit*/ SILAutoDiffIndices(unsigned source,
256
- ArrayRef<unsigned > parameters);
257
-
258
466
bool operator ==(const SILAutoDiffIndices &other) const ;
259
467
260
468
// / Queries whether the function's parameter with index `parameterIndex` is
261
469
// / one of the parameters to differentiate with respect to.
262
470
bool isWrtParameter (unsigned parameterIndex) const {
263
- return parameterIndex < parameters. size () &&
264
- parameters. test (parameterIndex);
471
+ return parameterIndex < parameters-> getCapacity () &&
472
+ parameters-> contains (parameterIndex);
265
473
}
266
474
267
475
void print (llvm::raw_ostream &s = llvm::outs()) const {
268
476
s << " (source=" << source << " parameters=(" ;
269
- interleave (parameters. set_bits (),
477
+ interleave (parameters-> getIndices (),
270
478
[&s](unsigned p) { s << p; }, [&s]{ s << ' ' ; });
271
479
s << " ))" ;
272
480
}
273
481
274
482
std::string mangle () const {
275
483
std::string result = " src_" + llvm::utostr (source) + " _wrt_" ;
276
- interleave (parameters. set_bits (),
484
+ interleave (parameters-> getIndices (),
277
485
[&](unsigned idx) { result += llvm::utostr (idx); },
278
486
[&] { result += ' _' ; });
279
487
return result;
@@ -449,19 +657,18 @@ template<typename T> struct DenseMapInfo;
449
657
450
658
template <> struct DenseMapInfo <SILAutoDiffIndices> {
451
659
static SILAutoDiffIndices getEmptyKey () {
452
- return { DenseMapInfo<unsigned >::getEmptyKey (), SmallBitVector () };
660
+ return { DenseMapInfo<unsigned >::getEmptyKey (), nullptr };
453
661
}
454
662
455
663
static SILAutoDiffIndices getTombstoneKey () {
456
- return { DenseMapInfo<unsigned >::getTombstoneKey (),
457
- SmallBitVector (sizeof (intptr_t ), true ) };
664
+ return { DenseMapInfo<unsigned >::getTombstoneKey (), nullptr };
458
665
}
459
666
460
667
static unsigned getHashValue (const SILAutoDiffIndices &Val) {
461
- auto params = Val.parameters .set_bits ();
462
668
unsigned combinedHash =
463
669
hash_combine (~1U , DenseMapInfo<unsigned >::getHashValue (Val.source ),
464
- hash_combine_range (params.begin (), params.end ()));
670
+ hash_combine_range (Val.parameters ->begin (),
671
+ Val.parameters ->end ()));
465
672
return combinedHash;
466
673
}
467
674
0 commit comments