Skip to content

Commit 980e024

Browse files
authored
[mlir][sparse] minor edits to support lib files (llvm#68137)
1 parent 95b2c6b commit 980e024

File tree

4 files changed

+47
-131
lines changed

4 files changed

+47
-131
lines changed

mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,19 @@ namespace mlir {
2323
namespace sparse_tensor {
2424

2525
/// An element of a sparse tensor in coordinate-scheme representation
26-
/// (i.e., a pair of coordinates and value). For example, a rank-1
26+
/// (i.e., a pair of coordinates and value). For example, a rank-1
2727
/// vector element would look like
2828
/// ({i}, a[i])
2929
/// and a rank-5 tensor element would look like
3030
/// ({i,j,k,l,m}, a[i,j,k,l,m])
3131
///
32-
/// The coordinates are represented as a (non-owning) pointer into
33-
/// a shared pool of coordinates, rather than being stored directly in
34-
/// this object. This significantly improves performance because it:
35-
/// (1) reduces the per-element memory footprint, and (2) centralizes
36-
/// the memory management for coordinates. The only downside is that
37-
/// the coordinates themselves cannot be retrieved without knowing the
38-
/// rank of the tensor to which this element belongs (and that rank is
39-
/// not stored in this object).
32+
/// The coordinates are represented as a (non-owning) pointer into a
33+
/// shared pool of coordinates, rather than being stored directly in this
34+
/// object. This significantly improves performance because it reduces the
35+
/// per-element memory footprint and centralizes the memory management for
36+
/// coordinates. The only downside is that the coordinates themselves cannot
37+
/// be retrieved without knowing the rank of the tensor to which this element
38+
/// belongs (and that rank is not stored in this object).
4039
template <typename V>
4140
struct Element final {
4241
Element(const uint64_t *coords, V val) : coords(coords), value(val){};
@@ -48,10 +47,6 @@ struct Element final {
4847
template <typename V>
4948
struct ElementLT final {
5049
ElementLT(uint64_t rank) : rank(rank) {}
51-
52-
/// Compares two elements a la `operator<`.
53-
///
54-
/// Precondition: the elements must both be valid for `rank`.
5550
bool operator()(const Element<V> &e1, const Element<V> &e2) const {
5651
for (uint64_t d = 0; d < rank; ++d) {
5752
if (e1.coords[d] == e2.coords[d])
@@ -60,13 +55,10 @@ struct ElementLT final {
6055
}
6156
return false;
6257
}
63-
6458
const uint64_t rank;
6559
};
6660

67-
/// The type of callback functions which receive an element. We avoid
68-
/// packaging the coordinates and value together as an `Element` object
69-
/// because this helps keep code somewhat cleaner.
61+
/// The type of callback functions which receive an element.
7062
template <typename V>
7163
using ElementConsumer =
7264
const std::function<void(const std::vector<uint64_t> &, V)> &;
@@ -89,27 +81,14 @@ class SparseTensorCOO final {
8981
using size_type = typename vector_type::size_type;
9082

9183
/// Constructs a new coordinate-scheme sparse tensor with the given
92-
/// sizes and initial storage capacity.
93-
///
94-
/// Asserts:
95-
/// * `dimSizes` has nonzero size.
96-
/// * the elements of `dimSizes` are nonzero.
84+
/// sizes and an optional initial storage capacity.
9785
explicit SparseTensorCOO(const std::vector<uint64_t> &dimSizes,
9886
uint64_t capacity = 0)
9987
: SparseTensorCOO(dimSizes.size(), dimSizes.data(), capacity) {}
10088

101-
// TODO: make a class for capturing known-valid sizes (a la PermutationRef),
102-
// so that `SparseTensorStorage::toCOO` can avoid redoing these assertions.
103-
// Also so that we can enforce the asserts *before* copying into `dimSizes`.
104-
//
10589
/// Constructs a new coordinate-scheme sparse tensor with the given
106-
/// sizes and initial storage capacity.
107-
///
108-
/// Precondition: `dimSizes` must be valid for `dimRank`.
109-
///
110-
/// Asserts:
111-
/// * `dimRank` is nonzero.
112-
/// * the elements of `dimSizes` are nonzero.
90+
/// sizes and an optional initial storage capacity. The size of the
91+
/// dimSizes array is determined by dimRank.
11392
explicit SparseTensorCOO(uint64_t dimRank, const uint64_t *dimSizes,
11493
uint64_t capacity = 0)
11594
: dimSizes(dimSizes, dimSizes + dimRank), isSorted(true) {
@@ -134,16 +113,7 @@ class SparseTensorCOO final {
134113
/// Returns the `operator<` closure object for the COO's element type.
135114
ElementLT<V> getElementLT() const { return ElementLT<V>(getRank()); }
136115

137-
/// Adds an element to the tensor. This method does not check whether
138-
/// `dimCoords` is already associated with a value, it adds it regardless.
139-
/// Resolving such conflicts is left up to clients of the iterator
140-
/// interface.
141-
///
142-
/// This method invalidates all iterators.
143-
///
144-
/// Asserts:
145-
/// * the `dimCoords` is valid for `getRank`.
146-
/// * the components of `dimCoords` are valid for `getDimSizes`.
116+
/// Adds an element to the tensor. This method invalidates all iterators.
147117
void add(const std::vector<uint64_t> &dimCoords, V val) {
148118
const uint64_t *base = coordinates.data();
149119
const uint64_t size = coordinates.size();
@@ -154,7 +124,7 @@ class SparseTensorCOO final {
154124
"Coordinate is too large for the dimension");
155125
coordinates.push_back(dimCoords[d]);
156126
}
157-
// This base only changes if `coordinates` was reallocated. In which
127+
// This base only changes if `coordinates` was reallocated. In which
158128
// case, we need to correct all previous pointers into the vector.
159129
// Note that this only happens if we did not set the initial capacity
160130
// right, and then only for every internal vector reallocation (which
@@ -175,11 +145,9 @@ class SparseTensorCOO final {
175145
const_iterator begin() const { return elements.cbegin(); }
176146
const_iterator end() const { return elements.cend(); }
177147

178-
/// Sorts elements lexicographically by coordinates. If a coordinate
148+
/// Sorts elements lexicographically by coordinates. If a coordinate
179149
/// is mapped to multiple values, then the relative order of those
180-
/// values is unspecified.
181-
///
182-
/// This method invalidates all iterators.
150+
/// values is unspecified. This method invalidates all iterators.
183151
void sort() {
184152
if (isSorted)
185153
return;

mlir/include/mlir/ExecutionEngine/SparseTensor/File.h

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements parsing and printing of files in one of the
10-
// following external formats:
9+
// This file implements reading and writing files in one of the following
10+
// external formats:
1111
//
1212
// (1) Matrix Market Exchange (MME): *.mtx
1313
// https://math.nist.gov/MatrixMarket/formats.html
@@ -197,31 +197,14 @@ class SparseTensorReader final {
197197

198198
/// Allocates a new COO object for `lvlSizes`, initializes it by reading
199199
/// all the elements from the file and applying `dim2lvl` to their
200-
/// dim-coordinates, and then closes the file.
201-
///
202-
/// Preconditions:
203-
/// * `lvlSizes` must be valid for `lvlRank`.
204-
/// * `dim2lvl` must be valid for `getRank()`.
205-
/// * `dim2lvl` maps `getDimSizes()`-coordinates to `lvlSizes`-coordinates.
206-
/// * the file's actual value type can be read as `V`.
207-
///
208-
/// Asserts:
209-
/// * `isValid()`
210-
/// * `dim2lvl` is a permutation, and therefore also `lvlRank == getRank()`.
211-
/// (This requirement will be lifted once we functionalize `dim2lvl`.)
212-
//
213-
// NOTE: This method is factored out of `readSparseTensor` primarily to
214-
// reduce code bloat (since the bulk of the code doesn't care about the
215-
// `<P,I>` type template parameters). But we leave it public since it's
216-
// perfectly reasonable for clients to use.
200+
/// dim-coordinates, and then closes the file. Templated on V only.
217201
template <typename V>
218202
SparseTensorCOO<V> *readCOO(uint64_t lvlRank, const uint64_t *lvlSizes,
219203
const uint64_t *dim2lvl);
220204

221205
/// Allocates a new sparse-tensor storage object with the given encoding,
222206
/// initializes it by reading all the elements from the file, and then
223-
/// closes the file. Preconditions/assertions are as per `readCOO`
224-
/// and `SparseTensorStorage::newFromCOO`.
207+
/// closes the file. Templated on P, I, and V.
225208
template <typename P, typename I, typename V>
226209
SparseTensorStorage<P, I, V> *
227210
readSparseTensor(uint64_t lvlRank, const uint64_t *lvlSizes,
@@ -312,10 +295,6 @@ SparseTensorCOO<V> *SparseTensorReader::readCOO(uint64_t lvlRank,
312295
const uint64_t *lvlSizes,
313296
const uint64_t *dim2lvl) {
314297
assert(isValid() && "Attempt to readCOO() before readHeader()");
315-
// Construct a `PermutationRef` for the `pushforward` below.
316-
// TODO: This specific implementation does not generalize to arbitrary
317-
// mappings, but once we functionalize the `dim2lvl` argument we can
318-
// simply use that function instead.
319298
const uint64_t dimRank = getRank();
320299
assert(lvlRank == dimRank && "Rank mismatch");
321300
detail::PermutationRef d2l(dimRank, dim2lvl);

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 23 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,30 @@
2525
#include "mlir/ExecutionEngine/SparseTensor/COO.h"
2626
#include "mlir/ExecutionEngine/SparseTensor/ErrorHandling.h"
2727

28+
#define ASSERT_VALID_DIM(d) \
29+
assert(d < getDimRank() && "Dimension is out of bounds");
30+
#define ASSERT_VALID_LVL(l) \
31+
assert(l < getLvlRank() && "Level is out of bounds");
32+
#define ASSERT_COMPRESSED_LVL(l) \
33+
assert(isCompressedLvl(l) && "Level is not compressed");
34+
#define ASSERT_COMPRESSED_OR_SINGLETON_LVL(l) \
35+
do { \
36+
const DimLevelType dlt = getLvlType(l); \
37+
(void)dlt; \
38+
assert((isCompressedDLT(dlt) || isSingletonDLT(dlt)) && \
39+
"Level is neither compressed nor singleton"); \
40+
} while (false)
41+
#define ASSERT_DENSE_DLT(dlt) assert(isDenseDLT(dlt) && "Level is not dense");
42+
2843
namespace mlir {
2944
namespace sparse_tensor {
3045

46+
// Forward references.
47+
template <typename V>
48+
class SparseTensorEnumeratorBase;
49+
template <typename P, typename C, typename V>
50+
class SparseTensorEnumerator;
51+
3152
namespace detail {
3253

3354
/// Checks whether the `perm` array is a permutation of `[0 .. size)`.
@@ -125,33 +146,6 @@ class PermutationRef final {
125146

126147
} // namespace detail
127148

128-
//===----------------------------------------------------------------------===//
129-
// This forward decl is sufficient to split `SparseTensorStorageBase` into
130-
// its own header, but isn't sufficient for `SparseTensorStorage` to join it.
131-
template <typename V>
132-
class SparseTensorEnumeratorBase;
133-
134-
// These macros ensure consistent error messages, without risk of incuring
135-
// an additional method call to do so.
136-
#define ASSERT_VALID_DIM(d) \
137-
assert(d < getDimRank() && "Dimension is out of bounds");
138-
#define ASSERT_VALID_LVL(l) \
139-
assert(l < getLvlRank() && "Level is out of bounds");
140-
#define ASSERT_COMPRESSED_LVL(l) \
141-
assert(isCompressedLvl(l) && "Level is not compressed");
142-
#define ASSERT_COMPRESSED_OR_SINGLETON_LVL(l) \
143-
do { \
144-
const DimLevelType dlt = getLvlType(l); \
145-
(void)dlt; \
146-
assert((isCompressedDLT(dlt) || isSingletonDLT(dlt)) && \
147-
"Level is neither compressed nor singleton"); \
148-
} while (false)
149-
// Because the `SparseTensorStorageBase` ctor uses `MLIR_SPARSETENSOR_FATAL`
150-
// (rather than `assert`) when validating level-types, all the uses of
151-
// `ASSERT_DENSE_DLT` are technically unnecessary. However, they are
152-
// retained for the sake of future-proofing.
153-
#define ASSERT_DENSE_DLT(dlt) assert(isDenseDLT(dlt) && "Level is not dense");
154-
155149
/// Abstract base class for `SparseTensorStorage<P,C,V>`. This class
156150
/// takes responsibility for all the `<P,C,V>`-independent aspects
157151
/// of the tensor (e.g., shape, sparsity, permutation). In addition,
@@ -185,23 +179,9 @@ class SparseTensorEnumeratorBase;
185179
/// specified. Thus, dynamic cardinalities always have an "immutable but
186180
/// unknown" value; so the term "dynamic" should not be taken to indicate
187181
/// run-time mutability.
188-
//
189-
// TODO: we'd like to factor out a class akin to `PermutationRef` for
190-
// capturing known-valid sizes to avoid redundant validity assertions.
191-
// But calling that class "SizesRef" would be a terrible name (and
192-
// "ValidSizesRef" isn't much better). Whereas, calling it "ShapeRef"
193-
// would be a lot nicer, but then that conflicts with the terminology
194-
// introduced above. So we need to come up with some new terminology
195-
// for distinguishing things, which allows a reasonable class name too.
196182
class SparseTensorStorageBase {
197183
protected:
198-
// Since this class is virtual, we must disallow public copying in
199-
// order to avoid "slicing". Since this class has data members,
200-
// that means making copying protected.
201-
// <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
202184
SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
203-
// Copy-assignment would be implicitly deleted (because our fields
204-
// are const), so we explicitly delete it for clarity.
205185
SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
206186

207187
public:
@@ -313,10 +293,8 @@ class SparseTensorStorageBase {
313293
MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETVALUES)
314294
#undef DECL_GETVALUES
315295

316-
/// Element-wise insertion in lexicographic coordinate order. The first
296+
/// Element-wise insertion in lexicographic coordinate order. The first
317297
/// argument is the level-coordinates for the value being inserted.
318-
// TODO: For better safety, this should take a parameter for the
319-
// length of `lvlCoords` and check that against `getLvlRank()`.
320298
#define DECL_LEXINSERT(VNAME, V) virtual void lexInsert(const uint64_t *, V);
321299
MLIR_SPARSETENSOR_FOREVERY_V(DECL_LEXINSERT)
322300
#undef DECL_LEXINSERT
@@ -348,12 +326,6 @@ class SparseTensorStorageBase {
348326
const std::vector<uint64_t> lvl2dim;
349327
};
350328

351-
//===----------------------------------------------------------------------===//
352-
// This forward decl is necessary for defining `SparseTensorStorage`,
353-
// but isn't sufficient for splitting it off.
354-
template <typename P, typename C, typename V>
355-
class SparseTensorEnumerator;
356-
357329
/// A memory-resident sparse tensor using a storage scheme based on
358330
/// per-level sparse/dense annotations. This data structure provides
359331
/// a bufferized form of a sparse tensor type. In contrast to generating
@@ -612,7 +584,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
612584
[&coo](const auto &trgCoords, V val) { coo->add(trgCoords, val); });
613585
// TODO: This assertion assumes there are no stored zeros,
614586
// or if there are then that we don't filter them out.
615-
// Cf., <https://github.com/llvm/llvm-project/issues/54179>
587+
// <https://github.com/llvm/llvm-project/issues/54179>
616588
assert(coo->getElements().size() == values.size());
617589
return coo;
618590
}

mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This header file provides the enums and functions which comprise the
10-
// public API of the `ExecutionEngine/SparseTensorRuntime.cpp` runtime
11-
// support library for the SparseTensor dialect.
9+
// This header file provides the functions which comprise the public API of the
10+
// sparse tensor runtime support library for the SparseTensor dialect.
1211
//
1312
//===----------------------------------------------------------------------===//
1413

@@ -153,8 +152,7 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETNEXT)
153152
#undef DECL_GETNEXT
154153

155154
/// Reads the sparse tensor, stores the coordinates and values to the given
156-
/// memrefs. Returns a boolean value to indicate whether the COO elements are
157-
/// sorted.
155+
/// memrefs. Returns a boolean to indicate whether the COO elements are sorted.
158156
#define DECL_GETNEXT(VNAME, V, CNAME, C) \
159157
MLIR_CRUNNERUTILS_EXPORT bool \
160158
_mlir_ciface_getSparseTensorReaderReadToBuffers##CNAME##VNAME( \
@@ -240,8 +238,7 @@ MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderNSE(void *p);
240238
MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderDimSize(void *p,
241239
index_type d);
242240

243-
/// Releases the SparseTensorReader. This also closes the file associated with
244-
/// the reader.
241+
/// Releases the SparseTensorReader and closes the associated file.
245242
MLIR_CRUNNERUTILS_EXPORT void delSparseTensorReader(void *p);
246243

247244
/// Creates a SparseTensorWriter for outputting a sparse tensor to a file

0 commit comments

Comments
 (0)