Skip to content

Commit 27ea470

Browse files
committed
[mlir][sparse] Add runtime support for reading a COO tensor and writing the data to the given indices and values buffers.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D143862
1 parent 00e2098 commit 27ea470

File tree

5 files changed

+178
-28
lines changed

5 files changed

+178
-28
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ enum class PrimaryType : uint32_t {
9292
};
9393

9494
// This x-macro includes all `V` types.
95+
// TODO: We currently split out the non-variadic version from the variadic
96+
// version. Using ##__VA_ARGS__ to avoid the split gives
97+
// warning: token pasting of ',' and __VA_ARGS__ is a GNU extension
98+
// [-Wgnu-zero-variadic-macro-arguments]
99+
// and __VA_OPT__(, ) __VA_ARGS__ requires c++20.
95100
#define MLIR_SPARSETENSOR_FOREVERY_V(DO) \
96101
DO(F64, double) \
97102
DO(F32, float) \
@@ -104,6 +109,27 @@ enum class PrimaryType : uint32_t {
104109
DO(C64, complex64) \
105110
DO(C32, complex32)
106111

112+
// This x-macro includes all `V` types and supports variadic arguments.
113+
#define MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, ...) \
114+
DO(F64, double, __VA_ARGS__) \
115+
DO(F32, float, __VA_ARGS__) \
116+
DO(F16, f16, __VA_ARGS__) \
117+
DO(BF16, bf16, __VA_ARGS__) \
118+
DO(I64, int64_t, __VA_ARGS__) \
119+
DO(I32, int32_t, __VA_ARGS__) \
120+
DO(I16, int16_t, __VA_ARGS__) \
121+
DO(I8, int8_t, __VA_ARGS__) \
122+
DO(C64, complex64, __VA_ARGS__) \
123+
DO(C32, complex32, __VA_ARGS__)
124+
125+
// This x-macro calls its argument on every pair of overhead and `V` types.
126+
#define MLIR_SPARSETENSOR_FOREVERY_V_O(DO) \
127+
MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 64, uint64_t) \
128+
MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 32, uint32_t) \
129+
MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 16, uint16_t) \
130+
MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 8, uint8_t) \
131+
MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 0, index_type)
132+
107133
constexpr bool isFloatingPrimaryType(PrimaryType valTy) {
108134
return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kBF16;
109135
}

mlir/include/mlir/ExecutionEngine/SparseTensor/File.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,14 @@ class SparseTensorReader final {
249249
return tensor;
250250
}
251251

252+
/// Reads the COO tensor from the file, stores the coordinates and values to
253+
/// the given buffers, returns a boolean value to indicate whether the COO
254+
/// elements are sorted.
255+
/// Precondition: the buffers should have enough space to hold the elements.
256+
template <typename C, typename V>
257+
bool readToBuffers(uint64_t lvlRank, const uint64_t *dim2lvl,
258+
C *lvlCoordinates, V *values);
259+
252260
private:
253261
/// Attempts to read a line from the file. Is private because there's
254262
/// no reason for client code to call it.
@@ -287,6 +295,13 @@ class SparseTensorReader final {
287295
void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
288296
SparseTensorCOO<V> *lvlCOO);
289297

298+
/// The internal implementation of `readToBuffers`. We template over
299+
/// `IsPattern` in order to perform LICM without needing to duplicate the
300+
/// source code.
301+
template <typename C, typename V, bool IsPattern>
302+
bool readToBuffersLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
303+
C *lvlCoordinates, V *values);
304+
290305
/// Reads the MME header of a general sparse matrix of type real.
291306
void readMMEHeader();
292307

@@ -351,6 +366,69 @@ void SparseTensorReader::readCOOLoop(uint64_t lvlRank,
351366
}
352367
}
353368

369+
template <typename C, typename V>
370+
bool SparseTensorReader::readToBuffers(uint64_t lvlRank,
371+
const uint64_t *dim2lvl,
372+
C *lvlCoordinates, V *values) {
373+
assert(isValid() && "Attempt to readCOO() before readHeader()");
374+
const uint64_t dimRank = getRank();
375+
assert(lvlRank == dimRank && "Rank mismatch");
376+
detail::PermutationRef d2l(dimRank, dim2lvl);
377+
// Do some manual LICM, to avoid assertions in the for-loop.
378+
bool isSorted =
379+
isPattern()
380+
? readToBuffersLoop<C, V, true>(lvlRank, d2l, lvlCoordinates, values)
381+
: readToBuffersLoop<C, V, false>(lvlRank, d2l, lvlCoordinates,
382+
values);
383+
384+
// Close the file and return isSorted.
385+
closeFile();
386+
return isSorted;
387+
}
388+
389+
template <typename C, typename V, bool IsPattern>
390+
bool SparseTensorReader::readToBuffersLoop(uint64_t lvlRank,
391+
detail::PermutationRef dim2lvl,
392+
C *lvlCoordinates, V *values) {
393+
const uint64_t dimRank = getRank();
394+
const uint64_t nse = getNNZ();
395+
std::vector<C> dimCoords(dimRank);
396+
// Read the first element with isSorted=false as a way to avoid accessing its
397+
// previous element.
398+
bool isSorted = false;
399+
char *linePtr;
400+
// We inline `readCOOElement` here in order to avoid redundant assertions,
401+
// since they're guaranteed by the call to `isValid()` and the construction
402+
// of `dimCoords` above.
403+
auto readElement = [&]() {
404+
linePtr = readCOOIndices<C>(dimCoords.data());
405+
dim2lvl.pushforward(dimRank, dimCoords.data(), lvlCoordinates);
406+
*values = detail::readCOOValue<V, IsPattern>(&linePtr);
407+
if (isSorted) {
408+
// Note that isSorted was set to false while reading the first element,
409+
// to guarantee the safeness of using prevLvlCoords.
410+
C *prevLvlCoords = lvlCoordinates - lvlRank;
411+
// TODO: define a new CoordsLT which is like ElementLT but doesn't have
412+
// the V parameter, and use it here.
413+
for (uint64_t l = 0; l < lvlRank; ++l) {
414+
if (prevLvlCoords[l] != lvlCoordinates[l]) {
415+
if (prevLvlCoords[l] > lvlCoordinates[l])
416+
isSorted = false;
417+
break;
418+
}
419+
}
420+
}
421+
lvlCoordinates += lvlRank;
422+
++values;
423+
};
424+
readElement();
425+
isSorted = true;
426+
for (uint64_t n = 1; n < nse; ++n)
427+
readElement();
428+
429+
return isSorted;
430+
}
431+
354432
/// Writes the sparse tensor to `filename` in extended FROSTT format.
355433
template <typename V>
356434
inline void writeExtFROSTT(const SparseTensorCOO<V> &coo,

mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,17 @@ MLIR_CRUNNERUTILS_EXPORT void delSparseTensorReader(void *p);
283283
MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETNEXT)
284284
#undef DECL_GETNEXT
285285

286+
/// Reads the sparse tensor, stores the coordinates and values to the given
287+
/// memrefs. Returns a boolean value to indicate whether the COO elements are
288+
/// sorted.
289+
#define DECL_GETNEXT(VNAME, V, CNAME, C) \
290+
MLIR_CRUNNERUTILS_EXPORT bool \
291+
_mlir_ciface_getSparseTensorReaderRead##CNAME##VNAME( \
292+
void *p, StridedMemRefType<index_type, 1> *dim2lvlRef, \
293+
StridedMemRefType<C, 1> *iref, StridedMemRefType<V, 1> *vref) \
294+
MLIR_SPARSETENSOR_FOREVERY_V_O(DECL_GETNEXT)
295+
#undef DECL_GETNEXT
296+
286297
using SparseTensorWriter = std::ostream;
287298

288299
/// Creates a SparseTensorWriter for outputing a sparse tensor to a file with

mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,33 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
631631
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
632632
#undef IMPL_GETNEXT
633633

634+
#define IMPL_GETNEXT(VNAME, V, CNAME, C) \
635+
bool _mlir_ciface_getSparseTensorReaderRead##CNAME##VNAME( \
636+
void *p, StridedMemRefType<index_type, 1> *dim2lvlRef, \
637+
StridedMemRefType<C, 1> *cref, StridedMemRefType<V, 1> *vref) { \
638+
assert(p); \
639+
auto &reader = *static_cast<SparseTensorReader *>(p); \
640+
ASSERT_NO_STRIDE(cref); \
641+
ASSERT_NO_STRIDE(vref); \
642+
ASSERT_NO_STRIDE(dim2lvlRef); \
643+
const uint64_t cSize = MEMREF_GET_USIZE(cref); \
644+
const uint64_t vSize = MEMREF_GET_USIZE(vref); \
645+
const uint64_t lvlRank = reader.getRank(); \
646+
assert(vSize *lvlRank <= cSize); \
647+
assert(vSize >= reader.getNNZ() && "Not enough space in buffers"); \
648+
ASSERT_USIZE_EQ(dim2lvlRef, lvlRank); \
649+
(void)cSize; \
650+
(void)vSize; \
651+
(void)lvlRank; \
652+
C *lvlCoordinates = MEMREF_GET_PAYLOAD(cref); \
653+
V *values = MEMREF_GET_PAYLOAD(vref); \
654+
index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \
655+
return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvlCoordinates, \
656+
values); \
657+
}
658+
MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
659+
#undef IMPL_GETNEXT
660+
634661
void *_mlir_ciface_newSparseTensorFromReader(
635662
void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
636663
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ module {
4343
func.func private @getSparseTensorReaderIsSymmetric(!TensorReader) -> (i1)
4444
func.func private @copySparseTensorReaderDimSizes(!TensorReader,
4545
memref<?xindex>) -> () attributes { llvm.emit_c_interface }
46+
func.func private @getSparseTensorReaderRead0F32(!TensorReader,
47+
memref<?xindex>, memref<?xindex>, memref<?xf32>)
48+
-> (i1) attributes { llvm.emit_c_interface }
4649
func.func private @getSparseTensorReaderNextF32(!TensorReader,
4750
memref<?xindex>, memref<f32>) -> () attributes { llvm.emit_c_interface }
4851

@@ -60,6 +63,14 @@ module {
6063
return
6164
}
6265

66+
func.func @dumpi2(%arg0: memref<?xindex, strided<[2], offset: ?>>) {
67+
%c0 = arith.constant 0 : index
68+
%v = vector.transfer_read %arg0[%c0], %c0 :
69+
memref<?xindex, strided<[2], offset: ?>>, vector<17xindex>
70+
vector.print %v : vector<17xindex>
71+
return
72+
}
73+
6374
func.func @dumpf(%arg0: memref<?xf32>) {
6475
%c0 = arith.constant 0 : index
6576
%d0 = arith.constant 0.0 : f32
@@ -70,39 +81,31 @@ module {
7081

7182
// Returns the indices and values of the tensor.
7283
func.func @readTensorFile(%tensor: !TensorReader)
73-
-> (memref<?xindex>, memref<?xindex>, memref<?xf32>) {
84+
-> (memref<?xindex>, memref<?xf32>, i1) {
7485
%c0 = arith.constant 0 : index
7586
%c1 = arith.constant 1 : index
87+
%c2 = arith.constant 2 : index
7688

7789
%rank = call @getSparseTensorReaderRank(%tensor) : (!TensorReader) -> index
7890
%nnz = call @getSparseTensorReaderNNZ(%tensor) : (!TensorReader) -> index
7991

8092
// Assume rank == 2.
81-
%x0s = memref.alloc(%nnz) : memref<?xindex>
82-
%x1s = memref.alloc(%nnz) : memref<?xindex>
93+
%isize = arith.muli %c2, %nnz : index
94+
%xs = memref.alloc(%isize) : memref<?xindex>
8395
%vs = memref.alloc(%nnz) : memref<?xf32>
84-
%indices = memref.alloc(%rank) : memref<?xindex>
85-
%value = memref.alloca() : memref<f32>
86-
scf.for %i = %c0 to %nnz step %c1 {
87-
func.call @getSparseTensorReaderNextF32(%tensor, %indices, %value)
88-
: (!TensorReader, memref<?xindex>, memref<f32>) -> ()
89-
// TODO: can we use memref.subview to avoid the need for the %value
90-
// buffer?
91-
%v = memref.load %value[] : memref<f32>
92-
memref.store %v, %vs[%i] : memref<?xf32>
93-
%i0 = memref.load %indices[%c0] : memref<?xindex>
94-
memref.store %i0, %x0s[%i] : memref<?xindex>
95-
%i1 = memref.load %indices[%c1] : memref<?xindex>
96-
memref.store %i1, %x1s[%i] : memref<?xindex>
97-
}
98-
99-
// Release the resource for the indices.
100-
memref.dealloc %indices : memref<?xindex>
101-
return %x0s, %x1s, %vs : memref<?xindex>, memref<?xindex>, memref<?xf32>
96+
%dim2lvl = memref.alloca(%c2) : memref<?xindex>
97+
memref.store %c0, %dim2lvl[%c0] : memref<?xindex>
98+
memref.store %c1, %dim2lvl[%c1] : memref<?xindex>
99+
%isSorted =func.call @getSparseTensorReaderRead0F32(%tensor, %dim2lvl, %xs, %vs)
100+
: (!TensorReader, memref<?xindex>, memref<?xindex>, memref<?xf32>) -> (i1)
101+
return %xs, %vs, %isSorted : memref<?xindex>, memref<?xf32>, i1
102102
}
103103

104104
// Reads a COO tensor from the given file name and prints its content.
105105
func.func @readTensorFileAndDump(%fileName: !Filename) {
106+
%c0 = arith.constant 0 : index
107+
%c1 = arith.constant 1 : index
108+
%c2 = arith.constant 2 : index
106109
%tensor = call @createSparseTensorReader(%fileName)
107110
: (!Filename) -> (!TensorReader)
108111
%rank = call @getSparseTensorReaderRank(%tensor) : (!TensorReader) -> index
@@ -116,18 +119,22 @@ module {
116119
func.call @copySparseTensorReaderDimSizes(%tensor, %dimSizes)
117120
: (!TensorReader, memref<?xindex>) -> ()
118121
call @dumpi(%dimSizes) : (memref<?xindex>) -> ()
119-
%x0s, %x1s, %vs = call @readTensorFile(%tensor)
120-
: (!TensorReader) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>)
121122

122-
call @dumpi(%x0s) : (memref<?xindex>) -> ()
123-
call @dumpi(%x1s) : (memref<?xindex>) -> ()
123+
%xs, %vs, %isSorted = call @readTensorFile(%tensor)
124+
: (!TensorReader) -> (memref<?xindex>, memref<?xf32>, i1)
125+
%x0s = memref.subview %xs[%c0][%nnz][%c2]
126+
: memref<?xindex> to memref<?xindex, strided<[2], offset: ?>>
127+
%x1s = memref.subview %xs[%c1][%nnz][%c2]
128+
: memref<?xindex> to memref<?xindex, strided<[2], offset: ?>>
129+
vector.print %isSorted : i1
130+
call @dumpi2(%x0s) : (memref<?xindex, strided<[2], offset: ?>>) -> ()
131+
call @dumpi2(%x1s) : (memref<?xindex, strided<[2], offset: ?>>) -> ()
124132
call @dumpf(%vs) : (memref<?xf32>) -> ()
125133

126134
// Release the resources.
127135
call @delSparseTensorReader(%tensor) : (!TensorReader) -> ()
128136
memref.dealloc %dimSizes : memref<?xindex>
129-
memref.dealloc %x0s : memref<?xindex>
130-
memref.dealloc %x1s : memref<?xindex>
137+
memref.dealloc %xs : memref<?xindex>
131138
memref.dealloc %vs : memref<?xf32>
132139

133140
return
@@ -184,6 +191,7 @@ module {
184191
// CHECK: 17
185192
// CHECK: 0
186193
// CHECK: ( 4, 256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
194+
// CHECK: 1
187195
// CHECK: ( 0, 0, 0, 0, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3 )
188196
// CHECK: ( 0, 126, 127, 254, 1, 253, 2, 0, 1, 3, 98, 126, 127, 128, 249, 253, 255 )
189197
// CHECK: ( -1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12, -13, 14, -15, 16, -17 )
@@ -215,4 +223,4 @@ module {
215223

216224
return
217225
}
218-
}
226+
}

0 commit comments

Comments
 (0)