Skip to content

Commit 9bd5bfc

Browse files
authored
[mlir][sparse] remove unused sparse tensor iterator (#68951)
1 parent 28b27c1 commit 9bd5bfc

File tree

3 files changed

+7
-127
lines changed

3 files changed

+7
-127
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,8 @@ enum class Action : uint32_t {
146146
kEmptyForward = 1,
147147
kFromCOO = 2,
148148
kSparseToSparse = 3,
149-
kFuture = 4, // not used
150149
kToCOO = 5,
151-
kToIterator = 6,
152150
kPack = 7,
153-
// Sort an unordered COO in place.
154151
kSortCOOInPlace = 8,
155152
};
156153

mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ extern "C" {
3939

4040
/// This is the "swiss army knife" method for materializing sparse
4141
/// tensors into the computation. The types of the `ptr` argument and
42-
/// the result depend on the action, as explained in the following table
43-
/// (where "STS" means a sparse-tensor-storage object, "COO" means
44-
/// a coordinate-scheme object, and "Iterator" means an iterator object).
42+
/// the result depend on the action, as explained in the following table,
43+
/// where "STS" means a sparse-tensor-storage object and "COO" means
44+
/// a coordinate-scheme object.
4545
///
4646
/// Action: `ptr`: Returns:
4747
/// kEmpty - STS, empty
4848
/// kEmptyForward - STS, empty, with forwarding COO
4949
/// kFromCOO COO STS, copied from the COO source
5050
/// kSparseToSparse STS STS, copied from the STS source
5151
/// kToCOO STS COO, copied from the STS source
52-
/// kToIterator STS Iterator (@getNext/@delSparseTensorIterator)
5352
/// kPack buffers STS, from level buffers
53+
/// kSortCOOInPlace STS STS, sorted in place
5454
MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensor( // NOLINT
5555
StridedMemRefType<index_type, 1> *dimSizesRef,
5656
StridedMemRefType<index_type, 1> *lvlSizesRef,
@@ -90,14 +90,6 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
9090
MLIR_SPARSETENSOR_FOREVERY_V(DECL_FORWARDINGINSERT)
9191
#undef DECL_FORWARDINGINSERT
9292

93-
/// Coordinate-scheme method for getting the next element while iterating.
94-
#define DECL_GETNEXT(VNAME, V) \
95-
MLIR_CRUNNERUTILS_EXPORT bool _mlir_ciface_getNext##VNAME( \
96-
void *iter, StridedMemRefType<index_type, 1> *cref, \
97-
StridedMemRefType<V, 0> *vref);
98-
MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETNEXT)
99-
#undef DECL_GETNEXT
100-
10193
/// Tensor-storage method to insert elements in lexicographical
10294
/// level-coordinate order.
10395
#define DECL_LEXINSERT(VNAME, V) \
@@ -201,12 +193,6 @@ MLIR_CRUNNERUTILS_EXPORT void delSparseTensor(void *tensor);
201193
MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELCOO)
202194
#undef DECL_DELCOO
203195

204-
/// Releases the memory for an iterator object.
205-
#define DECL_DELITER(VNAME, V) \
206-
MLIR_CRUNNERUTILS_EXPORT void delSparseTensorIterator##VNAME(void *iter);
207-
MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELITER)
208-
#undef DECL_DELITER
209-
210196
/// Helper function to read a sparse tensor filename from the environment,
211197
/// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
212198
MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id);

mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp

Lines changed: 3 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -63,71 +63,18 @@ using namespace mlir::sparse_tensor;
6363

6464
//===----------------------------------------------------------------------===//
6565
//
66-
// Implementation details for public functions, which don't have a good
67-
// place to live in the C++ library this file is wrapping.
66+
// Utilities for manipulating `StridedMemRefType`.
6867
//
6968
//===----------------------------------------------------------------------===//
7069

7170
namespace {
7271

73-
/// Wrapper class to avoid memory leakage issues. The `SparseTensorCOO<V>`
74-
/// class provides a standard C++ iterator interface, where the iterator
75-
/// is implemented as per `std::vector`'s iterator. However, for MLIR's
76-
/// usage we need to have an iterator which also holds onto the underlying
77-
/// `SparseTensorCOO<V>` so that it can be freed whenever the iterator
78-
/// is freed.
79-
//
80-
// We name this `SparseTensorIterator` rather than `SparseTensorCOOIterator`
81-
// for future-proofing, since the use of `SparseTensorCOO` is an
82-
// implementation detail that we eventually want to change (e.g., to
83-
// use `SparseTensorEnumerator` directly, rather than constructing the
84-
// intermediate `SparseTensorCOO` at all).
85-
template <typename V>
86-
class SparseTensorIterator final {
87-
public:
88-
/// This ctor requires `coo` to be a non-null pointer to a dynamically
89-
/// allocated object, and takes ownership of that object. Therefore,
90-
/// callers must not free the underlying COO object, since the iterator's
91-
/// dtor will do so.
92-
explicit SparseTensorIterator(const SparseTensorCOO<V> *coo)
93-
: coo(coo), it(coo->begin()), end(coo->end()) {}
94-
95-
~SparseTensorIterator() { delete coo; }
96-
97-
// Disable copy-ctor and copy-assignment, to prevent double-free.
98-
SparseTensorIterator(const SparseTensorIterator<V> &) = delete;
99-
SparseTensorIterator<V> &operator=(const SparseTensorIterator<V> &) = delete;
100-
101-
/// Gets the next element. If there are no remaining elements, then
102-
/// returns nullptr.
103-
const Element<V> *getNext() { return it < end ? &*it++ : nullptr; }
104-
105-
private:
106-
const SparseTensorCOO<V> *const coo; // Owning pointer.
107-
typename SparseTensorCOO<V>::const_iterator it;
108-
const typename SparseTensorCOO<V>::const_iterator end;
109-
};
110-
111-
//===----------------------------------------------------------------------===//
112-
//
113-
// Utilities for manipulating `StridedMemRefType`.
114-
//
115-
//===----------------------------------------------------------------------===//
116-
117-
// We shouldn't need to use `detail::safelyEQ` here since the `1` is a literal.
11872
#define ASSERT_NO_STRIDE(MEMREF) \
11973
do { \
12074
assert((MEMREF) && "Memref is nullptr"); \
12175
assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride"); \
12276
} while (false)
12377

124-
// All our functions use `uint64_t` for ranks, but `StridedMemRefType::sizes`
125-
// uses `int64_t` on some platforms. So we explicitly cast this lookup to
126-
// ensure we get a consistent type, and we use `checkOverflowCast` rather
127-
// than `static_cast` just to be extremely sure that the casting can't
128-
// go awry. (The cast should aways be safe since (1) sizes should never
129-
// be negative, and (2) the maximum `int64_t` is smaller than the maximum
130-
// `uint64_t`. But it's better to be safe than sorry.)
13178
#define MEMREF_GET_USIZE(MEMREF) \
13279
detail::checkOverflowCast<uint64_t>((MEMREF)->sizes[0])
13380

@@ -137,22 +84,13 @@ class SparseTensorIterator final {
13784

13885
#define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset)
13986

140-
/// Initializes the memref with the provided size and data pointer. This
87+
/// Initializes the memref with the provided size and data pointer. This
14188
/// is designed for functions which want to "return" a memref that aliases
14289
/// into memory owned by some other object (e.g., `SparseTensorStorage`),
14390
/// without doing any actual copying. (The "return" is in scarequotes
14491
/// because the `_mlir_ciface_` calling convention migrates any returned
14592
/// memrefs into an out-parameter passed before all the other function
14693
/// parameters.)
147-
///
148-
/// We make this a function rather than a macro mainly for type safety
149-
/// reasons. This function does not modify the data pointer, but it
150-
/// cannot be marked `const` because it is stored into the (necessarily)
151-
/// non-`const` memref. This function is templated over the `DataSizeT`
152-
/// to work around signedness warnings due to many data types having
153-
/// varying signedness across different platforms. The templating allows
154-
/// this function to ensure that it does the right thing and never
155-
/// introduces errors due to implicit conversions.
15694
template <typename DataSizeT, typename T>
15795
static inline void aliasIntoMemref(DataSizeT size, T *data,
15896
StridedMemRefType<T, 1> &ref) {
@@ -200,20 +138,11 @@ extern "C" {
200138
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
201139
dimRank, tensor); \
202140
} \
203-
case Action::kFuture: { \
204-
break; \
205-
} \
206141
case Action::kToCOO: { \
207142
assert(ptr && "Received nullptr for SparseTensorStorage object"); \
208143
auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
209144
return tensor.toCOO(lvlRank, lvlSizes, dimRank, dim2lvl, lvl2dim); \
210145
} \
211-
case Action::kToIterator: { \
212-
assert(ptr && "Received nullptr for SparseTensorStorage object"); \
213-
auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
214-
auto *coo = tensor.toCOO(lvlRank, lvlSizes, dimRank, dim2lvl, lvl2dim); \
215-
return new SparseTensorIterator<V>(coo); \
216-
} \
217146
case Action::kPack: { \
218147
assert(ptr && "Received nullptr for SparseTensorStorage object"); \
219148
intptr_t *buffers = static_cast<intptr_t *>(ptr); \
@@ -372,7 +301,6 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
372301
CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
373302

374303
// Unsupported case (add above if needed).
375-
// TODO: better pretty-printing of enum values!
376304
MLIR_SPARSETENSOR_FATAL(
377305
"unsupported combination of types: <P=%d, C=%d, V=%d>\n",
378306
static_cast<int>(posTp), static_cast<int>(crdTp),
@@ -428,29 +356,6 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
428356
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT)
429357
#undef IMPL_FORWARDINGINSERT
430358

431-
// NOTE: the `cref` argument uses the same coordinate-space as the `iter`
432-
// (which can be either dim- or lvl-coords, depending on context).
433-
#define IMPL_GETNEXT(VNAME, V) \
434-
bool _mlir_ciface_getNext##VNAME(void *iter, \
435-
StridedMemRefType<index_type, 1> *cref, \
436-
StridedMemRefType<V, 0> *vref) { \
437-
assert(iter &&vref); \
438-
ASSERT_NO_STRIDE(cref); \
439-
index_type *coords = MEMREF_GET_PAYLOAD(cref); \
440-
V *value = MEMREF_GET_PAYLOAD(vref); \
441-
const uint64_t rank = MEMREF_GET_USIZE(cref); \
442-
const Element<V> *elem = \
443-
static_cast<SparseTensorIterator<V> *>(iter)->getNext(); \
444-
if (elem == nullptr) \
445-
return false; \
446-
for (uint64_t d = 0; d < rank; d++) \
447-
coords[d] = elem->coords[d]; \
448-
*value = elem->value; \
449-
return true; \
450-
}
451-
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
452-
#undef IMPL_GETNEXT
453-
454359
#define IMPL_LEXINSERT(VNAME, V) \
455360
void _mlir_ciface_lexInsert##VNAME( \
456361
void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \
@@ -636,7 +541,6 @@ void *_mlir_ciface_newSparseTensorFromReader(
636541
CASE_SECSAME(kU64, kC32, uint64_t, complex32);
637542

638543
// Unsupported case (add above if needed).
639-
// TODO: better pretty-printing of enum values!
640544
MLIR_SPARSETENSOR_FATAL(
641545
"unsupported combination of types: <P=%d, C=%d, V=%d>\n",
642546
static_cast<int>(posTp), static_cast<int>(crdTp),
@@ -701,7 +605,7 @@ void endLexInsert(void *tensor) {
701605

702606
#define IMPL_OUTSPARSETENSOR(VNAME, V) \
703607
void outSparseTensor##VNAME(void *coo, void *dest, bool sort) { \
704-
assert(coo && "Got nullptr for COO object"); \
608+
assert(coo); \
705609
auto &coo_ = *static_cast<SparseTensorCOO<V> *>(coo); \
706610
if (sort) \
707611
coo_.sort(); \
@@ -721,13 +625,6 @@ void delSparseTensor(void *tensor) {
721625
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_DELCOO)
722626
#undef IMPL_DELCOO
723627

724-
#define IMPL_DELITER(VNAME, V) \
725-
void delSparseTensorIterator##VNAME(void *iter) { \
726-
delete static_cast<SparseTensorIterator<V> *>(iter); \
727-
}
728-
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_DELITER)
729-
#undef IMPL_DELITER
730-
731628
char *getTensorFilename(index_type id) {
732629
constexpr size_t BUF_SIZE = 80;
733630
char var[BUF_SIZE];

0 commit comments

Comments
 (0)