@@ -63,71 +63,18 @@ using namespace mlir::sparse_tensor;
63
63
64
64
// ===----------------------------------------------------------------------===//
65
65
//
66
- // Implementation details for public functions, which don't have a good
67
- // place to live in the C++ library this file is wrapping.
66
+ // Utilities for manipulating `StridedMemRefType`.
68
67
//
69
68
// ===----------------------------------------------------------------------===//
70
69
71
70
namespace {
72
71
73
- // / Wrapper class to avoid memory leakage issues. The `SparseTensorCOO<V>`
74
- // / class provides a standard C++ iterator interface, where the iterator
75
- // / is implemented as per `std::vector`'s iterator. However, for MLIR's
76
- // / usage we need to have an iterator which also holds onto the underlying
77
- // / `SparseTensorCOO<V>` so that it can be freed whenever the iterator
78
- // / is freed.
79
- //
80
- // We name this `SparseTensorIterator` rather than `SparseTensorCOOIterator`
81
- // for future-proofing, since the use of `SparseTensorCOO` is an
82
- // implementation detail that we eventually want to change (e.g., to
83
- // use `SparseTensorEnumerator` directly, rather than constructing the
84
- // intermediate `SparseTensorCOO` at all).
85
- template <typename V>
86
- class SparseTensorIterator final {
87
- public:
88
- // / This ctor requires `coo` to be a non-null pointer to a dynamically
89
- // / allocated object, and takes ownership of that object. Therefore,
90
- // / callers must not free the underlying COO object, since the iterator's
91
- // / dtor will do so.
92
- explicit SparseTensorIterator (const SparseTensorCOO<V> *coo)
93
- : coo(coo), it(coo->begin ()), end(coo->end ()) {}
94
-
95
- ~SparseTensorIterator () { delete coo; }
96
-
97
- // Disable copy-ctor and copy-assignment, to prevent double-free.
98
- SparseTensorIterator (const SparseTensorIterator<V> &) = delete;
99
- SparseTensorIterator<V> &operator =(const SparseTensorIterator<V> &) = delete;
100
-
101
- // / Gets the next element. If there are no remaining elements, then
102
- // / returns nullptr.
103
- const Element<V> *getNext () { return it < end ? &*it++ : nullptr ; }
104
-
105
- private:
106
- const SparseTensorCOO<V> *const coo; // Owning pointer.
107
- typename SparseTensorCOO<V>::const_iterator it;
108
- const typename SparseTensorCOO<V>::const_iterator end;
109
- };
110
-
111
- // ===----------------------------------------------------------------------===//
112
- //
113
- // Utilities for manipulating `StridedMemRefType`.
114
- //
115
- // ===----------------------------------------------------------------------===//
116
-
117
- // We shouldn't need to use `detail::safelyEQ` here since the `1` is a literal.
118
72
#define ASSERT_NO_STRIDE (MEMREF ) \
119
73
do { \
120
74
assert ((MEMREF) && " Memref is nullptr" ); \
121
75
assert (((MEMREF)->strides [0 ] == 1 ) && " Memref has non-trivial stride" ); \
122
76
} while (false )
123
77
124
- // All our functions use `uint64_t` for ranks, but `StridedMemRefType::sizes`
125
- // uses `int64_t` on some platforms. So we explicitly cast this lookup to
126
- // ensure we get a consistent type, and we use `checkOverflowCast` rather
127
- // than `static_cast` just to be extremely sure that the casting can't
128
- // go awry. (The cast should aways be safe since (1) sizes should never
129
- // be negative, and (2) the maximum `int64_t` is smaller than the maximum
130
- // `uint64_t`. But it's better to be safe than sorry.)
131
78
#define MEMREF_GET_USIZE (MEMREF ) \
132
79
detail::checkOverflowCast<uint64_t >((MEMREF)->sizes[0 ])
133
80
@@ -137,22 +84,13 @@ class SparseTensorIterator final {
137
84
138
85
#define MEMREF_GET_PAYLOAD (MEMREF ) ((MEMREF)->data + (MEMREF)->offset)
139
86
140
- // / Initializes the memref with the provided size and data pointer. This
87
+ // / Initializes the memref with the provided size and data pointer. This
141
88
// / is designed for functions which want to "return" a memref that aliases
142
89
// / into memory owned by some other object (e.g., `SparseTensorStorage`),
143
90
// / without doing any actual copying. (The "return" is in scarequotes
144
91
// / because the `_mlir_ciface_` calling convention migrates any returned
145
92
// / memrefs into an out-parameter passed before all the other function
146
93
// / parameters.)
147
- // /
148
- // / We make this a function rather than a macro mainly for type safety
149
- // / reasons. This function does not modify the data pointer, but it
150
- // / cannot be marked `const` because it is stored into the (necessarily)
151
- // / non-`const` memref. This function is templated over the `DataSizeT`
152
- // / to work around signedness warnings due to many data types having
153
- // / varying signedness across different platforms. The templating allows
154
- // / this function to ensure that it does the right thing and never
155
- // / introduces errors due to implicit conversions.
156
94
template <typename DataSizeT, typename T>
157
95
static inline void aliasIntoMemref (DataSizeT size, T *data,
158
96
StridedMemRefType<T, 1 > &ref) {
@@ -200,20 +138,11 @@ extern "C" {
200
138
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
201
139
dimRank, tensor); \
202
140
} \
203
- case Action::kFuture : { \
204
- break ; \
205
- } \
206
141
case Action::kToCOO : { \
207
142
assert (ptr && " Received nullptr for SparseTensorStorage object" ); \
208
143
auto &tensor = *static_cast <SparseTensorStorage<P, C, V> *>(ptr); \
209
144
return tensor.toCOO (lvlRank, lvlSizes, dimRank, dim2lvl, lvl2dim); \
210
145
} \
211
- case Action::kToIterator : { \
212
- assert (ptr && " Received nullptr for SparseTensorStorage object" ); \
213
- auto &tensor = *static_cast <SparseTensorStorage<P, C, V> *>(ptr); \
214
- auto *coo = tensor.toCOO (lvlRank, lvlSizes, dimRank, dim2lvl, lvl2dim); \
215
- return new SparseTensorIterator<V>(coo); \
216
- } \
217
146
case Action::kPack : { \
218
147
assert (ptr && " Received nullptr for SparseTensorStorage object" ); \
219
148
intptr_t *buffers = static_cast <intptr_t *>(ptr); \
@@ -372,7 +301,6 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
372
301
CASE_SECSAME (OverheadType::kU64 , PrimaryType::kC32 , uint64_t , complex32);
373
302
374
303
// Unsupported case (add above if needed).
375
- // TODO: better pretty-printing of enum values!
376
304
MLIR_SPARSETENSOR_FATAL (
377
305
" unsupported combination of types: <P=%d, C=%d, V=%d>\n " ,
378
306
static_cast <int >(posTp), static_cast <int >(crdTp),
@@ -428,29 +356,6 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
428
356
MLIR_SPARSETENSOR_FOREVERY_V (IMPL_FORWARDINGINSERT)
429
357
#undef IMPL_FORWARDINGINSERT
430
358
431
- // NOTE: the `cref` argument uses the same coordinate-space as the `iter`
432
- // (which can be either dim- or lvl-coords, depending on context).
433
- #define IMPL_GETNEXT (VNAME, V ) \
434
- bool _mlir_ciface_getNext##VNAME(void *iter, \
435
- StridedMemRefType<index_type, 1 > *cref, \
436
- StridedMemRefType<V, 0 > *vref) { \
437
- assert (iter &&vref); \
438
- ASSERT_NO_STRIDE (cref); \
439
- index_type *coords = MEMREF_GET_PAYLOAD (cref); \
440
- V *value = MEMREF_GET_PAYLOAD (vref); \
441
- const uint64_t rank = MEMREF_GET_USIZE (cref); \
442
- const Element<V> *elem = \
443
- static_cast <SparseTensorIterator<V> *>(iter)->getNext (); \
444
- if (elem == nullptr ) \
445
- return false ; \
446
- for (uint64_t d = 0 ; d < rank; d++) \
447
- coords[d] = elem->coords [d]; \
448
- *value = elem->value ; \
449
- return true ; \
450
- }
451
- MLIR_SPARSETENSOR_FOREVERY_V (IMPL_GETNEXT)
452
- #undef IMPL_GETNEXT
453
-
454
359
#define IMPL_LEXINSERT (VNAME, V ) \
455
360
void _mlir_ciface_lexInsert##VNAME( \
456
361
void *t, StridedMemRefType<index_type, 1 > *lvlCoordsRef, \
@@ -636,7 +541,6 @@ void *_mlir_ciface_newSparseTensorFromReader(
636
541
CASE_SECSAME (kU64 , kC32 , uint64_t , complex32);
637
542
638
543
// Unsupported case (add above if needed).
639
- // TODO: better pretty-printing of enum values!
640
544
MLIR_SPARSETENSOR_FATAL (
641
545
" unsupported combination of types: <P=%d, C=%d, V=%d>\n " ,
642
546
static_cast <int >(posTp), static_cast <int >(crdTp),
@@ -701,7 +605,7 @@ void endLexInsert(void *tensor) {
701
605
702
606
#define IMPL_OUTSPARSETENSOR (VNAME, V ) \
703
607
void outSparseTensor##VNAME(void *coo, void *dest, bool sort) { \
704
- assert (coo && " Got nullptr for COO object " ); \
608
+ assert (coo); \
705
609
auto &coo_ = *static_cast <SparseTensorCOO<V> *>(coo); \
706
610
if (sort) \
707
611
coo_.sort (); \
@@ -721,13 +625,6 @@ void delSparseTensor(void *tensor) {
721
625
MLIR_SPARSETENSOR_FOREVERY_V (IMPL_DELCOO)
722
626
#undef IMPL_DELCOO
723
627
724
- #define IMPL_DELITER (VNAME, V ) \
725
- void delSparseTensorIterator##VNAME(void *iter) { \
726
- delete static_cast <SparseTensorIterator<V> *>(iter); \
727
- }
728
- MLIR_SPARSETENSOR_FOREVERY_V (IMPL_DELITER)
729
- #undef IMPL_DELITER
730
-
731
628
char *getTensorFilename (index_type id) {
732
629
constexpr size_t BUF_SIZE = 80 ;
733
630
char var[BUF_SIZE];
0 commit comments