@@ -223,22 +223,31 @@ class AutoDiffParameterIndicesBuilder {
223
223
};
224
224
225
225
class AutoDiffIndexSubset : public llvm ::FoldingSetNode {
226
- private:
227
- using BitWord = uint64_t ;
226
+ public:
227
+ typedef uint64_t BitWord;
228
+
229
+ static constexpr unsigned bitWordSize = sizeof (BitWord);
230
+ static constexpr unsigned numBitsPerBitWord = bitWordSize * 8 ;
231
+
232
+ static std::pair<unsigned , unsigned >
233
+ getBitWordIndexAndOffset (unsigned index) {
234
+ auto bitWordIndex = index / numBitsPerBitWord;
235
+ auto bitWordOffset = index % numBitsPerBitWord;
236
+ return {bitWordIndex, bitWordOffset};
237
+ }
228
238
239
+ static unsigned getNumBitWordsNeededForCapacity (unsigned capacity) {
240
+ if (capacity == 0 ) return 0 ;
241
+ return capacity / numBitsPerBitWord + 1 ;
242
+ }
243
+
244
+ private:
229
245
// / The total capacity of the index subset, which is `1` less than the largest
230
246
// / index.
231
247
unsigned capacity;
232
248
// / The number of bit words in the index subset. in the index subset.
233
249
unsigned numBitWords;
234
250
235
- static std::pair<unsigned , unsigned > getBitWordIndexAndOffset (unsigned index);
236
- static unsigned getNumBitWordsNeededForCapacity (unsigned capacity);
237
-
238
- unsigned getNumBitWords () const {
239
- return numBitWords;
240
- }
241
-
242
251
BitWord *getBitWordsData () {
243
252
return reinterpret_cast <BitWord *>(this + 1 );
244
253
}
@@ -263,62 +272,159 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
263
272
return {const_cast <BitWord *>(getBitWordsData ()), getNumBitWords ()};
264
273
}
265
274
266
- explicit AutoDiffIndexSubset (unsigned capacity, unsigned numBitWords,
267
- ArrayRef<unsigned > indices);
275
+ explicit AutoDiffIndexSubset (unsigned capacity, ArrayRef<unsigned > indices)
276
+ : capacity(capacity),
277
+ numBitWords(getNumBitWordsNeededForCapacity(capacity)) {
278
+ std::uninitialized_fill_n (getBitWordsData (), numBitWords, 0 );
279
+ for (auto i : indices) {
280
+ unsigned bitWordIndex, offset;
281
+ std::tie (bitWordIndex, offset) = getBitWordIndexAndOffset (i);
282
+ getBitWord (bitWordIndex) |= (1 << offset);
283
+ }
284
+ }
268
285
269
286
public:
270
287
AutoDiffIndexSubset () = delete ;
271
288
AutoDiffIndexSubset (const AutoDiffIndexSubset &) = delete ;
272
289
AutoDiffIndexSubset &operator =(const AutoDiffIndexSubset &) = delete ;
273
290
274
- static AutoDiffIndexSubset *get (ASTContext &ctx, unsigned capacity,
275
- bool includeAll = false );
276
- static AutoDiffIndexSubset *get (ASTContext &ctx, unsigned capacity,
277
- IntRange<> range);
278
- static AutoDiffIndexSubset *get (ASTContext &ctx, unsigned capacity,
291
+ // Defined in ASTContext.h.
292
+ static AutoDiffIndexSubset *get (ASTContext &ctx,
293
+ unsigned capacity,
279
294
ArrayRef<unsigned > indices);
280
295
296
+ static AutoDiffIndexSubset *getDefault (ASTContext &ctx,
297
+ unsigned capacity,
298
+ bool includeAll = false ) {
299
+ if (includeAll)
300
+ return getFromRange (ctx, capacity, IntRange<>(capacity));
301
+ return get (ctx, capacity, {});
302
+ }
303
+
304
+ static AutoDiffIndexSubset *getFromRange (ASTContext &ctx,
305
+ unsigned capacity,
306
+ IntRange<> range) {
307
+ return get (ctx, capacity,
308
+ SmallVector<unsigned , 8 >(range.begin (), range.end ()));
309
+ }
310
+
311
+ unsigned getNumBitWords () const {
312
+ return numBitWords;
313
+ }
314
+
281
315
unsigned getCapacity () const {
282
316
return capacity;
283
317
}
284
318
285
319
class iterator ;
286
320
287
- iterator begin () const ;
288
- iterator end () const ;
289
- iterator_range<iterator> getIndices () const ;
290
- unsigned getNumIndices () const ;
321
+ iterator begin () const {
322
+ return iterator (this );
323
+ }
324
+
325
+ iterator end () const {
326
+ return iterator (this , (int )capacity);
327
+ }
328
+
329
+ iterator_range<iterator> getIndices () const {
330
+ return make_range (begin (), end ());
331
+ }
332
+
333
+ unsigned getNumIndices () const {
334
+ return (unsigned )std::distance (begin (), end ());
335
+ }
291
336
292
337
bool contains (unsigned index) const {
293
338
unsigned bitWordIndex, offset;
294
339
std::tie (bitWordIndex, offset) = getBitWordIndexAndOffset (index);
295
340
return getBitWord (bitWordIndex) & (1 << offset);
296
341
}
297
342
298
- bool isEmpty () const ;
299
- bool equals (const AutoDiffIndexSubset *other) const ;
300
- bool isSubsetOf (const AutoDiffIndexSubset *other) const ;
301
- bool isSupersetOf (const AutoDiffIndexSubset *other) const ;
343
+ bool isEmpty () const {
344
+ return llvm::all_of (getBitWords (), [](BitWord bw) { return !(bool )bw; });
345
+ }
346
+
347
+ bool equals (const AutoDiffIndexSubset *other) const {
348
+ return capacity == other->getCapacity () &&
349
+ getBitWords ().equals (other->getBitWords ());
350
+ }
351
+
352
+ bool isSubsetOf (const AutoDiffIndexSubset *other) const {
353
+ assert (capacity == other->capacity );
354
+ for (auto index : range (numBitWords))
355
+ if (getBitWord (index) & ~other->getBitWord (index))
356
+ return false ;
357
+ return true ;
358
+ }
359
+
360
+ bool isSupersetOf (const AutoDiffIndexSubset *other) const {
361
+ assert (capacity == other->capacity );
362
+ for (auto index : range (numBitWords))
363
+ if (~getBitWord (index) & other->getBitWord (index))
364
+ return false ;
365
+ return true ;
366
+ }
302
367
303
- AutoDiffIndexSubset *adding (unsigned index, ASTContext &ctx) const ;
304
- AutoDiffIndexSubset *extendingCapacity (ASTContext &ctx,
305
- unsigned newCapacity) const ;
368
+ AutoDiffIndexSubset *adding (
369
+ unsigned index, ASTContext &ctx) const {
370
+ assert (index < getCapacity ());
371
+ SmallVector<unsigned , 8 > newIndices;
372
+ newIndices.reserve (capacity + 1 );
373
+ bool inserted = false ;
374
+ for (auto curIndex : getIndices ()) {
375
+ if (inserted && curIndex > index) {
376
+ newIndices.push_back (index);
377
+ inserted = false ;
378
+ }
379
+ newIndices.push_back (curIndex);
380
+ }
381
+ return get (ctx, capacity, newIndices);
382
+ }
306
383
307
- void Profile (llvm::FoldingSetNodeID &id) const ;
384
+ AutoDiffIndexSubset *extendingCapacity (
385
+ ASTContext &ctx, unsigned newCapacity) const {
386
+ assert (newCapacity >= capacity);
387
+ if (newCapacity == capacity)
388
+ return const_cast <AutoDiffIndexSubset *>(this );
389
+ SmallVector<unsigned , 8 > indices;
390
+ for (auto index : getIndices ())
391
+ indices.push_back (index);
392
+ return AutoDiffIndexSubset::get (ctx, newCapacity, indices);
393
+ }
394
+
395
+ void Profile (llvm::FoldingSetNodeID &id) const {
396
+ id.AddInteger (capacity);
397
+ for (auto index : getIndices ())
398
+ id.AddInteger (index);
399
+ }
400
+
401
+ void print (llvm::raw_ostream &s = llvm::outs()) const {
402
+ s << ' {' ;
403
+ interleave (range (capacity), [this , &s](unsigned i) { s << contains (i); },
404
+ [&s] { s << " , " ; });
405
+ s << ' }' ;
406
+ }
407
+
408
+ void dump (llvm::raw_ostream &s = llvm::errs()) const {
409
+ s << " (autodiff_index_subset capacity=" << capacity << " indices=(" ;
410
+ interleave (getIndices (), [&s](unsigned i) { s << i; },
411
+ [&s] { s << " , " ; });
412
+ s << " ))" ;
413
+ }
308
414
309
- private:
310
415
int findNext (int startIndex) const ;
311
416
int findFirst () const { return findNext (-1 ); }
312
417
int findPrevious (int endIndex) const ;
313
418
int findLast () const { return findPrevious (capacity); }
314
419
315
- public:
316
420
class iterator {
317
- typedef unsigned value_type;
318
- typedef int difference_type;
319
- typedef unsigned * pointer;
320
- typedef unsigned & reference;
321
- typedef std::forward_iterator_tag iterator_category;
421
+ public:
422
+ typedef unsigned value_type;
423
+ typedef unsigned difference_type;
424
+ typedef unsigned * pointer;
425
+ typedef unsigned & reference;
426
+ typedef std::forward_iterator_tag iterator_category;
427
+
322
428
private:
323
429
const AutoDiffIndexSubset *parent;
324
430
int current = 0 ;
@@ -349,42 +455,19 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
349
455
unsigned operator *() const { return current; }
350
456
351
457
bool operator ==(const iterator &other) const {
352
- assert (& parent == & other.parent &&
458
+ assert (parent == other.parent &&
353
459
" Comparing iterators from different AutoDiffIndexSubsets" );
354
460
return current == other.current ;
355
461
}
356
462
357
463
bool operator !=(const iterator &other) const {
358
- assert (& parent == & other.parent &&
464
+ assert (parent == other.parent &&
359
465
" Comparing iterators from different AutoDiffIndexSubsets" );
360
466
return current != other.current ;
361
467
}
362
468
};
363
469
};
364
470
365
- class AutoDiffFunctionParameterSubset {
366
- private:
367
- AutoDiffIndexSubset *indexSubset;
368
- bool curried;
369
-
370
- public:
371
- explicit AutoDiffFunctionParameterSubset (
372
- AutoDiffIndexSubset *indexSubset, bool isCurried)
373
- : indexSubset(indexSubset), curried(isCurried) {}
374
-
375
- explicit AutoDiffFunctionParameterSubset (
376
- ASTContext &ctx, AutoDiffIndexSubset *parameterSubset,
377
- Optional<bool > isSelfIncluded);
378
-
379
- AutoDiffIndexSubset *getIndexSubset () const {
380
- return indexSubset;
381
- }
382
-
383
- bool isCurried () const {
384
- return curried;
385
- }
386
- };
387
-
388
471
// / SIL-level automatic differentiation indices. Consists of a source index,
389
472
// / i.e. index of the dependent result to differentiate from, and parameter
390
473
// / indices, i.e. index of independent parameters to differentiate with
0 commit comments