@@ -269,6 +269,26 @@ struct SparseTensorCOO final {
269
269
DO (C64, complex64) \
270
270
DO (C32, complex32)
271
271
272
+ // This x-macro calls its argument on every overhead type which has
273
+ // fixed-width. It excludes `index_type` because that type is often
274
+ // handled specially (e.g., by translating it into the architecture-dependent
275
+ // equivalent fixed-width overhead type).
276
+ #define FOREVERY_FIXED_O (DO ) \
277
+ DO (64 , uint64_t ) \
278
+ DO (32 , uint32_t ) \
279
+ DO (16 , uint16_t ) \
280
+ DO (8 , uint8_t )
281
+
282
+ // This x-macro calls its argument on every overhead type, including
283
+ // `index_type`. Our naming convention uses an empty suffix for
284
+ // `index_type`, so the missing first argument when we call `DO`
285
+ // gets resolved to the empty token which can then be concatenated
286
+ // as intended. (This behavior is standard per C99 6.10.3/4 and
287
+ // C++11 N3290 16.3/4; whereas in C++03 16.3/10 it was undefined behavior.)
288
+ #define FOREVERY_O (DO ) \
289
+ FOREVERY_FIXED_O (DO) \
290
+ DO (, index_type)
291
+
272
292
// Forward.
273
293
template <typename V>
274
294
class SparseTensorEnumeratorBase ;
@@ -347,30 +367,18 @@ class SparseTensorStorageBase {
347
367
#undef DECL_NEWENUMERATOR
348
368
349
369
// / Overhead storage.
350
- virtual void getPointers (std::vector<uint64_t > **, uint64_t ) {
351
- FATAL_PIV (" p64" );
352
- }
353
- virtual void getPointers (std::vector<uint32_t > **, uint64_t ) {
354
- FATAL_PIV (" p32" );
355
- }
356
- virtual void getPointers (std::vector<uint16_t > **, uint64_t ) {
357
- FATAL_PIV (" p16" );
358
- }
359
- virtual void getPointers (std::vector<uint8_t > **, uint64_t ) {
360
- FATAL_PIV (" p8" );
361
- }
362
- virtual void getIndices (std::vector<uint64_t > **, uint64_t ) {
363
- FATAL_PIV (" i64" );
364
- }
365
- virtual void getIndices (std::vector<uint32_t > **, uint64_t ) {
366
- FATAL_PIV (" i32" );
367
- }
368
- virtual void getIndices (std::vector<uint16_t > **, uint64_t ) {
369
- FATAL_PIV (" i16" );
370
+ #define DECL_GETPOINTERS (PNAME, P ) \
371
+ virtual void getPointers (std::vector<P> **, uint64_t ) { \
372
+ FATAL_PIV (" getPointers" #PNAME); \
370
373
}
371
- virtual void getIndices (std::vector<uint8_t > **, uint64_t ) {
372
- FATAL_PIV (" i8" );
374
+ FOREVERY_FIXED_O (DECL_GETPOINTERS)
375
+ #undef DECL_GETPOINTERS
376
+ #define DECL_GETINDICES (INAME, I ) \
377
+ virtual void getIndices (std::vector<I> **, uint64_t ) { \
378
+ FATAL_PIV (" getIndices" #INAME); \
373
379
}
380
+ FOREVERY_FIXED_O (DECL_GETINDICES)
381
+ #undef DECL_GETINDICES
374
382
375
383
// / Primary storage.
376
384
#define DECL_GETVALUES (VNAME, V ) \
@@ -1576,18 +1584,16 @@ FOREVERY_V(IMPL_SPARSEVALUES)
1576
1584
ref->strides [0 ] = 1 ; \
1577
1585
}
1578
1586
// / Methods that provide direct access to pointers.
1579
- IMPL_GETOVERHEAD (sparsePointers, index_type, getPointers)
1580
- IMPL_GETOVERHEAD(sparsePointers64, uint64_t , getPointers)
1581
- IMPL_GETOVERHEAD(sparsePointers32, uint32_t , getPointers)
1582
- IMPL_GETOVERHEAD(sparsePointers16, uint16_t , getPointers)
1583
- IMPL_GETOVERHEAD(sparsePointers8, uint8_t , getPointers)
1587
+ #define IMPL_SPARSEPOINTERS (PNAME, P ) \
1588
+ IMPL_GETOVERHEAD (sparsePointers##PNAME, P, getPointers)
1589
+ FOREVERY_O(IMPL_SPARSEPOINTERS)
1590
+ #undef IMPL_SPARSEPOINTERS
1584
1591
1585
1592
// / Methods that provide direct access to indices.
1586
- IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1587
- IMPL_GETOVERHEAD(sparseIndices64, uint64_t , getIndices)
1588
- IMPL_GETOVERHEAD(sparseIndices32, uint32_t , getIndices)
1589
- IMPL_GETOVERHEAD(sparseIndices16, uint16_t , getIndices)
1590
- IMPL_GETOVERHEAD(sparseIndices8, uint8_t , getIndices)
1593
+ #define IMPL_SPARSEINDICES (INAME, I ) \
1594
+ IMPL_GETOVERHEAD (sparseIndices##INAME, I, getIndices)
1595
+ FOREVERY_O(IMPL_SPARSEINDICES)
1596
+ #undef IMPL_SPARSEINDICES
1591
1597
#undef IMPL_GETOVERHEAD
1592
1598
1593
1599
// / Helper to add value to coordinate scheme, one per value type.
0 commit comments