Skip to content

Commit a9a19f5

Browse files
committed
[mlir][sparse] Adding x-macros for OverheadType
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D126026
1 parent 1dfd8e9 commit a9a19f5

File tree

1 file changed

+38
-32
lines changed

1 file changed

+38
-32
lines changed

mlir/lib/ExecutionEngine/SparseTensorUtils.cpp

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,26 @@ struct SparseTensorCOO final {
269269
DO(C64, complex64) \
270270
DO(C32, complex32)
271271

272+
// This x-macro calls its argument on every overhead type which has
273+
// fixed-width. It excludes `index_type` because that type is often
274+
// handled specially (e.g., by translating it into the architecture-dependent
275+
// equivalent fixed-width overhead type).
276+
#define FOREVERY_FIXED_O(DO) \
277+
DO(64, uint64_t) \
278+
DO(32, uint32_t) \
279+
DO(16, uint16_t) \
280+
DO(8, uint8_t)
281+
282+
// This x-macro calls its argument on every overhead type, including
283+
// `index_type`. Our naming convention uses an empty suffix for
284+
// `index_type`, so the missing first argument when we call `DO`
285+
// gets resolved to the empty token which can then be concatenated
286+
// as intended. (This behavior is standard per C99 6.10.3/4 and
287+
// C++11 N3290 16.3/4; whereas in C++03 16.3/10 it was undefined behavior.)
288+
#define FOREVERY_O(DO) \
289+
FOREVERY_FIXED_O(DO) \
290+
DO(, index_type)
291+
272292
// Forward.
273293
template <typename V>
274294
class SparseTensorEnumeratorBase;
@@ -347,30 +367,18 @@ class SparseTensorStorageBase {
347367
#undef DECL_NEWENUMERATOR
348368

349369
/// Overhead storage.
350-
virtual void getPointers(std::vector<uint64_t> **, uint64_t) {
351-
FATAL_PIV("p64");
352-
}
353-
virtual void getPointers(std::vector<uint32_t> **, uint64_t) {
354-
FATAL_PIV("p32");
355-
}
356-
virtual void getPointers(std::vector<uint16_t> **, uint64_t) {
357-
FATAL_PIV("p16");
358-
}
359-
virtual void getPointers(std::vector<uint8_t> **, uint64_t) {
360-
FATAL_PIV("p8");
361-
}
362-
virtual void getIndices(std::vector<uint64_t> **, uint64_t) {
363-
FATAL_PIV("i64");
364-
}
365-
virtual void getIndices(std::vector<uint32_t> **, uint64_t) {
366-
FATAL_PIV("i32");
367-
}
368-
virtual void getIndices(std::vector<uint16_t> **, uint64_t) {
369-
FATAL_PIV("i16");
370+
#define DECL_GETPOINTERS(PNAME, P) \
371+
virtual void getPointers(std::vector<P> **, uint64_t) { \
372+
FATAL_PIV("getPointers" #PNAME); \
370373
}
371-
virtual void getIndices(std::vector<uint8_t> **, uint64_t) {
372-
FATAL_PIV("i8");
374+
FOREVERY_FIXED_O(DECL_GETPOINTERS)
375+
#undef DECL_GETPOINTERS
376+
#define DECL_GETINDICES(INAME, I) \
377+
virtual void getIndices(std::vector<I> **, uint64_t) { \
378+
FATAL_PIV("getIndices" #INAME); \
373379
}
380+
FOREVERY_FIXED_O(DECL_GETINDICES)
381+
#undef DECL_GETINDICES
374382

375383
/// Primary storage.
376384
#define DECL_GETVALUES(VNAME, V) \
@@ -1576,18 +1584,16 @@ FOREVERY_V(IMPL_SPARSEVALUES)
15761584
ref->strides[0] = 1; \
15771585
}
15781586
/// Methods that provide direct access to pointers.
1579-
IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
1580-
IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
1581-
IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
1582-
IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
1583-
IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
1587+
#define IMPL_SPARSEPOINTERS(PNAME, P) \
1588+
IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
1589+
FOREVERY_O(IMPL_SPARSEPOINTERS)
1590+
#undef IMPL_SPARSEPOINTERS
15841591

15851592
/// Methods that provide direct access to indices.
1586-
IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1587-
IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
1588-
IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
1589-
IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
1590-
IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
1593+
#define IMPL_SPARSEINDICES(INAME, I) \
1594+
IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices)
1595+
FOREVERY_O(IMPL_SPARSEINDICES)
1596+
#undef IMPL_SPARSEINDICES
15911597
#undef IMPL_GETOVERHEAD
15921598

15931599
/// Helper to add value to coordinate scheme, one per value type.

0 commit comments

Comments
 (0)